feat(backend): add mapping for ignore_eos_token stopping criteria

This commit is contained in:
Morgan Funtowicz 2024-10-31 21:32:29 +01:00
parent 3af2c6837c
commit f39edc72ff
3 changed files with 7 additions and 3 deletions

View File

@ -113,8 +113,10 @@ namespace huggingface::tgi::backends::llamacpp {
auto new_token_id = llama_sampler_sample(sampler.get(), context, -1);
auto is_eos = llama_token_is_eog(mModel_.get(), new_token_id);
generation_context.generated_tokens[n_decoded_tokens] = new_token_id;
generating = !is_eos;
if (!generation_context.generation_params.ignore_eos_token) {
generation_context.generated_tokens[n_decoded_tokens] = new_token_id;
generating = !is_eos;
}
// Bubble up the generated token if a callback is provided
std::invoke(std::forward<const llama_decode_callback>(callback_), new_token_id, is_eos);

View File

@ -27,7 +27,7 @@ namespace huggingface::tgi::backends::llamacpp {
typedef std::unique_ptr<llama_context, decltype(llama_context_deleter)> llama_context_smart_ptr;
typedef std::function<void(llama_token, bool)> llama_decode_callback;
static constexpr auto llama_void_callback = [](llama_token token_id, bool is_eos) {};
static constexpr auto llama_void_callback = [](llama_token, bool) {};
/**
*
@ -59,6 +59,7 @@ namespace huggingface::tgi::backends::llamacpp {
*/
struct generation_params_t {
uint32_t max_new_tokens = std::numeric_limits<uint32_t>::max();
bool ignore_eos_token = false;
};
struct generation_context_t {

View File

@ -18,6 +18,7 @@ impl Default for SamplingParams {
mod ffi {
struct GenerationParams {
max_new_tokens: u32,
ignore_eos_token: bool,
}
struct SamplingParams {