implement the Stream method to send new tokens through a callback

This commit is contained in:
Morgan Funtowicz 2024-07-09 13:46:48 +00:00
parent 09292b06a0
commit 13eabfabcb
2 changed files with 44 additions and 16 deletions

View File

@ -20,6 +20,8 @@ namespace tle = tensorrt_llm::executor;
namespace huggingface::tgi::backends { namespace huggingface::tgi::backends {
using TokenStreamingCallback = void(tle::TokenIdType);
/** /**
* Initialize all the components required by TRTLLM. * Initialize all the components required by TRTLLM.
* It is required to call this function before attempting to load any engine * It is required to call this function before attempting to load any engine
@ -28,9 +30,8 @@ namespace huggingface::tgi::backends {
/** /**
* *
* @param config * @param config TensorRT-LLM configuration object
* @param workerPath * @param workerPath Path to the "executorWorker" provided by TensorRT-LLM when using orchestrator mode
* @param channel
* @return * @return
*/ */
tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath); tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath);
@ -58,7 +59,7 @@ namespace huggingface::tgi::backends {
} }
/*** /***
* * Submit a new generation task to the executor
* @param tokens * @param tokens
* @param maxNewTokens * @param maxNewTokens
* @param topK * @param topK
@ -69,7 +70,7 @@ namespace huggingface::tgi::backends {
* @param frequencyPenalty * @param frequencyPenalty
* @param seed * @param seed
* @param nTopTokens * @param nTopTokens
* @return * @return Request id related to this generation for reference
*/ */
[[nodiscard]] tle::IdType Submit( [[nodiscard]] tle::IdType Submit(
const std::vector<tle::TokenIdType> &tokens, const std::vector<tle::TokenIdType> &tokens,
@ -85,11 +86,13 @@ namespace huggingface::tgi::backends {
); );
/*** /***
* * Unroll the token generation until end of stream is reached.
* @param reqId * Every generated token is streamed back through the provided callback for further processing
* @return * @param reqId The request id to unroll
* @param cb The callback to stream token back
* @return Global number of generated tokens for this request id
*/ */
std::vector<tle::Response> Poll(tle::IdType reqId); size_t Stream(tle::IdType reqId, const std::function<TokenStreamingCallback>& cb);
}; };
} }

View File

@ -10,6 +10,7 @@ void huggingface::tgi::backends::InitializeBackend() {
initTrtLlmPlugins(); initTrtLlmPlugins();
} }
[[nodiscard]]
tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) { tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) {
tle::ExecutorConfig execConfig(1); tle::ExecutorConfig execConfig(1);
@ -64,6 +65,7 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string&>()); SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string&>());
} }
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit( tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
const std::vector<tle::TokenIdType> &tokens, const std::vector<tle::TokenIdType> &tokens,
const int32_t maxNewTokens, const int32_t maxNewTokens,
@ -76,8 +78,8 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
std::optional<uint32_t> seed, std::optional<uint32_t> seed,
std::optional<uint32_t> nTopTokens std::optional<uint32_t> nTopTokens
) { ) {
spdlog::debug( SPDLOG_DEBUG(
FMT_STRING("Submitting inference over {:d} tokens to the executor {:d}"), FMT_STRING("Submitting inference over {:d} tokens to the executor ({:d} already in-flight)"),
tokens.size(), tokens.size(),
executor.getLatestIterationStats().back().numActiveRequests executor.getLatestIterationStats().back().numActiveRequests
); );
@ -103,8 +105,31 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
return executor.enqueueRequest(request); return executor.enqueueRequest(request);
} }
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType reqId) { size_t huggingface::tgi::backends::TensorRtLlmBackend::Stream(const tle::IdType reqId, const std::function<TokenStreamingCallback>& cb) {
SPDLOG_DEBUG(FMT_STRING("Polling request {:d}"), reqId); bool isFinal = false;
size_t generatedTokens = 0;
do {
const auto responses = executor.awaitResponses(reqId); const auto responses = executor.awaitResponses(reqId);
return responses; for (const auto &response: responses){
if(response.hasError()) {
SPDLOG_WARN("Caught error during generation: {}", response.getErrorMsg());
isFinal = true;
} else {
const auto generation = response.getResult();
const auto token = generation.outputTokenIds[0][0];
// Update the end of stream detection and overall number of generated tokens
isFinal = generation.isFinal;
++generatedTokens;
// Send the token back through the callback function for further processing
cb(token);
}
}
} while(!isFinal);
// Return the number of generated tokens
return generatedTokens;
} }