feat(backend): make eog clearer on c++ side

This commit is contained in:
Morgan Funtowicz 2024-11-04 00:11:55 +01:00
parent 06424aa9ff
commit 11c593dc69
1 changed files with 13 additions and 10 deletions

View File

@ -95,7 +95,7 @@ namespace huggingface::tgi::backends::llamacpp {
// Decode
auto n_decoded_tokens = 0;
for (bool generating = true; generating && n_decoded_tokens < max_new_tokens; ++n_decoded_tokens) {
for (bool generating = true; generating; ++n_decoded_tokens) {
const auto callback_ = callback.value_or(llama_void_callback);
#ifdef TGI_LLAMACPP_BACKEND_DEBUG
@ -108,24 +108,27 @@ namespace huggingface::tgi::backends::llamacpp {
const auto status = llama_decode(context, batch);
#endif
batch.n_tokens = 0;
if (LLAMA_SUCCESS(status)) {
if (LLAMA_SUCCESS(status)) [[likely]] {
// Sample the new token
auto new_token_id = llama_sampler_sample(sampler.get(), context, -1);
auto is_eog = llama_token_is_eog(mModel_.get(), new_token_id);
auto new_token_logits = 0.0f; // TODO: return logit
auto is_eos = llama_token_is_eog(mModel_.get(), new_token_id);
auto effective_n_decoded_tokens = n_decoded_tokens + 1;
if (!generation_context.generation_params.ignore_eos_token) {
generation_context.generated_tokens[n_decoded_tokens] = new_token_id;
generating = !is_eos;
}
// Push the token to the generated vector on Rust side
generation_context.generated_tokens[n_decoded_tokens] = new_token_id;
// Handle termination cases
const auto has_reach_max_tokens = n_decoded_tokens >= max_new_tokens - 1;
const auto has_reach_eog = !generation_context.generation_params.ignore_eos_token & is_eog;
generating = !(has_reach_max_tokens | has_reach_eog);
// Bubble up the generated token if a callback is provided
std::invoke(std::forward<const llama_decode_callback>(callback_),
new_token_id,
new_token_logits,
is_eos || effective_n_decoded_tokens == max_new_tokens,
effective_n_decoded_tokens);
!generating,
n_decoded_tokens + 1);
batch = llama_batch_get_one(&new_token_id, 1);
}