From 05ff551950dad2948f5f8fa10234496179dffd42 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Sun, 3 Nov 2024 23:07:22 +0100 Subject: [PATCH] feat(backend): add number of generated tokens in the callback --- backends/llamacpp/csrc/backend.cpp | 4 ++-- backends/llamacpp/csrc/backend.hpp | 4 ++-- backends/llamacpp/csrc/ffi.hpp | 6 +++--- backends/llamacpp/src/backend.rs | 3 ++- backends/llamacpp/src/lib.rs | 2 +- 5 files changed, 10 insertions(+), 9 deletions(-) diff --git a/backends/llamacpp/csrc/backend.cpp b/backends/llamacpp/csrc/backend.cpp index 4b608620..54e41a14 100644 --- a/backends/llamacpp/csrc/backend.cpp +++ b/backends/llamacpp/csrc/backend.cpp @@ -120,8 +120,8 @@ namespace huggingface::tgi::backends::llamacpp { } // Bubble up the generated token if a callback is provided - std::invoke( - std::forward(callback_), new_token_id, new_token_logits, is_eos); + std::invoke(std::forward(callback_), + new_token_id, new_token_logits, is_eos, n_decoded_tokens); batch = llama_batch_get_one(&new_token_id, 1); } diff --git a/backends/llamacpp/csrc/backend.hpp b/backends/llamacpp/csrc/backend.hpp index 70f99268..ebae7fb0 100644 --- a/backends/llamacpp/csrc/backend.hpp +++ b/backends/llamacpp/csrc/backend.hpp @@ -29,8 +29,8 @@ namespace huggingface::tgi::backends::llamacpp { static constexpr auto llama_sampler_deleter = [](llama_sampler *pSampler) { llama_sampler_free(pSampler); }; typedef std::unique_ptr llama_sampler_ptr; - typedef std::function llama_decode_callback; - static constexpr auto llama_void_callback = [](llama_token, float_t, bool) {}; + typedef std::function llama_decode_callback; + static constexpr auto llama_void_callback = [](llama_token, float_t, bool, size_t) {}; /** * diff --git a/backends/llamacpp/csrc/ffi.hpp b/backends/llamacpp/csrc/ffi.hpp index 63f8d3b6..df924cb7 100644 --- a/backends/llamacpp/csrc/ffi.hpp +++ b/backends/llamacpp/csrc/ffi.hpp @@ -68,14 +68,14 @@ namespace huggingface::tgi::backends::llamacpp { const generation_params_t generation_params, const sampling_params_t &sampling_params, OpaqueStream *stream, - rust::Fn callback + rust::Fn callback ) { // Define the visitor lambda function which requires the has_emplace_generate constraint on T auto inner_fw = [=, &sampling_params, &stream, &callback](T &&backend) -> std::expected { - auto context_forwarding_callback = [=, &stream](uint32_t new_token_id, float_t logits, bool is_eos){ - callback(stream, new_token_id, logits, is_eos); + auto context_forwarding_callback = [=, &stream](uint32_t new_token_id, float_t logits, bool is_eos, size_t n_generated_tokens){ + callback(stream, new_token_id, logits, is_eos, n_generated_tokens); }; // Ask the compiler to create view over Rust slice transmuting from uint32_t* to int32_t* diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index bfdac34b..c3fff697 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -102,6 +102,7 @@ fn llama_generate_callback( new_token_id: u32, new_token_logit: f32, is_eos: bool, + n_generated_tokens: usize, ) { let response = InferStreamResponse::Intermediate { token: Token { @@ -112,7 +113,7 @@ fn llama_generate_callback( }, top_tokens: vec![], }; - debug!("Generated token: {new_token_id} -> logits={new_token_logit}, is_eos={is_eos}"); + info!("Generated token: {new_token_id} -> logits={new_token_logit}, is_eos={is_eos} ({n_generated_tokens})"); unsafe { if let Err(ref err) = (*channel).0.send(Ok(response)) { diff --git a/backends/llamacpp/src/lib.rs b/backends/llamacpp/src/lib.rs index f923526f..277f77cb 100644 --- a/backends/llamacpp/src/lib.rs +++ b/backends/llamacpp/src/lib.rs @@ -70,7 +70,7 @@ mod ffi { generation_params: GenerationParams, sampling_params: &SamplingParams, stream: *mut OpaqueStream, - callback: unsafe fn(*mut OpaqueStream, u32, f32, bool), + callback: unsafe fn(*mut OpaqueStream, u32, f32, bool, usize), ) -> Result; } }