diff --git a/backends/trtllm/include/backend.h b/backends/trtllm/include/backend.h index 90df0fae..b19b5d7e 100644 --- a/backends/trtllm/include/backend.h +++ b/backends/trtllm/include/backend.h @@ -58,6 +58,12 @@ namespace huggingface::tgi::backends { */ [[nodiscard]] bool IsReady() const; + /*** + * Query the executor for the number of token available for pulling + * @return + */ + [[nodiscard]] size_t NumResponsesReady() const; + /*** * Submit a new generation task to the executor * @param tokens @@ -70,7 +76,6 @@ namespace huggingface::tgi::backends { */ [[nodiscard]] RequestId Submit( const std::vector &tokens, - int32_t maxNewTokens, int32_t topK, float_t topP, float_t temperature, @@ -84,15 +89,6 @@ namespace huggingface::tgi::backends { */ std::vector Poll(RequestId requestId); - /*** - * Unroll the token generation until end of stream is reached. - * Every generated token is streamed back through the provided callback for further processing - * @param reqId The request id to unroll - * @param cb The callback to stream token back - * @return Global number of generated tokens for this request id - */ - uint32_t Stream(RequestId reqId, std::function &cb); - /*** * Stop the underlying executor */ diff --git a/backends/trtllm/lib/backend.cpp b/backends/trtllm/lib/backend.cpp index 1db73651..aca718c4 100644 --- a/backends/trtllm/lib/backend.cpp +++ b/backends/trtllm/lib/backend.cpp @@ -1,7 +1,7 @@ #include #include -#include "fmt/format.h" +#include #include #include "backend.h" @@ -72,20 +72,31 @@ bool huggingface::tgi::backends::TensorRtLlmBackend::IsReady() const { return executor.canEnqueueRequests(); } +size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const { + return executor.getNumResponsesReady(); +} + [[nodiscard("Returned request id needs to be provided back to gather generated tokens")]] tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit( const std::vector &tokens, - const int32_t maxNewTokens, const int32_t topK, const float_t topP, const float_t temperature, const uint64_t seed ) { - SPDLOG_DEBUG( +#ifndef NDEBUG + SPDLOG_INFO( FMT_STRING("Submitting inference over {:d} tokens to the executor ({:d} already in-flight)"), tokens.size(), executor.getLatestIterationStats().back().numActiveRequests ); +#else + SPDLOG_INFO( + FMT_STRING("Submitting inference [{}] to the executor ({:d} already in-flight)"), + fmt::join(tokens, ", "), + executor.getLatestIterationStats().back().numActiveRequests + ); +#endif const auto sampling = tle::SamplingConfig{ 1, @@ -100,40 +111,13 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit( std::nullopt, }; const auto output = tle::OutputConfig{false, false, false}; - return executor.enqueueRequest(tle::Request{tokens, maxNewTokens, true, sampling, output}); -} - -uint32_t huggingface::tgi::backends::TensorRtLlmBackend::Stream(const tle::IdType reqId, - std::function &cb) { - bool isFinal = false; - size_t generatedTokens = 0; - - do { - const auto responses = executor.awaitResponses(reqId); - for (const auto &response: responses) { - if (response.hasError()) { - SPDLOG_WARN("Caught error during generation: {}", response.getErrorMsg()); - isFinal = true; - } else { - const auto generation = response.getResult(); - const auto token = generation.outputTokenIds[0][0]; - - // Update the end of stream detection and overall number of generated tokens - isFinal = generation.isFinal; - ++generatedTokens; - - // Send the token back through the callback function for further processing - cb(token); - } - } - - } while (!isFinal); - - // Return the number of generated tokens - return generatedTokens; + return executor.enqueueRequest( + tle::Request{tokens, std::numeric_limits::max(), true, sampling, output}); } +[[nodiscard("Generated tokens result must be used")]] std::vector huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType requestId) { + SPDLOG_INFO("Polling status for request {}", requestId); return executor.awaitResponses(requestId); }