feat(backend): add logit parameter in the callback fn

This commit is contained in:
Morgan Funtowicz 2024-11-01 00:49:50 +01:00
parent f39edc72ff
commit d4aee42fd8
4 changed files with 7 additions and 5 deletions

View File

@ -111,6 +111,7 @@ namespace huggingface::tgi::backends::llamacpp {
if (LLAMA_SUCCESS(status)) { if (LLAMA_SUCCESS(status)) {
// Sample the new token // Sample the new token
auto new_token_id = llama_sampler_sample(sampler.get(), context, -1); auto new_token_id = llama_sampler_sample(sampler.get(), context, -1);
auto new_token_logits = 0.0f; // TODO: return logit
auto is_eos = llama_token_is_eog(mModel_.get(), new_token_id); auto is_eos = llama_token_is_eog(mModel_.get(), new_token_id);
if (!generation_context.generation_params.ignore_eos_token) { if (!generation_context.generation_params.ignore_eos_token) {
@ -119,7 +120,8 @@ namespace huggingface::tgi::backends::llamacpp {
} }
// Bubble up the generated token if a callback is provided // Bubble up the generated token if a callback is provided
std::invoke(std::forward<const llama_decode_callback>(callback_), new_token_id, is_eos); std::invoke(
std::forward<const llama_decode_callback>(callback_), new_token_id, new_token_logits, is_eos);
batch = llama_batch_get_one(&new_token_id, 1); batch = llama_batch_get_one(&new_token_id, 1);
} }

View File

@ -26,8 +26,8 @@ namespace huggingface::tgi::backends::llamacpp {
static constexpr auto llama_context_deleter = [](llama_context *pContext) { llama_free(pContext); }; static constexpr auto llama_context_deleter = [](llama_context *pContext) { llama_free(pContext); };
typedef std::unique_ptr<llama_context, decltype(llama_context_deleter)> llama_context_smart_ptr; typedef std::unique_ptr<llama_context, decltype(llama_context_deleter)> llama_context_smart_ptr;
typedef std::function<void(llama_token, bool)> llama_decode_callback; typedef std::function<void(llama_token, float_t, bool)> llama_decode_callback;
static constexpr auto llama_void_callback = [](llama_token, bool) {}; static constexpr auto llama_void_callback = [](llama_token, float_t, bool) {};
/** /**
* *

View File

@ -66,7 +66,7 @@ namespace huggingface::tgi::backends::llamacpp {
rust::Slice <uint32_t> generated_tokens, rust::Slice <uint32_t> generated_tokens,
const generation_params_t &generation_params, const generation_params_t &generation_params,
const sampling_params_t &sampling_params, const sampling_params_t &sampling_params,
rust::Fn<void(uint32_t, bool)> callback rust::Fn<void(uint32_t, float_t, bool)> callback
) { ) {
// Define the visitor lambda function which requires the has_emplace_generate constraint on T // Define the visitor lambda function which requires the has_emplace_generate constraint on T
static auto inner_fw = [=, &generation_params, &sampling_params]<has_emplace_generate T>(T &&backend) static auto inner_fw = [=, &generation_params, &sampling_params]<has_emplace_generate T>(T &&backend)

View File

@ -51,7 +51,7 @@ mod ffi {
generated: &mut [u32], generated: &mut [u32],
generation_params: &GenerationParams, generation_params: &GenerationParams,
sampling_params: &SamplingParams, sampling_params: &SamplingParams,
callback: fn(u32, bool), callback: fn(u32, f32, bool),
) -> Result<usize>; ) -> Result<usize>;
} }
} }