feat(backend): add mapping for ignore_eos_token stopping criteria
This commit is contained in:
parent
3af2c6837c
commit
f39edc72ff
|
@ -113,8 +113,10 @@ namespace huggingface::tgi::backends::llamacpp {
|
||||||
auto new_token_id = llama_sampler_sample(sampler.get(), context, -1);
|
auto new_token_id = llama_sampler_sample(sampler.get(), context, -1);
|
||||||
auto is_eos = llama_token_is_eog(mModel_.get(), new_token_id);
|
auto is_eos = llama_token_is_eog(mModel_.get(), new_token_id);
|
||||||
|
|
||||||
generation_context.generated_tokens[n_decoded_tokens] = new_token_id;
|
if (!generation_context.generation_params.ignore_eos_token) {
|
||||||
generating = !is_eos;
|
generation_context.generated_tokens[n_decoded_tokens] = new_token_id;
|
||||||
|
generating = !is_eos;
|
||||||
|
}
|
||||||
|
|
||||||
// Bubble up the generated token if a callback is provided
|
// Bubble up the generated token if a callback is provided
|
||||||
std::invoke(std::forward<const llama_decode_callback>(callback_), new_token_id, is_eos);
|
std::invoke(std::forward<const llama_decode_callback>(callback_), new_token_id, is_eos);
|
||||||
|
|
|
@ -27,7 +27,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
||||||
typedef std::unique_ptr<llama_context, decltype(llama_context_deleter)> llama_context_smart_ptr;
|
typedef std::unique_ptr<llama_context, decltype(llama_context_deleter)> llama_context_smart_ptr;
|
||||||
|
|
||||||
typedef std::function<void(llama_token, bool)> llama_decode_callback;
|
typedef std::function<void(llama_token, bool)> llama_decode_callback;
|
||||||
static constexpr auto llama_void_callback = [](llama_token token_id, bool is_eos) {};
|
static constexpr auto llama_void_callback = [](llama_token, bool) {};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -59,6 +59,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
||||||
*/
|
*/
|
||||||
struct generation_params_t {
|
struct generation_params_t {
|
||||||
uint32_t max_new_tokens = std::numeric_limits<uint32_t>::max();
|
uint32_t max_new_tokens = std::numeric_limits<uint32_t>::max();
|
||||||
|
bool ignore_eos_token = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct generation_context_t {
|
struct generation_context_t {
|
||||||
|
|
|
@ -18,6 +18,7 @@ impl Default for SamplingParams {
|
||||||
mod ffi {
|
mod ffi {
|
||||||
struct GenerationParams {
|
struct GenerationParams {
|
||||||
max_new_tokens: u32,
|
max_new_tokens: u32,
|
||||||
|
ignore_eos_token: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SamplingParams {
|
struct SamplingParams {
|
||||||
|
|
Loading…
Reference in New Issue