feat(backend): add number of generated tokens in the callback
This commit is contained in:
parent
188442f67d
commit
05ff551950
|
@ -120,8 +120,8 @@ namespace huggingface::tgi::backends::llamacpp {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bubble up the generated token if a callback is provided
|
// Bubble up the generated token if a callback is provided
|
||||||
std::invoke(
|
std::invoke(std::forward<const llama_decode_callback>(callback_),
|
||||||
std::forward<const llama_decode_callback>(callback_), new_token_id, new_token_logits, is_eos);
|
new_token_id, new_token_logits, is_eos, n_decoded_tokens);
|
||||||
|
|
||||||
batch = llama_batch_get_one(&new_token_id, 1);
|
batch = llama_batch_get_one(&new_token_id, 1);
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,8 +29,8 @@ namespace huggingface::tgi::backends::llamacpp {
|
||||||
static constexpr auto llama_sampler_deleter = [](llama_sampler *pSampler) { llama_sampler_free(pSampler); };
|
static constexpr auto llama_sampler_deleter = [](llama_sampler *pSampler) { llama_sampler_free(pSampler); };
|
||||||
typedef std::unique_ptr<llama_sampler, decltype(llama_sampler_deleter)> llama_sampler_ptr;
|
typedef std::unique_ptr<llama_sampler, decltype(llama_sampler_deleter)> llama_sampler_ptr;
|
||||||
|
|
||||||
typedef std::function<void(llama_token, float_t, bool)> llama_decode_callback;
|
typedef std::function<void(llama_token, float_t, bool, size_t)> llama_decode_callback;
|
||||||
static constexpr auto llama_void_callback = [](llama_token, float_t, bool) {};
|
static constexpr auto llama_void_callback = [](llama_token, float_t, bool, size_t) {};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
|
|
@ -68,14 +68,14 @@ namespace huggingface::tgi::backends::llamacpp {
|
||||||
const generation_params_t generation_params,
|
const generation_params_t generation_params,
|
||||||
const sampling_params_t &sampling_params,
|
const sampling_params_t &sampling_params,
|
||||||
OpaqueStream *stream,
|
OpaqueStream *stream,
|
||||||
rust::Fn<void(OpaqueStream *, uint32_t, float_t, bool)> callback
|
rust::Fn<void(OpaqueStream *, uint32_t, float_t, bool, size_t)> callback
|
||||||
) {
|
) {
|
||||||
// Define the visitor lambda function which requires the has_emplace_generate constraint on T
|
// Define the visitor lambda function which requires the has_emplace_generate constraint on T
|
||||||
auto inner_fw = [=, &sampling_params, &stream, &callback]<has_emplace_generate T>(T &&backend)
|
auto inner_fw = [=, &sampling_params, &stream, &callback]<has_emplace_generate T>(T &&backend)
|
||||||
-> std::expected<size_t, backend_error_t> {
|
-> std::expected<size_t, backend_error_t> {
|
||||||
|
|
||||||
auto context_forwarding_callback = [=, &stream](uint32_t new_token_id, float_t logits, bool is_eos){
|
auto context_forwarding_callback = [=, &stream](uint32_t new_token_id, float_t logits, bool is_eos, size_t n_generated_tokens){
|
||||||
callback(stream, new_token_id, logits, is_eos);
|
callback(stream, new_token_id, logits, is_eos, n_generated_tokens);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Ask the compiler to create view over Rust slice transmuting from uint32_t* to int32_t*
|
// Ask the compiler to create view over Rust slice transmuting from uint32_t* to int32_t*
|
||||||
|
|
|
@ -102,6 +102,7 @@ fn llama_generate_callback(
|
||||||
new_token_id: u32,
|
new_token_id: u32,
|
||||||
new_token_logit: f32,
|
new_token_logit: f32,
|
||||||
is_eos: bool,
|
is_eos: bool,
|
||||||
|
n_generated_tokens: usize,
|
||||||
) {
|
) {
|
||||||
let response = InferStreamResponse::Intermediate {
|
let response = InferStreamResponse::Intermediate {
|
||||||
token: Token {
|
token: Token {
|
||||||
|
@ -112,7 +113,7 @@ fn llama_generate_callback(
|
||||||
},
|
},
|
||||||
top_tokens: vec![],
|
top_tokens: vec![],
|
||||||
};
|
};
|
||||||
debug!("Generated token: {new_token_id} -> logits={new_token_logit}, is_eos={is_eos}");
|
info!("Generated token: {new_token_id} -> logits={new_token_logit}, is_eos={is_eos} ({n_generated_tokens})");
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
if let Err(ref err) = (*channel).0.send(Ok(response)) {
|
if let Err(ref err) = (*channel).0.send(Ok(response)) {
|
||||||
|
|
|
@ -70,7 +70,7 @@ mod ffi {
|
||||||
generation_params: GenerationParams,
|
generation_params: GenerationParams,
|
||||||
sampling_params: &SamplingParams,
|
sampling_params: &SamplingParams,
|
||||||
stream: *mut OpaqueStream,
|
stream: *mut OpaqueStream,
|
||||||
callback: unsafe fn(*mut OpaqueStream, u32, f32, bool),
|
callback: unsafe fn(*mut OpaqueStream, u32, f32, bool, usize),
|
||||||
) -> Result<usize>;
|
) -> Result<usize>;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue