From d4aee42fd8dc16113c42c1d6032f405717c5794b Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Fri, 1 Nov 2024 00:49:50 +0100 Subject: [PATCH] feat(backend): add logit parameter in the callback fn --- backends/llamacpp/csrc/backend.cpp | 4 +++- backends/llamacpp/csrc/backend.hpp | 4 ++-- backends/llamacpp/csrc/ffi.hpp | 2 +- backends/llamacpp/src/lib.rs | 2 +- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/backends/llamacpp/csrc/backend.cpp b/backends/llamacpp/csrc/backend.cpp index 665f78df..50d5897c 100644 --- a/backends/llamacpp/csrc/backend.cpp +++ b/backends/llamacpp/csrc/backend.cpp @@ -111,6 +111,7 @@ namespace huggingface::tgi::backends::llamacpp { if (LLAMA_SUCCESS(status)) { // Sample the new token auto new_token_id = llama_sampler_sample(sampler.get(), context, -1); + auto new_token_logits = 0.0f; // TODO: return logit auto is_eos = llama_token_is_eog(mModel_.get(), new_token_id); if (!generation_context.generation_params.ignore_eos_token) { @@ -119,7 +120,8 @@ namespace huggingface::tgi::backends::llamacpp { } // Bubble up the generated token if a callback is provided - std::invoke(std::forward(callback_), new_token_id, is_eos); + std::invoke( + std::forward(callback_), new_token_id, new_token_logits, is_eos); batch = llama_batch_get_one(&new_token_id, 1); } diff --git a/backends/llamacpp/csrc/backend.hpp b/backends/llamacpp/csrc/backend.hpp index 44952a5d..288bf36a 100644 --- a/backends/llamacpp/csrc/backend.hpp +++ b/backends/llamacpp/csrc/backend.hpp @@ -26,8 +26,8 @@ namespace huggingface::tgi::backends::llamacpp { static constexpr auto llama_context_deleter = [](llama_context *pContext) { llama_free(pContext); }; typedef std::unique_ptr llama_context_smart_ptr; - typedef std::function llama_decode_callback; - static constexpr auto llama_void_callback = [](llama_token, bool) {}; + typedef std::function llama_decode_callback; + static constexpr auto llama_void_callback = [](llama_token, float_t, bool) {}; /** * diff --git a/backends/llamacpp/csrc/ffi.hpp b/backends/llamacpp/csrc/ffi.hpp index 18254114..5c404b01 100644 --- a/backends/llamacpp/csrc/ffi.hpp +++ b/backends/llamacpp/csrc/ffi.hpp @@ -66,7 +66,7 @@ namespace huggingface::tgi::backends::llamacpp { rust::Slice generated_tokens, const generation_params_t &generation_params, const sampling_params_t &sampling_params, - rust::Fn callback + rust::Fn callback ) { // Define the visitor lambda function which requires the has_emplace_generate constraint on T static auto inner_fw = [=, &generation_params, &sampling_params](T &&backend) diff --git a/backends/llamacpp/src/lib.rs b/backends/llamacpp/src/lib.rs index 33088d54..8d51a15a 100644 --- a/backends/llamacpp/src/lib.rs +++ b/backends/llamacpp/src/lib.rs @@ -51,7 +51,7 @@ mod ffi { generated: &mut [u32], generation_params: &GenerationParams, sampling_params: &SamplingParams, - callback: fn(u32, bool), + callback: fn(u32, f32, bool), ) -> Result; } }