feat(backend): add mapping for ignore_eos_token stopping criteria
This commit is contained in:
parent
3af2c6837c
commit
f39edc72ff
|
@ -113,8 +113,10 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
auto new_token_id = llama_sampler_sample(sampler.get(), context, -1);
|
||||
auto is_eos = llama_token_is_eog(mModel_.get(), new_token_id);
|
||||
|
||||
generation_context.generated_tokens[n_decoded_tokens] = new_token_id;
|
||||
generating = !is_eos;
|
||||
if (!generation_context.generation_params.ignore_eos_token) {
|
||||
generation_context.generated_tokens[n_decoded_tokens] = new_token_id;
|
||||
generating = !is_eos;
|
||||
}
|
||||
|
||||
// Bubble up the generated token if a callback is provided
|
||||
std::invoke(std::forward<const llama_decode_callback>(callback_), new_token_id, is_eos);
|
||||
|
|
|
@ -27,7 +27,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
typedef std::unique_ptr<llama_context, decltype(llama_context_deleter)> llama_context_smart_ptr;
|
||||
|
||||
typedef std::function<void(llama_token, bool)> llama_decode_callback;
|
||||
static constexpr auto llama_void_callback = [](llama_token token_id, bool is_eos) {};
|
||||
static constexpr auto llama_void_callback = [](llama_token, bool) {};
|
||||
|
||||
/**
|
||||
*
|
||||
|
@ -59,6 +59,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
*/
|
||||
struct generation_params_t {
|
||||
uint32_t max_new_tokens = std::numeric_limits<uint32_t>::max();
|
||||
bool ignore_eos_token = false;
|
||||
};
|
||||
|
||||
struct generation_context_t {
|
||||
|
|
|
@ -18,6 +18,7 @@ impl Default for SamplingParams {
|
|||
mod ffi {
|
||||
struct GenerationParams {
|
||||
max_new_tokens: u32,
|
||||
ignore_eos_token: bool,
|
||||
}
|
||||
|
||||
struct SamplingParams {
|
||||
|
|
Loading…
Reference in New Issue