wip
This commit is contained in:
parent
be2d38032a
commit
504754861f
|
@ -1,511 +1,84 @@
|
|||
/// Batching and inference logic
|
||||
use crate::infer::v3::queue::{Entry, Queue};
|
||||
use crate::infer::{
|
||||
GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler,
|
||||
};
|
||||
use crate::validation::ValidGenerateRequest;
|
||||
use crate::{FinishReason, PrefillToken, Token};
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
};
|
||||
use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient};
|
||||
use text_generation_client::ClientError;
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit};
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
use crate::infer::InferError;
|
||||
use crate::{ChatTemplateInputs, GrammarType, Message, MessageChunk, Text, TextMessage};
|
||||
use minijinja::{Environment, ErrorKind, Template};
|
||||
use minijinja_contrib::pycompat;
|
||||
|
||||
pub(crate) struct SchedulerV3 {
|
||||
/// Request queue
|
||||
queue: Queue,
|
||||
/// Notify batcher on queue appends
|
||||
batching_task_notifier: Arc<Notify>,
|
||||
/// Raise a exception (custom function) used in the chat templates
|
||||
pub(crate) fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
|
||||
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
|
||||
}
|
||||
|
||||
impl SchedulerV3 {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ChatTemplate {
|
||||
template: Template<'static, 'static>,
|
||||
bos_token: Option<String>,
|
||||
eos_token: Option<String>,
|
||||
use_default_tool_template: bool,
|
||||
}
|
||||
|
||||
impl ChatTemplate {
|
||||
pub(crate) fn new(
|
||||
client: ShardedClient,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
requires_padding: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
template: String,
|
||||
bos_token: Option<String>,
|
||||
eos_token: Option<String>,
|
||||
) -> Self {
|
||||
let queue = Queue::new(
|
||||
requires_padding,
|
||||
16,
|
||||
window_size,
|
||||
speculate,
|
||||
max_batch_total_tokens,
|
||||
);
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
let mut env = Box::new(Environment::new());
|
||||
// enable things like .strip() or .capitalize()
|
||||
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||
let template_str = template.into_boxed_str();
|
||||
env.add_function("raise_exception", raise_exception);
|
||||
|
||||
// Spawn batching background task that contains all the inference logic
|
||||
tokio::spawn(batching_task(
|
||||
client,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
queue.clone(),
|
||||
batching_task_notifier.clone(),
|
||||
generation_health,
|
||||
));
|
||||
// check if contains the tools variable within the template
|
||||
let use_default_tool_template =
|
||||
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
|
||||
// leaking env and template_str as read-only, static resources for performance.
|
||||
let template = Box::leak(env)
|
||||
.template_from_str(Box::leak(template_str))
|
||||
.unwrap();
|
||||
|
||||
Self {
|
||||
queue,
|
||||
batching_task_notifier,
|
||||
template,
|
||||
bos_token,
|
||||
eos_token,
|
||||
use_default_tool_template,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Scheduler for SchedulerV3 {
|
||||
#[instrument(skip_all)]
|
||||
fn schedule(
|
||||
pub(crate) fn apply(
|
||||
&self,
|
||||
request: ValidGenerateRequest,
|
||||
permit: OwnedSemaphorePermit,
|
||||
) -> Result<GenerateStreamResponse, InferError> {
|
||||
// MPSC channel to communicate with the background batching task
|
||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||
let input_length = request.input_length;
|
||||
|
||||
// Append the request to the queue
|
||||
self.queue.append(Entry {
|
||||
request,
|
||||
response_tx,
|
||||
span: Span::current(),
|
||||
temp_span: None,
|
||||
queue_time: Instant::now(),
|
||||
batch_time: None,
|
||||
block_allocation: None,
|
||||
});
|
||||
|
||||
// Notify the background task that we have a new entry in the queue that needs
|
||||
// to be batched
|
||||
self.batching_task_notifier.notify_one();
|
||||
|
||||
// Return stream
|
||||
Ok((
|
||||
permit,
|
||||
input_length,
|
||||
UnboundedReceiverStream::new(response_rx),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Batching logic
|
||||
/// Will be launched in a background Tokio task
|
||||
///
|
||||
/// Batches requests and sends them to the inference server
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn batching_task(
|
||||
mut client: ShardedClient,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
queue: Queue,
|
||||
notifier: Arc<Notify>,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
) {
|
||||
// Infinite loop
|
||||
loop {
|
||||
// Wait for a notification from the Infer struct
|
||||
notifier.notified().await;
|
||||
|
||||
// Get the next batch from the queue
|
||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||
// waiting in the queue
|
||||
while let Some((mut entries, batch, span)) = queue
|
||||
.next_batch(
|
||||
None,
|
||||
max_batch_size,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
)
|
||||
.await
|
||||
{
|
||||
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
|
||||
.instrument(span)
|
||||
.await;
|
||||
let mut waiting_tokens = 1;
|
||||
|
||||
// We loop until we do not receive any cached batch from the inference server (== until
|
||||
// all requests have met their stopping criteria)
|
||||
while let Some(batch) = cached_batch {
|
||||
// Get current batch info
|
||||
let batch_size = batch.size;
|
||||
let batch_max_tokens = batch.max_tokens;
|
||||
let mut batches = vec![batch];
|
||||
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64);
|
||||
|
||||
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||
// to add a new batch even though its size might be small
|
||||
None
|
||||
} else {
|
||||
// Minimum batch size
|
||||
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||
};
|
||||
|
||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||
let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize);
|
||||
|
||||
// Try to get a new batch
|
||||
if let Some((mut new_entries, new_batch, span)) = queue
|
||||
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
||||
.await
|
||||
{
|
||||
// Tracking metrics
|
||||
if min_size.is_some() {
|
||||
metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure");
|
||||
} else {
|
||||
metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded");
|
||||
}
|
||||
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
// Create a new span to add the info that this entry is waiting
|
||||
// because a new batch is being computed
|
||||
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
|
||||
// Add relationships
|
||||
span.follows_from(&entry_waiting_span);
|
||||
entry_waiting_span.follows_from(&span);
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_waiting_span);
|
||||
});
|
||||
|
||||
// Generate one token for this new batch to have the attention past in cache
|
||||
let new_cached_batch =
|
||||
prefill(&mut client, new_batch, &mut new_entries, &generation_health)
|
||||
.instrument(span)
|
||||
.await;
|
||||
// Reset waiting counter
|
||||
waiting_tokens = 1;
|
||||
// Extend current batch with the new batch
|
||||
if let Some(new_cached_batch) = new_cached_batch {
|
||||
entries.extend(new_entries);
|
||||
batches.push(new_cached_batch);
|
||||
}
|
||||
mut messages: Vec<Message>,
|
||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
||||
) -> Result<String, InferError> {
|
||||
if self.use_default_tool_template {
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
||||
last_message.content.push(MessageChunk::Text(Text {
|
||||
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
||||
}));
|
||||
}
|
||||
|
||||
// Create span for this batch to add context to inference calls
|
||||
let next_batch_size = entries.len();
|
||||
let next_batch_span =
|
||||
info_span!(parent: None, "batch", batch_size = next_batch_size);
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||
// Add relationships
|
||||
next_batch_span.follows_from(&entry_batch_span);
|
||||
entry_batch_span.follows_from(&next_batch_span);
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
});
|
||||
|
||||
cached_batch = decode(&mut client, batches, &mut entries, &generation_health)
|
||||
.instrument(next_batch_span)
|
||||
.await;
|
||||
waiting_tokens += 1;
|
||||
}
|
||||
metrics::gauge!("tgi_batch_current_size", 0.0);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens", 0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn prefill(
|
||||
client: &mut ShardedClient,
|
||||
batch: Batch,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
generation_health: &Arc<AtomicBool>,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_id = batch.id;
|
||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
|
||||
|
||||
match client.prefill(batch).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
// Update health
|
||||
generation_health.store(true, Ordering::SeqCst);
|
||||
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
|
||||
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
|
||||
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill");
|
||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
|
||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
// Update health
|
||||
generation_health.store(false, Ordering::SeqCst);
|
||||
let _ = client.clear_cache(Some(batch_id)).await;
|
||||
send_errors(err, entries);
|
||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn decode(
|
||||
client: &mut ShardedClient,
|
||||
batches: Vec<CachedBatch>,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
generation_health: &Arc<AtomicBool>,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
|
||||
|
||||
match client.decode(batches).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
// Update health
|
||||
generation_health.store(true, Ordering::SeqCst);
|
||||
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
if let Some(concat_duration) = timings.concat {
|
||||
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
|
||||
}
|
||||
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode");
|
||||
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode");
|
||||
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode");
|
||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
|
||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
generation_health.store(false, Ordering::SeqCst);
|
||||
for id in batch_ids {
|
||||
let _ = client.clear_cache(Some(id)).await;
|
||||
}
|
||||
send_errors(err, entries);
|
||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Filter a `batch` and remove all requests not present in `entries`
|
||||
#[instrument(skip_all)]
|
||||
async fn filter_batch(
|
||||
client: &mut ShardedClient,
|
||||
next_batch: Option<CachedBatch>,
|
||||
entries: &IntMap<u64, Entry>,
|
||||
) -> Option<CachedBatch> {
|
||||
let mut batch = next_batch?;
|
||||
|
||||
// No need to filter
|
||||
if batch.size as usize == entries.len() {
|
||||
return Some(batch);
|
||||
}
|
||||
|
||||
let id = batch.id;
|
||||
|
||||
// Retain only requests that are still in entries
|
||||
batch.request_ids.retain(|id| entries.contains_key(id));
|
||||
|
||||
if batch.request_ids.is_empty() {
|
||||
// All requests have been filtered out
|
||||
// Next batch is now empty
|
||||
// Clear it from the Python shards cache
|
||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||
client.clear_cache(Some(id)).await.unwrap();
|
||||
None
|
||||
} else {
|
||||
// Filter Python shard cache
|
||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||
client.filter_batch(id, batch.request_ids).await.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
||||
/// and filter entries
|
||||
#[instrument(skip_all)]
|
||||
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
||||
generations.into_iter().for_each(|generation| {
|
||||
let id = generation.request_id;
|
||||
// Get entry
|
||||
// We can `expect` here as the request id should always be in the entries
|
||||
let entry = entries
|
||||
.get(&id)
|
||||
.expect("ID not found in entries. This is a bug.");
|
||||
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
||||
// Send generation responses back to the infer task
|
||||
// If the receive an error from the Flume channel, it means that the client dropped the
|
||||
// request and we need to stop generating hence why we unwrap_or(true)
|
||||
let stopped = send_responses(generation, entry).map_err(|err| {
|
||||
tracing::error!("Entry response channel error.");
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
err
|
||||
}).unwrap_or(true);
|
||||
if stopped {
|
||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Send responses through the `entry` response channel
|
||||
fn send_responses(
|
||||
generation: Generation,
|
||||
entry: &Entry,
|
||||
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||
// Return directly if the channel is disconnected
|
||||
if entry.response_tx.is_closed() {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
let mut stopped = false;
|
||||
|
||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||
// Create Token objects
|
||||
// We do that here instead of in the Python code as Rust for loops are faster
|
||||
let prefill_tokens = prefill_tokens
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(prefill_tokens.logprobs)
|
||||
.zip(prefill_tokens.texts)
|
||||
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
||||
.collect();
|
||||
|
||||
// Send message
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
|
||||
}
|
||||
|
||||
// Create last Token
|
||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||
let n = tokens_.ids.len();
|
||||
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
|
||||
let mut iterator = tokens_
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(tokens_.logprobs)
|
||||
.zip(tokens_.texts)
|
||||
.zip(tokens_.is_special)
|
||||
.enumerate()
|
||||
.peekable();
|
||||
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
|
||||
let token = Token {
|
||||
id,
|
||||
text,
|
||||
logprob,
|
||||
special,
|
||||
};
|
||||
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
|
||||
top_tokens_
|
||||
.ids
|
||||
.iter()
|
||||
.zip(top_tokens_.logprobs.iter())
|
||||
.zip(top_tokens_.texts.iter())
|
||||
.zip(top_tokens_.is_special.iter())
|
||||
.map(|(((&id, &logprob), text), &special)| Token {
|
||||
id,
|
||||
text: text.to_string(),
|
||||
logprob,
|
||||
special,
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
match (&generation.generated_text, iterator.peek()) {
|
||||
(Some(generated_text), None) => {
|
||||
// Generation has ended
|
||||
stopped = true;
|
||||
// Send message
|
||||
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||
token,
|
||||
top_tokens,
|
||||
generated_text: GeneratedText::from(generated_text.clone()),
|
||||
queued: entry.queue_time,
|
||||
start: entry.batch_time.unwrap(),
|
||||
}))?;
|
||||
}
|
||||
_ => {
|
||||
// Send message
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(stopped)
|
||||
}
|
||||
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
||||
|
||||
/// Send errors to Infer for all `entries`
|
||||
#[instrument(skip_all)]
|
||||
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||
entries.drain().for_each(|(_, entry)| {
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||
let err = InferError::GenerationError(error.to_string());
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
|
||||
tracing::error!("{err}");
|
||||
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
entry
|
||||
.response_tx
|
||||
.send(Err(err))
|
||||
.unwrap_or(());
|
||||
});
|
||||
}
|
||||
|
||||
impl From<text_generation_client::v3::GeneratedText> for GeneratedText {
|
||||
fn from(value: text_generation_client::v3::GeneratedText) -> Self {
|
||||
let v3_finish_reason =
|
||||
text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap();
|
||||
let finish_reason = match v3_finish_reason {
|
||||
text_generation_client::v3::FinishReason::Length => FinishReason::Length,
|
||||
text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||
text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||
};
|
||||
|
||||
Self {
|
||||
text: value.text,
|
||||
generated_tokens: value.generated_tokens,
|
||||
finish_reason,
|
||||
seed: value.seed,
|
||||
}
|
||||
self.template
|
||||
.render(ChatTemplateInputs {
|
||||
messages,
|
||||
bos_token: self.bos_token.as_deref(),
|
||||
eos_token: self.eos_token.as_deref(),
|
||||
add_generation_prompt: true,
|
||||
tools: None,
|
||||
tools_prompt: None,
|
||||
})
|
||||
.map_err(InferError::TemplateError)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// tests
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::infer::raise_exception;
|
||||
use crate::infer::chat_template::raise_exception;
|
||||
use crate::{ChatTemplateInputs, TextMessage};
|
||||
use minijinja::Environment;
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::Health;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct HealthCheck {
|
||||
client: Arc<dyn Health + Send + Sync>,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl HealthCheck {
|
||||
pub(crate) fn new(
|
||||
client: Arc<dyn Health + Send + Sync>,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
) -> Self {
|
||||
Self {
|
||||
client,
|
||||
generation_health,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn check(&mut self) -> bool {
|
||||
let value = if self.generation_health.load(Ordering::SeqCst) {
|
||||
// Generation is healthy, we only check that the shards can allocate on device
|
||||
self.client.device_health().await
|
||||
} else {
|
||||
self.client.model_health().await
|
||||
}
|
||||
.is_ok();
|
||||
// Update generation health
|
||||
self.generation_health.store(value, Ordering::SeqCst);
|
||||
value
|
||||
}
|
||||
}
|
|
@ -1,6 +1,8 @@
|
|||
mod health;
|
||||
pub(crate) mod v2;
|
||||
pub(crate) mod v3;
|
||||
mod chat_template;
|
||||
mod tool_grammar;
|
||||
|
||||
pub(crate) use health::HealthCheck;
|
||||
|
||||
|
@ -23,6 +25,7 @@ use tokio::time::Instant;
|
|||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::instrument;
|
||||
use chat_template::ChatTemplate;
|
||||
|
||||
pub(crate) trait Scheduler {
|
||||
fn schedule(
|
||||
|
@ -37,18 +40,20 @@ pub(crate) trait Scheduler {
|
|||
pub struct Infer {
|
||||
/// Validation
|
||||
validation: Validation,
|
||||
/// Request scheduler
|
||||
scheduler: Arc<dyn Scheduler + Send + Sync>,
|
||||
/// Request backend
|
||||
backend: Arc<dyn Backend + Send + Sync>,
|
||||
/// Chat template
|
||||
chat_template: Option<ChatTemplate>,
|
||||
/// Inference limit
|
||||
limit_concurrent_requests: Arc<Semaphore>,
|
||||
/// Backend health
|
||||
backend_health: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl Infer {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn new(
|
||||
scheduler: Arc<dyn Scheduler + Send + Sync>,
|
||||
backend: impl Backend + Send + Sync + 'static,
|
||||
validation: Validation,
|
||||
max_concurrent_requests: usize,
|
||||
tokenizer_config: HubTokenizerConfig,
|
||||
|
@ -69,20 +74,31 @@ impl Infer {
|
|||
// Inference limit with a semaphore
|
||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||
|
||||
// Backend health
|
||||
let backend_health = Arc::new(AtomicBool::new(false));
|
||||
|
||||
Self {
|
||||
validation,
|
||||
scheduler,
|
||||
backend: Arc::new(backend),
|
||||
chat_template,
|
||||
limit_concurrent_requests: semaphore,
|
||||
backend_health,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a new request to the queue and return a stream of InferStreamResponse
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) async fn generate_stream(
|
||||
&self,
|
||||
pub(crate) async fn generate_stream<'a>(
|
||||
&'a self,
|
||||
request: GenerateRequest,
|
||||
) -> Result<GenerateStreamResponse, InferError> {
|
||||
) -> Result<
|
||||
(
|
||||
OwnedSemaphorePermit,
|
||||
u32, // input_length
|
||||
impl Stream<Item=Result<InferStreamResponse, InferError>> + 'a,
|
||||
),
|
||||
InferError,
|
||||
> {
|
||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||
let permit = self
|
||||
.clone()
|
||||
|
@ -101,7 +117,36 @@ impl Infer {
|
|||
err
|
||||
})?;
|
||||
|
||||
self.scheduler.schedule(valid_request, permit)
|
||||
let input_length = valid_request.input_length;
|
||||
let mut generation_stream = self
|
||||
.backend
|
||||
.schedule(valid_request)
|
||||
.map_err(InferError::Backend)?;
|
||||
|
||||
let stream = stream! {
|
||||
while let Some(generation) = generation_stream.next().await {
|
||||
self.backend_health.store(generation.is_ok(), Ordering::SeqCst);
|
||||
|
||||
yield generation.map(|generation| match generation {
|
||||
types::TokenStreamResponse::Prefill(prefill_tokens) => InferStreamResponse::Prefill(
|
||||
prefill_tokens.into_iter().map(PrefillToken::from).collect()
|
||||
),
|
||||
types::TokenStreamResponse::Intermediate { token, top_tokens } => InferStreamResponse::Intermediate {
|
||||
token: Token::from(token),
|
||||
top_tokens: top_tokens.into_iter().map(Token::from).collect(),
|
||||
},
|
||||
types::TokenStreamResponse::End { token, top_tokens, generated_text, start, queued } => InferStreamResponse::End {
|
||||
token: Token::from(token),
|
||||
top_tokens: top_tokens.into_iter().map(Token::from).collect(),
|
||||
generated_text,
|
||||
start,
|
||||
queued,
|
||||
}
|
||||
}).map_err(InferError::GenerationError)
|
||||
}
|
||||
};
|
||||
|
||||
Ok((permit, input_length, stream))
|
||||
}
|
||||
|
||||
/// Tokenizer the input
|
||||
|
@ -153,7 +198,7 @@ impl Infer {
|
|||
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
|
||||
|
||||
// Create stream and keep semaphore permit as long as generate lives
|
||||
let (_permit, _input_length, mut stream) = self.generate_stream(request).await?;
|
||||
let (_permit, _input_length, stream) = self.generate_stream(request).await?;
|
||||
|
||||
// Return values
|
||||
let mut result_prefill = Vec::new();
|
||||
|
@ -163,6 +208,8 @@ impl Infer {
|
|||
let mut result_start = None;
|
||||
let mut result_queued = None;
|
||||
|
||||
let mut stream = Box::pin(stream);
|
||||
|
||||
// Iterate on stream
|
||||
while let Some(response) = stream.next().await {
|
||||
match response? {
|
||||
|
@ -254,190 +301,15 @@ impl Infer {
|
|||
let best_response = infer_responses.remove(max_index);
|
||||
Ok((best_response, infer_responses))
|
||||
}
|
||||
}
|
||||
|
||||
/// Raise a exception (custom function) used in the chat templates
|
||||
fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
|
||||
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ChatTemplate {
|
||||
template: Template<'static, 'static>,
|
||||
bos_token: Option<String>,
|
||||
eos_token: Option<String>,
|
||||
use_default_tool_template: bool,
|
||||
}
|
||||
|
||||
impl ChatTemplate {
|
||||
fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
|
||||
let mut env = Box::new(Environment::new());
|
||||
// enable things like .strip() or .capitalize()
|
||||
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||
let template_str = template.into_boxed_str();
|
||||
env.add_function("raise_exception", raise_exception);
|
||||
|
||||
// check if contains the tools variable within the template
|
||||
let use_default_tool_template =
|
||||
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
|
||||
// leaking env and template_str as read-only, static resources for performance.
|
||||
let template = Box::leak(env)
|
||||
.template_from_str(Box::leak(template_str))
|
||||
.unwrap();
|
||||
|
||||
Self {
|
||||
template,
|
||||
bos_token,
|
||||
eos_token,
|
||||
use_default_tool_template,
|
||||
}
|
||||
}
|
||||
|
||||
fn apply(
|
||||
&self,
|
||||
mut messages: Vec<Message>,
|
||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
||||
) -> Result<String, InferError> {
|
||||
if self.use_default_tool_template {
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
||||
last_message.content.push(MessageChunk::Text(Text {
|
||||
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
||||
|
||||
self.template
|
||||
.render(ChatTemplateInputs {
|
||||
messages,
|
||||
bos_token: self.bos_token.as_deref(),
|
||||
eos_token: self.eos_token.as_deref(),
|
||||
add_generation_prompt: true,
|
||||
tools: None,
|
||||
tools_prompt: None,
|
||||
})
|
||||
.map_err(InferError::TemplateError)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ToolGrammar {}
|
||||
|
||||
impl ToolGrammar {
|
||||
pub fn apply(
|
||||
tools: Option<Vec<Tool>>,
|
||||
tool_choice: Option<ToolType>,
|
||||
) -> Result<Option<Tools>, InferError> {
|
||||
if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) {
|
||||
// let tool_prompt = tool_prompt.unwrap_or_default();
|
||||
let tools_to_use = match tool_choice {
|
||||
ToolType::FunctionName(name) => {
|
||||
vec![req_tools
|
||||
.iter()
|
||||
.find(|tool| tool.function.name == *name)
|
||||
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
|
||||
.clone()]
|
||||
}
|
||||
ToolType::OneOf => req_tools.to_owned(),
|
||||
};
|
||||
|
||||
// adds the error notification function for LLM feedback if required
|
||||
let mut text_response_properties = Map::new();
|
||||
text_response_properties.insert(
|
||||
"error".to_string(),
|
||||
serde_json::json!({
|
||||
"type": "string",
|
||||
"description": "The error or issue to notify"
|
||||
}),
|
||||
);
|
||||
text_response_properties.insert(
|
||||
"_name".to_string(),
|
||||
serde_json::json!({
|
||||
"type": "string",
|
||||
"const": "notify_error"
|
||||
}),
|
||||
);
|
||||
|
||||
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
let func = tool.function.clone();
|
||||
|
||||
// Clone the existing parameters, which are expected to be a JSON object
|
||||
let mut params = if let Value::Object(params) = &func.arguments {
|
||||
params.clone()
|
||||
} else {
|
||||
Map::new()
|
||||
};
|
||||
|
||||
// Insert the function's description at the top level, outside of properties
|
||||
params.insert(
|
||||
"description".to_string(),
|
||||
Value::String(func.description.clone().unwrap_or_default()),
|
||||
);
|
||||
|
||||
// Ensure 'properties' exists and is an object
|
||||
let properties = params
|
||||
.entry("properties".to_string())
|
||||
.or_insert_with(|| json!({}))
|
||||
.as_object_mut()
|
||||
.unwrap();
|
||||
|
||||
// Insert the constant for the function name inside 'properties'
|
||||
properties.insert(
|
||||
"_name".to_string(),
|
||||
json!({
|
||||
"type": "string",
|
||||
"const": func.name.clone(),
|
||||
// "description": "The name of the function"
|
||||
}),
|
||||
);
|
||||
|
||||
// Check if 'required' exists, and it is an array. If not, create an empty array.
|
||||
let required = params
|
||||
.entry("required".to_string())
|
||||
.or_insert_with(|| json!([]))
|
||||
.as_array_mut()
|
||||
.unwrap();
|
||||
|
||||
// Add 'name' to the 'required' array if it is not already present
|
||||
if !required.iter().any(|r| r == "_name") {
|
||||
required.push(json!("_name"));
|
||||
}
|
||||
|
||||
(func.name, Value::Object(params))
|
||||
})
|
||||
.chain([(
|
||||
"notify_error".to_string(),
|
||||
serde_json::json!({
|
||||
"properties": text_response_properties,
|
||||
"required": ["error", "_name"],
|
||||
"type": "object"
|
||||
}),
|
||||
)])
|
||||
.collect();
|
||||
|
||||
let tools = Tools {
|
||||
functions_map: FunctionsMap { functions },
|
||||
properties: Properties {
|
||||
function: tools_to_use
|
||||
.iter()
|
||||
.map(|tool| FunctionRef {
|
||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||
})
|
||||
.chain(std::iter::once(FunctionRef {
|
||||
ref_path: "#/$functions/notify_error".to_string(),
|
||||
}))
|
||||
.collect(),
|
||||
},
|
||||
};
|
||||
|
||||
return Ok(Some(tools));
|
||||
}
|
||||
// Err(InferError::ToolError("No tools provided".to_string()))
|
||||
Ok(None)
|
||||
#[instrument(skip(self))]
|
||||
pub(crate) async fn health(&self) -> bool {
|
||||
let health = self
|
||||
.backend
|
||||
.health(self.backend_health.load(Ordering::SeqCst))
|
||||
.await;
|
||||
self.backend_health.store(health, Ordering::SeqCst);
|
||||
health
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -491,8 +363,10 @@ pub(crate) struct InferResponse {
|
|||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum InferError {
|
||||
#[error("Request failed during scheduling: {0}")]
|
||||
Backend(BackendError),
|
||||
#[error("Request failed during generation: {0}")]
|
||||
GenerationError(String),
|
||||
GenerationError(BackendError),
|
||||
#[error("Model is overloaded")]
|
||||
Overloaded(#[from] TryAcquireError),
|
||||
#[error("Input validation error: {0}")]
|
||||
|
@ -508,6 +382,7 @@ pub enum InferError {
|
|||
impl InferError {
|
||||
pub(crate) fn error_type(&self) -> &str {
|
||||
match self {
|
||||
InferError::Backend(_) => "backend",
|
||||
InferError::GenerationError(_) => "generation",
|
||||
InferError::Overloaded(_) => "overloaded",
|
||||
InferError::ValidationError(_) => "validation",
|
||||
|
|
|
@ -0,0 +1,122 @@
|
|||
use crate::infer::InferError;
|
||||
use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolType, Tools};
|
||||
use serde_json::{json, Map, Value};
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub(crate) struct ToolGrammar {}
|
||||
|
||||
impl ToolGrammar {
|
||||
pub fn apply(
|
||||
tools: Option<Vec<Tool>>,
|
||||
tool_choice: Option<ToolType>,
|
||||
) -> Result<Option<Tools>, InferError> {
|
||||
if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) {
|
||||
// let tool_prompt = tool_prompt.unwrap_or_default();
|
||||
let tools_to_use = match tool_choice {
|
||||
ToolType::FunctionName(name) => {
|
||||
vec![req_tools
|
||||
.iter()
|
||||
.find(|tool| tool.function.name == *name)
|
||||
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
|
||||
.clone()]
|
||||
}
|
||||
ToolType::OneOf => req_tools.to_owned(),
|
||||
};
|
||||
|
||||
// adds the error notification function for LLM feedback if required
|
||||
let mut text_response_properties = Map::new();
|
||||
text_response_properties.insert(
|
||||
"error".to_string(),
|
||||
json!({
|
||||
"type": "string",
|
||||
"description": "The error or issue to notify"
|
||||
}),
|
||||
);
|
||||
text_response_properties.insert(
|
||||
"_name".to_string(),
|
||||
json!({
|
||||
"type": "string",
|
||||
"const": "notify_error"
|
||||
}),
|
||||
);
|
||||
|
||||
let functions: HashMap<String, Value> = tools_to_use
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
let func = tool.function.clone();
|
||||
|
||||
// Clone the existing parameters, which are expected to be a JSON object
|
||||
let mut params = if let Value::Object(params) = &func.arguments {
|
||||
params.clone()
|
||||
} else {
|
||||
Map::new()
|
||||
};
|
||||
|
||||
// Insert the function's description at the top level, outside of properties
|
||||
params.insert(
|
||||
"description".to_string(),
|
||||
Value::String(func.description.clone().unwrap_or_default()),
|
||||
);
|
||||
|
||||
// Ensure 'properties' exists and is an object
|
||||
let properties = params
|
||||
.entry("properties".to_string())
|
||||
.or_insert_with(|| json!({}))
|
||||
.as_object_mut()
|
||||
.unwrap();
|
||||
|
||||
// Insert the constant for the function name inside 'properties'
|
||||
properties.insert(
|
||||
"_name".to_string(),
|
||||
json!({
|
||||
"type": "string",
|
||||
"const": func.name.clone(),
|
||||
// "description": "The name of the function"
|
||||
}),
|
||||
);
|
||||
|
||||
// Check if 'required' exists, and it is an array. If not, create an empty array.
|
||||
let required = params
|
||||
.entry("required".to_string())
|
||||
.or_insert_with(|| json!([]))
|
||||
.as_array_mut()
|
||||
.unwrap();
|
||||
|
||||
// Add 'name' to the 'required' array if it is not already present
|
||||
if !required.iter().any(|r| r == "_name") {
|
||||
required.push(json!("_name"));
|
||||
}
|
||||
|
||||
(func.name, Value::Object(params))
|
||||
})
|
||||
.chain([(
|
||||
"notify_error".to_string(),
|
||||
json!({
|
||||
"properties": text_response_properties,
|
||||
"required": ["error", "_name"],
|
||||
"type": "object"
|
||||
}),
|
||||
)])
|
||||
.collect();
|
||||
|
||||
let tools = Tools {
|
||||
functions_map: FunctionsMap { functions },
|
||||
properties: Properties {
|
||||
function: tools_to_use
|
||||
.iter()
|
||||
.map(|tool| FunctionRef {
|
||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||
})
|
||||
.chain(std::iter::once(FunctionRef {
|
||||
ref_path: "#/$functions/notify_error".to_string(),
|
||||
}))
|
||||
.collect(),
|
||||
},
|
||||
};
|
||||
|
||||
return Ok(Some(tools));
|
||||
}
|
||||
// Err(InferError::ToolError("No tools provided".to_string()))
|
||||
Ok(None)
|
||||
}
|
||||
}
|
|
@ -121,12 +121,10 @@ responses(
|
|||
example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
|
||||
)
|
||||
)]
|
||||
#[instrument(skip(health))]
|
||||
#[instrument(skip(infer))]
|
||||
/// Health check method
|
||||
async fn health(
|
||||
mut health: Extension<HealthCheck>,
|
||||
) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||
match health.check().await {
|
||||
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||
match infer.health().await {
|
||||
true => Ok(()),
|
||||
false => Err((
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
|
@ -437,8 +435,9 @@ async fn generate_stream_internal(
|
|||
} else {
|
||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||
// Keep permit as long as generate_stream lives
|
||||
Ok((_permit, _input_length, mut response_stream)) => {
|
||||
Ok((_permit, _input_length, response_stream)) => {
|
||||
let mut index = 0;
|
||||
let mut response_stream = Box::pin(response_stream);
|
||||
// Server-Sent Event stream
|
||||
while let Some(response) = response_stream.next().await {
|
||||
index += 1;
|
||||
|
@ -1960,16 +1959,8 @@ impl From<InferError> for Event {
|
|||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum WebServerError {
|
||||
#[error("Unable to connect to the Python model shards: {0}")]
|
||||
Connection(ClientError),
|
||||
#[error("Unable to clear the Python model shards cache: {0}")]
|
||||
Cache(ClientError),
|
||||
#[error("Unable to get the Python model shards info: {0}")]
|
||||
Info(ClientError),
|
||||
#[error("Unable to warmup the Python model shards: {0}")]
|
||||
Warmup(ClientError),
|
||||
#[error("Not enough memory to handle `max_total_tokens={0}`")]
|
||||
NotEnoughMemory(usize),
|
||||
#[error("Backend error: {0}")]
|
||||
Backend(#[from] BackendError),
|
||||
#[error("Axum error: {0}")]
|
||||
Axum(#[from] axum::BoxError),
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue