feat(trtllm): detect stop_words from generation_config.json

This commit is contained in:
Morgan Funtowicz 2024-10-23 16:05:59 +02:00
parent 6376fecc6c
commit 9cee00eec3
2 changed files with 35 additions and 15 deletions

View File

@ -82,6 +82,14 @@ namespace huggingface::tgi::backends {
uint64_t seed
) noexcept;
/**
* Attempt to retrieve the
* @param generationConfigPath
* @return
*/
std::optional<std::list<std::vector<TokenId>>>
GetStopWordsFromConfig(const std::filesystem::path &generationConfigPath) noexcept;
/**
*
*/

View File

@ -103,6 +103,31 @@ tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
);
}
std::optional<std::list<std::vector<huggingface::tgi::backends::TokenId>>>
huggingface::tgi::backends::GetStopWordsFromConfig(
const std::filesystem::path &generationConfigPath) noexcept {
if (exists(generationConfigPath)) {
const auto generationConfig = json::parse(std::ifstream(generationConfigPath));
if (const auto eosTokenIds = generationConfig["/eos_token_id"_json_pointer]; eosTokenIds.is_array()) {
SPDLOG_INFO(FMT_STRING("Found {:d} EOS tokens"), eosTokenIds.size());
std::list<std::vector<huggingface::tgi::backends::TokenId>> stopWords(eosTokenIds.size());
const auto to_single_token = [](const auto tokenIdObj) -> decltype(stopWords)::value_type {
return {tokenIdObj.template get<tle::TokenIdType>()};
};
std::transform(eosTokenIds.cbegin(), eosTokenIds.cend(), stopWords.begin(), to_single_token);
return stopWords;
} else {
SPDLOG_INFO("Invalid EOS tokens entry found (not an array)");
}
} else {
SPDLOG_INFO("No EOS tokens found, generation_config.json doesn't exist");
}
return std::nullopt;
}
huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
const std::filesystem::path &enginesFolder,
const std::filesystem::path &executorWorker
@ -125,21 +150,8 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint32_t>();
// Attempt to discover stopWords from the generation_config.json
if (auto generationConfigPath = enginesFolder / "generation_config.json"; exists(generationConfigPath)) {
const auto generationConfig = json::parse(std::ifstream(generationConfigPath));
if (const auto eosTokenIds = generationConfig["/eos_token_ids"_json_pointer]; eosTokenIds.is_array()) {
SPDLOG_INFO(FMT_STRING("Found {:d} EOS tokens"), eosTokenIds.size());
stopWords = std::list<decltype(stopWords)::value_type>(eosTokenIds.size());
const auto to_single_token = [](const auto tokenIdObj) -> decltype(stopWords)::value_type {
return {tokenIdObj.template get<tle::TokenIdType>()};
};
std::transform(eosTokenIds.cbegin(), eosTokenIds.cend(), stopWords.begin(), to_single_token);
}
} else {
SPDLOG_INFO("No EOS tokens found, generation_config.json doesn't exist");
stopWords = {};
}
const auto generationConfigPath = enginesFolder / "generation_config.json";
stopWords = GetStopWordsFromConfig(generationConfigPath).value_or(std::list<std::vector<TokenId>>());
}
[[nodiscard("Returned number of requests needs to be consumed")]]