feat(backend): add mapping for ignore_eos_token stopping criteria

This commit is contained in:
Morgan Funtowicz 2024-10-31 21:32:29 +01:00
parent 3af2c6837c
commit f39edc72ff
3 changed files with 7 additions and 3 deletions

View File

@ -113,8 +113,10 @@ namespace huggingface::tgi::backends::llamacpp {
auto new_token_id = llama_sampler_sample(sampler.get(), context, -1); auto new_token_id = llama_sampler_sample(sampler.get(), context, -1);
auto is_eos = llama_token_is_eog(mModel_.get(), new_token_id); auto is_eos = llama_token_is_eog(mModel_.get(), new_token_id);
generation_context.generated_tokens[n_decoded_tokens] = new_token_id; if (!generation_context.generation_params.ignore_eos_token) {
generating = !is_eos; generation_context.generated_tokens[n_decoded_tokens] = new_token_id;
generating = !is_eos;
}
// Bubble up the generated token if a callback is provided // Bubble up the generated token if a callback is provided
std::invoke(std::forward<const llama_decode_callback>(callback_), new_token_id, is_eos); std::invoke(std::forward<const llama_decode_callback>(callback_), new_token_id, is_eos);

View File

@ -27,7 +27,7 @@ namespace huggingface::tgi::backends::llamacpp {
typedef std::unique_ptr<llama_context, decltype(llama_context_deleter)> llama_context_smart_ptr; typedef std::unique_ptr<llama_context, decltype(llama_context_deleter)> llama_context_smart_ptr;
typedef std::function<void(llama_token, bool)> llama_decode_callback; typedef std::function<void(llama_token, bool)> llama_decode_callback;
static constexpr auto llama_void_callback = [](llama_token token_id, bool is_eos) {}; static constexpr auto llama_void_callback = [](llama_token, bool) {};
/** /**
* *
@ -59,6 +59,7 @@ namespace huggingface::tgi::backends::llamacpp {
*/ */
struct generation_params_t { struct generation_params_t {
uint32_t max_new_tokens = std::numeric_limits<uint32_t>::max(); uint32_t max_new_tokens = std::numeric_limits<uint32_t>::max();
bool ignore_eos_token = false;
}; };
struct generation_context_t { struct generation_context_t {

View File

@ -18,6 +18,7 @@ impl Default for SamplingParams {
mod ffi { mod ffi {
struct GenerationParams { struct GenerationParams {
max_new_tokens: u32, max_new_tokens: u32,
ignore_eos_token: bool,
} }
struct SamplingParams { struct SamplingParams {