From f39edc72ff4eaa3226d3ea469ebad6c107dfd5cb Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Thu, 31 Oct 2024 21:32:29 +0100 Subject: [PATCH] feat(backend): add mapping for ignore_eos_token stopping criteria --- backends/llamacpp/csrc/backend.cpp | 6 ++++-- backends/llamacpp/csrc/backend.hpp | 3 ++- backends/llamacpp/src/lib.rs | 1 + 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/backends/llamacpp/csrc/backend.cpp b/backends/llamacpp/csrc/backend.cpp index f2f5d4c6..665f78df 100644 --- a/backends/llamacpp/csrc/backend.cpp +++ b/backends/llamacpp/csrc/backend.cpp @@ -113,8 +113,10 @@ namespace huggingface::tgi::backends::llamacpp { auto new_token_id = llama_sampler_sample(sampler.get(), context, -1); auto is_eos = llama_token_is_eog(mModel_.get(), new_token_id); - generation_context.generated_tokens[n_decoded_tokens] = new_token_id; - generating = !is_eos; + if (!generation_context.generation_params.ignore_eos_token) { + generation_context.generated_tokens[n_decoded_tokens] = new_token_id; + generating = !is_eos; + } // Bubble up the generated token if a callback is provided std::invoke(std::forward(callback_), new_token_id, is_eos); diff --git a/backends/llamacpp/csrc/backend.hpp b/backends/llamacpp/csrc/backend.hpp index 871490f2..44952a5d 100644 --- a/backends/llamacpp/csrc/backend.hpp +++ b/backends/llamacpp/csrc/backend.hpp @@ -27,7 +27,7 @@ namespace huggingface::tgi::backends::llamacpp { typedef std::unique_ptr llama_context_smart_ptr; typedef std::function llama_decode_callback; - static constexpr auto llama_void_callback = [](llama_token token_id, bool is_eos) {}; + static constexpr auto llama_void_callback = [](llama_token, bool) {}; /** * @@ -59,6 +59,7 @@ namespace huggingface::tgi::backends::llamacpp { */ struct generation_params_t { uint32_t max_new_tokens = std::numeric_limits::max(); + bool ignore_eos_token = false; }; struct generation_context_t { diff --git a/backends/llamacpp/src/lib.rs b/backends/llamacpp/src/lib.rs index 9fb79501..33088d54 100644 --- a/backends/llamacpp/src/lib.rs +++ b/backends/llamacpp/src/lib.rs @@ -18,6 +18,7 @@ impl Default for SamplingParams { mod ffi { struct GenerationParams { max_new_tokens: u32, + ignore_eos_token: bool, } struct SamplingParams {