feat(backend): make eog clearer on c++ side
This commit is contained in:
parent
06424aa9ff
commit
11c593dc69
|
@ -95,7 +95,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
|
||||
// Decode
|
||||
auto n_decoded_tokens = 0;
|
||||
for (bool generating = true; generating && n_decoded_tokens < max_new_tokens; ++n_decoded_tokens) {
|
||||
for (bool generating = true; generating; ++n_decoded_tokens) {
|
||||
const auto callback_ = callback.value_or(llama_void_callback);
|
||||
|
||||
#ifdef TGI_LLAMACPP_BACKEND_DEBUG
|
||||
|
@ -108,24 +108,27 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
const auto status = llama_decode(context, batch);
|
||||
#endif
|
||||
batch.n_tokens = 0;
|
||||
if (LLAMA_SUCCESS(status)) {
|
||||
if (LLAMA_SUCCESS(status)) [[likely]] {
|
||||
// Sample the new token
|
||||
auto new_token_id = llama_sampler_sample(sampler.get(), context, -1);
|
||||
auto is_eog = llama_token_is_eog(mModel_.get(), new_token_id);
|
||||
auto new_token_logits = 0.0f; // TODO: return logit
|
||||
auto is_eos = llama_token_is_eog(mModel_.get(), new_token_id);
|
||||
auto effective_n_decoded_tokens = n_decoded_tokens + 1;
|
||||
|
||||
if (!generation_context.generation_params.ignore_eos_token) {
|
||||
generation_context.generated_tokens[n_decoded_tokens] = new_token_id;
|
||||
generating = !is_eos;
|
||||
}
|
||||
// Push the token to the generated vector on Rust side
|
||||
generation_context.generated_tokens[n_decoded_tokens] = new_token_id;
|
||||
|
||||
// Handle termination cases
|
||||
const auto has_reach_max_tokens = n_decoded_tokens >= max_new_tokens - 1;
|
||||
const auto has_reach_eog = !generation_context.generation_params.ignore_eos_token & is_eog;
|
||||
|
||||
generating = !(has_reach_max_tokens | has_reach_eog);
|
||||
|
||||
// Bubble up the generated token if a callback is provided
|
||||
std::invoke(std::forward<const llama_decode_callback>(callback_),
|
||||
new_token_id,
|
||||
new_token_logits,
|
||||
is_eos || effective_n_decoded_tokens == max_new_tokens,
|
||||
effective_n_decoded_tokens);
|
||||
!generating,
|
||||
n_decoded_tokens + 1);
|
||||
|
||||
batch = llama_batch_get_one(&new_token_id, 1);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue