feat(backend): correctly handle the max_new_tokens case for is_eos
This commit is contained in:
parent
05ff551950
commit
06424aa9ff
|
@ -113,6 +113,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
auto new_token_id = llama_sampler_sample(sampler.get(), context, -1);
|
||||
auto new_token_logits = 0.0f; // TODO: return logit
|
||||
auto is_eos = llama_token_is_eog(mModel_.get(), new_token_id);
|
||||
auto effective_n_decoded_tokens = n_decoded_tokens + 1;
|
||||
|
||||
if (!generation_context.generation_params.ignore_eos_token) {
|
||||
generation_context.generated_tokens[n_decoded_tokens] = new_token_id;
|
||||
|
@ -121,7 +122,10 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
|
||||
// Bubble up the generated token if a callback is provided
|
||||
std::invoke(std::forward<const llama_decode_callback>(callback_),
|
||||
new_token_id, new_token_logits, is_eos, n_decoded_tokens);
|
||||
new_token_id,
|
||||
new_token_logits,
|
||||
is_eos || effective_n_decoded_tokens == max_new_tokens,
|
||||
effective_n_decoded_tokens);
|
||||
|
||||
batch = llama_batch_get_one(&new_token_id, 1);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue