feat(trtllm): detect stop_words from generation_config.json
This commit is contained in:
parent
6376fecc6c
commit
9cee00eec3
|
@ -82,6 +82,14 @@ namespace huggingface::tgi::backends {
|
||||||
uint64_t seed
|
uint64_t seed
|
||||||
) noexcept;
|
) noexcept;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Attempt to retrieve the
|
||||||
|
* @param generationConfigPath
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
std::optional<std::list<std::vector<TokenId>>>
|
||||||
|
GetStopWordsFromConfig(const std::filesystem::path &generationConfigPath) noexcept;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -103,6 +103,31 @@ tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::optional<std::list<std::vector<huggingface::tgi::backends::TokenId>>>
|
||||||
|
huggingface::tgi::backends::GetStopWordsFromConfig(
|
||||||
|
const std::filesystem::path &generationConfigPath) noexcept {
|
||||||
|
if (exists(generationConfigPath)) {
|
||||||
|
const auto generationConfig = json::parse(std::ifstream(generationConfigPath));
|
||||||
|
if (const auto eosTokenIds = generationConfig["/eos_token_id"_json_pointer]; eosTokenIds.is_array()) {
|
||||||
|
SPDLOG_INFO(FMT_STRING("Found {:d} EOS tokens"), eosTokenIds.size());
|
||||||
|
std::list<std::vector<huggingface::tgi::backends::TokenId>> stopWords(eosTokenIds.size());
|
||||||
|
|
||||||
|
const auto to_single_token = [](const auto tokenIdObj) -> decltype(stopWords)::value_type {
|
||||||
|
return {tokenIdObj.template get<tle::TokenIdType>()};
|
||||||
|
};
|
||||||
|
|
||||||
|
std::transform(eosTokenIds.cbegin(), eosTokenIds.cend(), stopWords.begin(), to_single_token);
|
||||||
|
return stopWords;
|
||||||
|
} else {
|
||||||
|
SPDLOG_INFO("Invalid EOS tokens entry found (not an array)");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
SPDLOG_INFO("No EOS tokens found, generation_config.json doesn't exist");
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
|
huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
|
||||||
const std::filesystem::path &enginesFolder,
|
const std::filesystem::path &enginesFolder,
|
||||||
const std::filesystem::path &executorWorker
|
const std::filesystem::path &executorWorker
|
||||||
|
@ -125,21 +150,8 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
|
||||||
maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint32_t>();
|
maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint32_t>();
|
||||||
|
|
||||||
// Attempt to discover stopWords from the generation_config.json
|
// Attempt to discover stopWords from the generation_config.json
|
||||||
if (auto generationConfigPath = enginesFolder / "generation_config.json"; exists(generationConfigPath)) {
|
const auto generationConfigPath = enginesFolder / "generation_config.json";
|
||||||
const auto generationConfig = json::parse(std::ifstream(generationConfigPath));
|
stopWords = GetStopWordsFromConfig(generationConfigPath).value_or(std::list<std::vector<TokenId>>());
|
||||||
if (const auto eosTokenIds = generationConfig["/eos_token_ids"_json_pointer]; eosTokenIds.is_array()) {
|
|
||||||
SPDLOG_INFO(FMT_STRING("Found {:d} EOS tokens"), eosTokenIds.size());
|
|
||||||
stopWords = std::list<decltype(stopWords)::value_type>(eosTokenIds.size());
|
|
||||||
|
|
||||||
const auto to_single_token = [](const auto tokenIdObj) -> decltype(stopWords)::value_type {
|
|
||||||
return {tokenIdObj.template get<tle::TokenIdType>()};
|
|
||||||
};
|
|
||||||
std::transform(eosTokenIds.cbegin(), eosTokenIds.cend(), stopWords.begin(), to_single_token);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
SPDLOG_INFO("No EOS tokens found, generation_config.json doesn't exist");
|
|
||||||
stopWords = {};
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard("Returned number of requests needs to be consumed")]]
|
[[nodiscard("Returned number of requests needs to be consumed")]]
|
||||||
|
|
Loading…
Reference in New Issue