implement the Stream method to send new tokens through a callback
This commit is contained in:
parent
09292b06a0
commit
13eabfabcb
|
@ -20,6 +20,8 @@ namespace tle = tensorrt_llm::executor;
|
|||
|
||||
namespace huggingface::tgi::backends {
|
||||
|
||||
using TokenStreamingCallback = void(tle::TokenIdType);
|
||||
|
||||
/**
|
||||
* Initialize all the components required by TRTLLM.
|
||||
* It is required to call this function before attempting to load any engine
|
||||
|
@ -28,9 +30,8 @@ namespace huggingface::tgi::backends {
|
|||
|
||||
/**
|
||||
*
|
||||
* @param config
|
||||
* @param workerPath
|
||||
* @param channel
|
||||
* @param config TensorRT-LLM configuration object
|
||||
* @param workerPath Path to the "executorWorker" provided by TensorRT-LLM when using orchestrator mode
|
||||
* @return
|
||||
*/
|
||||
tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath);
|
||||
|
@ -58,7 +59,7 @@ namespace huggingface::tgi::backends {
|
|||
}
|
||||
|
||||
/***
|
||||
*
|
||||
* Submit a new generation task to the executor
|
||||
* @param tokens
|
||||
* @param maxNewTokens
|
||||
* @param topK
|
||||
|
@ -69,7 +70,7 @@ namespace huggingface::tgi::backends {
|
|||
* @param frequencyPenalty
|
||||
* @param seed
|
||||
* @param nTopTokens
|
||||
* @return
|
||||
* @return Request id related to this generation for reference
|
||||
*/
|
||||
[[nodiscard]] tle::IdType Submit(
|
||||
const std::vector<tle::TokenIdType> &tokens,
|
||||
|
@ -85,11 +86,13 @@ namespace huggingface::tgi::backends {
|
|||
);
|
||||
|
||||
/***
|
||||
*
|
||||
* @param reqId
|
||||
* @return
|
||||
* Unroll the token generation until end of stream is reached.
|
||||
* Every generated token is streamed back through the provided callback for further processing
|
||||
* @param reqId The request id to unroll
|
||||
* @param cb The callback to stream token back
|
||||
* @return Global number of generated tokens for this request id
|
||||
*/
|
||||
std::vector<tle::Response> Poll(tle::IdType reqId);
|
||||
size_t Stream(tle::IdType reqId, const std::function<TokenStreamingCallback>& cb);
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ void huggingface::tgi::backends::InitializeBackend() {
|
|||
initTrtLlmPlugins();
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) {
|
||||
tle::ExecutorConfig execConfig(1);
|
||||
|
||||
|
@ -64,6 +65,7 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
|
|||
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string&>());
|
||||
}
|
||||
|
||||
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
|
||||
tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
||||
const std::vector<tle::TokenIdType> &tokens,
|
||||
const int32_t maxNewTokens,
|
||||
|
@ -76,8 +78,8 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
|||
std::optional<uint32_t> seed,
|
||||
std::optional<uint32_t> nTopTokens
|
||||
) {
|
||||
spdlog::debug(
|
||||
FMT_STRING("Submitting inference over {:d} tokens to the executor {:d}"),
|
||||
SPDLOG_DEBUG(
|
||||
FMT_STRING("Submitting inference over {:d} tokens to the executor ({:d} already in-flight)"),
|
||||
tokens.size(),
|
||||
executor.getLatestIterationStats().back().numActiveRequests
|
||||
);
|
||||
|
@ -103,8 +105,31 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
|||
return executor.enqueueRequest(request);
|
||||
}
|
||||
|
||||
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType reqId) {
|
||||
SPDLOG_DEBUG(FMT_STRING("Polling request {:d}"), reqId);
|
||||
size_t huggingface::tgi::backends::TensorRtLlmBackend::Stream(const tle::IdType reqId, const std::function<TokenStreamingCallback>& cb) {
|
||||
bool isFinal = false;
|
||||
size_t generatedTokens = 0;
|
||||
|
||||
do {
|
||||
const auto responses = executor.awaitResponses(reqId);
|
||||
return responses;
|
||||
for (const auto &response: responses){
|
||||
if(response.hasError()) {
|
||||
SPDLOG_WARN("Caught error during generation: {}", response.getErrorMsg());
|
||||
isFinal = true;
|
||||
} else {
|
||||
const auto generation = response.getResult();
|
||||
const auto token = generation.outputTokenIds[0][0];
|
||||
|
||||
// Update the end of stream detection and overall number of generated tokens
|
||||
isFinal = generation.isFinal;
|
||||
++generatedTokens;
|
||||
|
||||
// Send the token back through the callback function for further processing
|
||||
cb(token);
|
||||
}
|
||||
}
|
||||
|
||||
} while(!isFinal);
|
||||
|
||||
// Return the number of generated tokens
|
||||
return generatedTokens;
|
||||
}
|
Loading…
Reference in New Issue