implement the Stream method to send new tokens through a callback

This commit is contained in:
Morgan Funtowicz 2024-07-09 13:46:48 +00:00
parent 09292b06a0
commit 13eabfabcb
2 changed files with 44 additions and 16 deletions

View File

@ -20,6 +20,8 @@ namespace tle = tensorrt_llm::executor;
namespace huggingface::tgi::backends {
using TokenStreamingCallback = void(tle::TokenIdType);
/**
* Initialize all the components required by TRTLLM.
* It is required to call this function before attempting to load any engine
@ -28,9 +30,8 @@ namespace huggingface::tgi::backends {
/**
*
* @param config
* @param workerPath
* @param channel
* @param config TensorRT-LLM configuration object
* @param workerPath Path to the "executorWorker" provided by TensorRT-LLM when using orchestrator mode
* @return
*/
tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath);
@ -58,7 +59,7 @@ namespace huggingface::tgi::backends {
}
/***
*
* Submit a new generation task to the executor
* @param tokens
* @param maxNewTokens
* @param topK
@ -69,7 +70,7 @@ namespace huggingface::tgi::backends {
* @param frequencyPenalty
* @param seed
* @param nTopTokens
* @return
* @return Request id related to this generation for reference
*/
[[nodiscard]] tle::IdType Submit(
const std::vector<tle::TokenIdType> &tokens,
@ -85,13 +86,15 @@ namespace huggingface::tgi::backends {
);
/***
*
* @param reqId
* @return
* Unroll the token generation until end of stream is reached.
* Every generated token is streamed back through the provided callback for further processing
* @param reqId The request id to unroll
* @param cb The callback to stream token back
* @return Global number of generated tokens for this request id
*/
std::vector<tle::Response> Poll(tle::IdType reqId);
size_t Stream(tle::IdType reqId, const std::function<TokenStreamingCallback>& cb);
};
}
#endif //TGI_TRTLLM_BACKEND_H
#endif //TGI_TRTLLM_BACKEND_H

View File

@ -10,6 +10,7 @@ void huggingface::tgi::backends::InitializeBackend() {
initTrtLlmPlugins();
}
[[nodiscard]]
tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) {
tle::ExecutorConfig execConfig(1);
@ -64,6 +65,7 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string&>());
}
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
const std::vector<tle::TokenIdType> &tokens,
const int32_t maxNewTokens,
@ -76,8 +78,8 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
std::optional<uint32_t> seed,
std::optional<uint32_t> nTopTokens
) {
spdlog::debug(
FMT_STRING("Submitting inference over {:d} tokens to the executor {:d}"),
SPDLOG_DEBUG(
FMT_STRING("Submitting inference over {:d} tokens to the executor ({:d} already in-flight)"),
tokens.size(),
executor.getLatestIterationStats().back().numActiveRequests
);
@ -103,8 +105,31 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
return executor.enqueueRequest(request);
}
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType reqId) {
SPDLOG_DEBUG(FMT_STRING("Polling request {:d}"), reqId);
const auto responses = executor.awaitResponses(reqId);
return responses;
size_t huggingface::tgi::backends::TensorRtLlmBackend::Stream(const tle::IdType reqId, const std::function<TokenStreamingCallback>& cb) {
bool isFinal = false;
size_t generatedTokens = 0;
do {
const auto responses = executor.awaitResponses(reqId);
for (const auto &response: responses){
if(response.hasError()) {
SPDLOG_WARN("Caught error during generation: {}", response.getErrorMsg());
isFinal = true;
} else {
const auto generation = response.getResult();
const auto token = generation.outputTokenIds[0][0];
// Update the end of stream detection and overall number of generated tokens
isFinal = generation.isFinal;
++generatedTokens;
// Send the token back through the callback function for further processing
cb(token);
}
}
} while(!isFinal);
// Return the number of generated tokens
return generatedTokens;
}