feat(backend): add early stopping criteria from TGI stream callback

This commit is contained in:
Morgan Funtowicz 2024-11-04 17:01:22 +01:00
parent 958c72a44a
commit 1473259f84
5 changed files with 23 additions and 18 deletions

View File

@ -121,11 +121,12 @@ namespace huggingface::tgi::backends::llamacpp {
generating = !(has_reach_max_tokens | has_reach_eog);
// Bubble up the generated token if a callback is provided
std::invoke(std::forward<const llama_decode_callback>(callback_),
new_token_id,
new_token_logits,
!generating,
n_decoded_tokens + 1);
const auto should_stop = std::invoke(std::forward<const llama_decode_callback>(callback_),
new_token_id,
new_token_logits,
!generating,
n_decoded_tokens + 1);
generating ^= should_stop;
batch = llama_batch_get_one(&new_token_id, 1);
}
@ -148,11 +149,12 @@ namespace huggingface::tgi::backends::llamacpp {
// TODO: Should we provide a way to change this value?
auto generated = std::vector<llama_token>(2 << 8);
auto inner_callback = [&](uint32_t new_token_id, float_t new_token_logit, bool is_eos,
size_t num_generated_tokens) {
size_t num_generated_tokens) -> bool {
generated.emplace_back(new_token_id);
if (callback.has_value())
(*callback)(new_token_id, new_token_logit, is_eos, num_generated_tokens);
return (*callback)(new_token_id, new_token_logit, is_eos, num_generated_tokens);
return true;
};
auto nTokensGenerated = stream(tokens, generation_params, sampling_params, inner_callback);

View File

@ -29,8 +29,8 @@ namespace huggingface::tgi::backends::llamacpp {
static constexpr auto llama_sampler_deleter = [](llama_sampler *pSampler) { llama_sampler_free(pSampler); };
typedef std::unique_ptr<llama_sampler, decltype(llama_sampler_deleter)> llama_sampler_ptr;
typedef std::function<void(llama_token, float_t, bool, size_t)> llama_decode_callback;
static constexpr auto llama_void_callback = [](llama_token, float_t, bool, size_t) {};
typedef std::function<bool(llama_token, float_t, bool, size_t)> llama_decode_callback;
static constexpr auto llama_void_callback = [](llama_token, float_t, bool, size_t) -> bool { return false; };
/**
*

View File

@ -64,14 +64,14 @@ namespace huggingface::tgi::backends::llamacpp {
const generation_params_t generation_params,
const sampling_params_t &sampling_params,
InferContext *ctx,
rust::Fn<void(InferContext *, uint32_t, float_t, bool, size_t)> callback
rust::Fn<bool(InferContext *, uint32_t, float_t, bool, size_t)> callback
) {
// Define the visitor lambda function which requires the has_emplace_generate constraint on T
auto inner_fw = [=, &sampling_params, &ctx, &callback]<has_stream_method T>(T &&backend)
-> std::expected<size_t, backend_error_t> {
auto context_forwarding_callback = [=, &ctx](uint32_t new_token_id, float_t logits, bool is_eos, size_t n_generated_tokens){
callback(ctx, new_token_id, logits, is_eos, n_generated_tokens);
auto context_forwarding_callback = [=, &ctx](uint32_t new_token_id, float_t logits, bool is_eos, size_t n_generated_tokens) -> bool {
return callback(ctx, new_token_id, logits, is_eos, n_generated_tokens);
};
// Ask the compiler to create view over Rust slice transmuting from uint32_t* to int32_t*

View File

@ -13,11 +13,10 @@ use text_generation_router::validation::{
};
use text_generation_router::{FinishReason, Token};
use thiserror::Error;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, info};
use tracing::{error, info};
type InferResult = Result<InferStreamResponse, InferError>;
@ -45,7 +44,7 @@ impl From<&ValidStoppingParameters> for GenerationParams {
}
#[cfg_attr(debug_assertions, derive(Debug))]
struct GenerationContext {
pub(crate) struct GenerationContext {
pub(crate) input_tokens: Arc<Vec<u32>>,
pub(crate) generated_tokens: Vec<u32>,
pub(crate) generation_params: GenerationParams,
@ -108,7 +107,7 @@ fn llama_generate_callback(
new_token_logit: f32,
is_final: bool,
n_generated_tokens: usize,
) {
) -> bool {
info!("Generated token: {new_token_id} -> logits={new_token_logit}, is_final={is_final} ({n_generated_tokens})");
// Decode token
@ -151,10 +150,14 @@ fn llama_generate_callback(
};
// Send back to the client
if let Err(ref err) = ctx.stream.send(Ok(response)) {
if let Err(ref _err) = ctx.stream.send(Ok(response)) {
error!("Failed to send back the response to the client, cancelling request");
// TODO: cancel the request
return true; // should_stop
}
// should_stop
false
}
unsafe fn scheduler_loop(

View File

@ -58,7 +58,7 @@ mod ffi {
generation_params: GenerationParams,
sampling_params: &SamplingParams,
stream: *mut InferContext,
callback: unsafe fn(*mut InferContext, u32, f32, bool, usize),
callback: unsafe fn(*mut InferContext, u32, f32, bool, usize) -> bool,
) -> Result<usize>;
}
}