feat(trtllm): detect stop_words from generation_config.json
This commit is contained in:
parent
6376fecc6c
commit
9cee00eec3
|
@ -82,6 +82,14 @@ namespace huggingface::tgi::backends {
|
|||
uint64_t seed
|
||||
) noexcept;
|
||||
|
||||
/**
|
||||
* Attempt to retrieve the
|
||||
* @param generationConfigPath
|
||||
* @return
|
||||
*/
|
||||
std::optional<std::list<std::vector<TokenId>>>
|
||||
GetStopWordsFromConfig(const std::filesystem::path &generationConfigPath) noexcept;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
|
|
|
@ -103,6 +103,31 @@ tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
|
|||
);
|
||||
}
|
||||
|
||||
std::optional<std::list<std::vector<huggingface::tgi::backends::TokenId>>>
|
||||
huggingface::tgi::backends::GetStopWordsFromConfig(
|
||||
const std::filesystem::path &generationConfigPath) noexcept {
|
||||
if (exists(generationConfigPath)) {
|
||||
const auto generationConfig = json::parse(std::ifstream(generationConfigPath));
|
||||
if (const auto eosTokenIds = generationConfig["/eos_token_id"_json_pointer]; eosTokenIds.is_array()) {
|
||||
SPDLOG_INFO(FMT_STRING("Found {:d} EOS tokens"), eosTokenIds.size());
|
||||
std::list<std::vector<huggingface::tgi::backends::TokenId>> stopWords(eosTokenIds.size());
|
||||
|
||||
const auto to_single_token = [](const auto tokenIdObj) -> decltype(stopWords)::value_type {
|
||||
return {tokenIdObj.template get<tle::TokenIdType>()};
|
||||
};
|
||||
|
||||
std::transform(eosTokenIds.cbegin(), eosTokenIds.cend(), stopWords.begin(), to_single_token);
|
||||
return stopWords;
|
||||
} else {
|
||||
SPDLOG_INFO("Invalid EOS tokens entry found (not an array)");
|
||||
}
|
||||
} else {
|
||||
SPDLOG_INFO("No EOS tokens found, generation_config.json doesn't exist");
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
|
||||
const std::filesystem::path &enginesFolder,
|
||||
const std::filesystem::path &executorWorker
|
||||
|
@ -125,21 +150,8 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
|
|||
maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint32_t>();
|
||||
|
||||
// Attempt to discover stopWords from the generation_config.json
|
||||
if (auto generationConfigPath = enginesFolder / "generation_config.json"; exists(generationConfigPath)) {
|
||||
const auto generationConfig = json::parse(std::ifstream(generationConfigPath));
|
||||
if (const auto eosTokenIds = generationConfig["/eos_token_ids"_json_pointer]; eosTokenIds.is_array()) {
|
||||
SPDLOG_INFO(FMT_STRING("Found {:d} EOS tokens"), eosTokenIds.size());
|
||||
stopWords = std::list<decltype(stopWords)::value_type>(eosTokenIds.size());
|
||||
|
||||
const auto to_single_token = [](const auto tokenIdObj) -> decltype(stopWords)::value_type {
|
||||
return {tokenIdObj.template get<tle::TokenIdType>()};
|
||||
};
|
||||
std::transform(eosTokenIds.cbegin(), eosTokenIds.cend(), stopWords.begin(), to_single_token);
|
||||
}
|
||||
} else {
|
||||
SPDLOG_INFO("No EOS tokens found, generation_config.json doesn't exist");
|
||||
stopWords = {};
|
||||
}
|
||||
const auto generationConfigPath = enginesFolder / "generation_config.json";
|
||||
stopWords = GetStopWordsFromConfig(generationConfigPath).value_or(std::list<std::vector<TokenId>>());
|
||||
}
|
||||
|
||||
[[nodiscard("Returned number of requests needs to be consumed")]]
|
||||
|
|
Loading…
Reference in New Issue