diff --git a/backends/trtllm/include/backend.h b/backends/trtllm/include/backend.h index b5d0711b..97ab3063 100644 --- a/backends/trtllm/include/backend.h +++ b/backends/trtllm/include/backend.h @@ -20,6 +20,8 @@ namespace tle = tensorrt_llm::executor; namespace huggingface::tgi::backends { + using TokenStreamingCallback = void(tle::TokenIdType); + /** * Initialize all the components required by TRTLLM. * It is required to call this function before attempting to load any engine @@ -28,9 +30,8 @@ namespace huggingface::tgi::backends { /** * - * @param config - * @param workerPath - * @param channel + * @param config TensorRT-LLM configuration object + * @param workerPath Path to the "executorWorker" provided by TensorRT-LLM when using orchestrator mode * @return */ tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath); @@ -58,7 +59,7 @@ namespace huggingface::tgi::backends { } /*** - * + * Submit a new generation task to the executor * @param tokens * @param maxNewTokens * @param topK @@ -69,7 +70,7 @@ namespace huggingface::tgi::backends { * @param frequencyPenalty * @param seed * @param nTopTokens - * @return + * @return Request id related to this generation for reference */ [[nodiscard]] tle::IdType Submit( const std::vector &tokens, @@ -85,13 +86,15 @@ namespace huggingface::tgi::backends { ); /*** - * - * @param reqId - * @return + * Unroll the token generation until end of stream is reached. + * Every generated token is streamed back through the provided callback for further processing + * @param reqId The request id to unroll + * @param cb The callback to stream token back + * @return Global number of generated tokens for this request id */ - std::vector Poll(tle::IdType reqId); + size_t Stream(tle::IdType reqId, const std::function& cb); }; } -#endif //TGI_TRTLLM_BACKEND_H +#endif //TGI_TRTLLM_BACKEND_H \ No newline at end of file diff --git a/backends/trtllm/lib/backend.cpp b/backends/trtllm/lib/backend.cpp index 1a3f598a..0f058128 100644 --- a/backends/trtllm/lib/backend.cpp +++ b/backends/trtllm/lib/backend.cpp @@ -10,6 +10,7 @@ void huggingface::tgi::backends::InitializeBackend() { initTrtLlmPlugins(); } +[[nodiscard]] tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) { tle::ExecutorConfig execConfig(1); @@ -64,6 +65,7 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend( SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref()); } +[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]] tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit( const std::vector &tokens, const int32_t maxNewTokens, @@ -76,8 +78,8 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit( std::optional seed, std::optional nTopTokens ) { - spdlog::debug( - FMT_STRING("Submitting inference over {:d} tokens to the executor {:d}"), + SPDLOG_DEBUG( + FMT_STRING("Submitting inference over {:d} tokens to the executor ({:d} already in-flight)"), tokens.size(), executor.getLatestIterationStats().back().numActiveRequests ); @@ -103,8 +105,31 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit( return executor.enqueueRequest(request); } -std::vector huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType reqId) { - SPDLOG_DEBUG(FMT_STRING("Polling request {:d}"), reqId); - const auto responses = executor.awaitResponses(reqId); - return responses; +size_t huggingface::tgi::backends::TensorRtLlmBackend::Stream(const tle::IdType reqId, const std::function& cb) { + bool isFinal = false; + size_t generatedTokens = 0; + + do { + const auto responses = executor.awaitResponses(reqId); + for (const auto &response: responses){ + if(response.hasError()) { + SPDLOG_WARN("Caught error during generation: {}", response.getErrorMsg()); + isFinal = true; + } else { + const auto generation = response.getResult(); + const auto token = generation.outputTokenIds[0][0]; + + // Update the end of stream detection and overall number of generated tokens + isFinal = generation.isFinal; + ++generatedTokens; + + // Send the token back through the callback function for further processing + cb(token); + } + } + + } while(!isFinal); + + // Return the number of generated tokens + return generatedTokens; } \ No newline at end of file