feat(backend): add early stopping criteria from TGI stream callback
This commit is contained in:
parent
958c72a44a
commit
1473259f84
|
@ -121,11 +121,12 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
generating = !(has_reach_max_tokens | has_reach_eog);
|
||||
|
||||
// Bubble up the generated token if a callback is provided
|
||||
std::invoke(std::forward<const llama_decode_callback>(callback_),
|
||||
new_token_id,
|
||||
new_token_logits,
|
||||
!generating,
|
||||
n_decoded_tokens + 1);
|
||||
const auto should_stop = std::invoke(std::forward<const llama_decode_callback>(callback_),
|
||||
new_token_id,
|
||||
new_token_logits,
|
||||
!generating,
|
||||
n_decoded_tokens + 1);
|
||||
generating ^= should_stop;
|
||||
|
||||
batch = llama_batch_get_one(&new_token_id, 1);
|
||||
}
|
||||
|
@ -148,11 +149,12 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
// TODO: Should we provide a way to change this value?
|
||||
auto generated = std::vector<llama_token>(2 << 8);
|
||||
auto inner_callback = [&](uint32_t new_token_id, float_t new_token_logit, bool is_eos,
|
||||
size_t num_generated_tokens) {
|
||||
size_t num_generated_tokens) -> bool {
|
||||
generated.emplace_back(new_token_id);
|
||||
|
||||
if (callback.has_value())
|
||||
(*callback)(new_token_id, new_token_logit, is_eos, num_generated_tokens);
|
||||
return (*callback)(new_token_id, new_token_logit, is_eos, num_generated_tokens);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto nTokensGenerated = stream(tokens, generation_params, sampling_params, inner_callback);
|
||||
|
|
|
@ -29,8 +29,8 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
static constexpr auto llama_sampler_deleter = [](llama_sampler *pSampler) { llama_sampler_free(pSampler); };
|
||||
typedef std::unique_ptr<llama_sampler, decltype(llama_sampler_deleter)> llama_sampler_ptr;
|
||||
|
||||
typedef std::function<void(llama_token, float_t, bool, size_t)> llama_decode_callback;
|
||||
static constexpr auto llama_void_callback = [](llama_token, float_t, bool, size_t) {};
|
||||
typedef std::function<bool(llama_token, float_t, bool, size_t)> llama_decode_callback;
|
||||
static constexpr auto llama_void_callback = [](llama_token, float_t, bool, size_t) -> bool { return false; };
|
||||
|
||||
/**
|
||||
*
|
||||
|
|
|
@ -64,14 +64,14 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
const generation_params_t generation_params,
|
||||
const sampling_params_t &sampling_params,
|
||||
InferContext *ctx,
|
||||
rust::Fn<void(InferContext *, uint32_t, float_t, bool, size_t)> callback
|
||||
rust::Fn<bool(InferContext *, uint32_t, float_t, bool, size_t)> callback
|
||||
) {
|
||||
// Define the visitor lambda function which requires the has_emplace_generate constraint on T
|
||||
auto inner_fw = [=, &sampling_params, &ctx, &callback]<has_stream_method T>(T &&backend)
|
||||
-> std::expected<size_t, backend_error_t> {
|
||||
|
||||
auto context_forwarding_callback = [=, &ctx](uint32_t new_token_id, float_t logits, bool is_eos, size_t n_generated_tokens){
|
||||
callback(ctx, new_token_id, logits, is_eos, n_generated_tokens);
|
||||
auto context_forwarding_callback = [=, &ctx](uint32_t new_token_id, float_t logits, bool is_eos, size_t n_generated_tokens) -> bool {
|
||||
return callback(ctx, new_token_id, logits, is_eos, n_generated_tokens);
|
||||
};
|
||||
|
||||
// Ask the compiler to create view over Rust slice transmuting from uint32_t* to int32_t*
|
||||
|
|
|
@ -13,11 +13,10 @@ use text_generation_router::validation::{
|
|||
};
|
||||
use text_generation_router::{FinishReason, Token};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{debug, error, info};
|
||||
use tracing::{error, info};
|
||||
|
||||
type InferResult = Result<InferStreamResponse, InferError>;
|
||||
|
||||
|
@ -45,7 +44,7 @@ impl From<&ValidStoppingParameters> for GenerationParams {
|
|||
}
|
||||
|
||||
#[cfg_attr(debug_assertions, derive(Debug))]
|
||||
struct GenerationContext {
|
||||
pub(crate) struct GenerationContext {
|
||||
pub(crate) input_tokens: Arc<Vec<u32>>,
|
||||
pub(crate) generated_tokens: Vec<u32>,
|
||||
pub(crate) generation_params: GenerationParams,
|
||||
|
@ -108,7 +107,7 @@ fn llama_generate_callback(
|
|||
new_token_logit: f32,
|
||||
is_final: bool,
|
||||
n_generated_tokens: usize,
|
||||
) {
|
||||
) -> bool {
|
||||
info!("Generated token: {new_token_id} -> logits={new_token_logit}, is_final={is_final} ({n_generated_tokens})");
|
||||
|
||||
// Decode token
|
||||
|
@ -151,10 +150,14 @@ fn llama_generate_callback(
|
|||
};
|
||||
|
||||
// Send back to the client
|
||||
if let Err(ref err) = ctx.stream.send(Ok(response)) {
|
||||
if let Err(ref _err) = ctx.stream.send(Ok(response)) {
|
||||
error!("Failed to send back the response to the client, cancelling request");
|
||||
// TODO: cancel the request
|
||||
return true; // should_stop
|
||||
}
|
||||
|
||||
// should_stop
|
||||
false
|
||||
}
|
||||
|
||||
unsafe fn scheduler_loop(
|
||||
|
|
|
@ -58,7 +58,7 @@ mod ffi {
|
|||
generation_params: GenerationParams,
|
||||
sampling_params: &SamplingParams,
|
||||
stream: *mut InferContext,
|
||||
callback: unsafe fn(*mut InferContext, u32, f32, bool, usize),
|
||||
callback: unsafe fn(*mut InferContext, u32, f32, bool, usize) -> bool,
|
||||
) -> Result<usize>;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue