feat(backend): correctly handle the max_new_tokens case for is_eos

This commit is contained in:
Morgan Funtowicz 2024-11-03 23:50:46 +01:00
parent 05ff551950
commit 06424aa9ff
1 changed files with 5 additions and 1 deletions

View File

@ -113,6 +113,7 @@ namespace huggingface::tgi::backends::llamacpp {
auto new_token_id = llama_sampler_sample(sampler.get(), context, -1);
auto new_token_logits = 0.0f; // TODO: return logit
auto is_eos = llama_token_is_eog(mModel_.get(), new_token_id);
auto effective_n_decoded_tokens = n_decoded_tokens + 1;
if (!generation_context.generation_params.ignore_eos_token) {
generation_context.generated_tokens[n_decoded_tokens] = new_token_id;
@ -121,7 +122,10 @@ namespace huggingface::tgi::backends::llamacpp {
// Bubble up the generated token if a callback is provided
std::invoke(std::forward<const llama_decode_callback>(callback_),
new_token_id, new_token_logits, is_eos, n_decoded_tokens);
new_token_id,
new_token_logits,
is_eos || effective_n_decoded_tokens == max_new_tokens,
effective_n_decoded_tokens);
batch = llama_batch_get_one(&new_token_id, 1);
}