From 9cee00eec356a5b073bba7f113c8a981f20ae573 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Wed, 23 Oct 2024 16:05:59 +0200 Subject: [PATCH] feat(trtllm): detect stop_words from generation_config.json --- backends/trtllm/include/backend.h | 8 ++++++ backends/trtllm/lib/backend.cpp | 42 ++++++++++++++++++++----------- 2 files changed, 35 insertions(+), 15 deletions(-) diff --git a/backends/trtllm/include/backend.h b/backends/trtllm/include/backend.h index dee83e22..d23f6288 100644 --- a/backends/trtllm/include/backend.h +++ b/backends/trtllm/include/backend.h @@ -82,6 +82,14 @@ namespace huggingface::tgi::backends { uint64_t seed ) noexcept; + /** + * Attempt to retrieve the + * @param generationConfigPath + * @return + */ + std::optional>> + GetStopWordsFromConfig(const std::filesystem::path &generationConfigPath) noexcept; + /** * */ diff --git a/backends/trtllm/lib/backend.cpp b/backends/trtllm/lib/backend.cpp index 3fda9d62..ad22b0c7 100644 --- a/backends/trtllm/lib/backend.cpp +++ b/backends/trtllm/lib/backend.cpp @@ -103,6 +103,31 @@ tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig( ); } +std::optional>> +huggingface::tgi::backends::GetStopWordsFromConfig( + const std::filesystem::path &generationConfigPath) noexcept { + if (exists(generationConfigPath)) { + const auto generationConfig = json::parse(std::ifstream(generationConfigPath)); + if (const auto eosTokenIds = generationConfig["/eos_token_id"_json_pointer]; eosTokenIds.is_array()) { + SPDLOG_INFO(FMT_STRING("Found {:d} EOS tokens"), eosTokenIds.size()); + std::list> stopWords(eosTokenIds.size()); + + const auto to_single_token = [](const auto tokenIdObj) -> decltype(stopWords)::value_type { + return {tokenIdObj.template get()}; + }; + + std::transform(eosTokenIds.cbegin(), eosTokenIds.cend(), stopWords.begin(), to_single_token); + return stopWords; + } else { + SPDLOG_INFO("Invalid EOS tokens entry found (not an array)"); + } + } else { + SPDLOG_INFO("No EOS tokens found, generation_config.json doesn't exist"); + } + + return std::nullopt; +} + huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend( const std::filesystem::path &enginesFolder, const std::filesystem::path &executorWorker @@ -125,21 +150,8 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend( maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get(); // Attempt to discover stopWords from the generation_config.json - if (auto generationConfigPath = enginesFolder / "generation_config.json"; exists(generationConfigPath)) { - const auto generationConfig = json::parse(std::ifstream(generationConfigPath)); - if (const auto eosTokenIds = generationConfig["/eos_token_ids"_json_pointer]; eosTokenIds.is_array()) { - SPDLOG_INFO(FMT_STRING("Found {:d} EOS tokens"), eosTokenIds.size()); - stopWords = std::list(eosTokenIds.size()); - - const auto to_single_token = [](const auto tokenIdObj) -> decltype(stopWords)::value_type { - return {tokenIdObj.template get()}; - }; - std::transform(eosTokenIds.cbegin(), eosTokenIds.cend(), stopWords.begin(), to_single_token); - } - } else { - SPDLOG_INFO("No EOS tokens found, generation_config.json doesn't exist"); - stopWords = {}; - } + const auto generationConfigPath = enginesFolder / "generation_config.json"; + stopWords = GetStopWordsFromConfig(generationConfigPath).value_or(std::list>()); } [[nodiscard("Returned number of requests needs to be consumed")]]