From 504754861fc21796362d57d7c54594aa8578b95b Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 26 Jun 2024 12:08:56 +0200 Subject: [PATCH] wip --- .../{v3/scheduler.rs => chat_template.rs} | 543 ++---------------- router/src/infer/health.rs | 34 -- router/src/infer/mod.rs | 261 +++------ router/src/infer/tool_grammar.rs | 122 ++++ router/src/server.rs | 23 +- 5 files changed, 255 insertions(+), 728 deletions(-) rename router/src/infer/{v3/scheduler.rs => chat_template.rs} (68%) delete mode 100644 router/src/infer/health.rs create mode 100644 router/src/infer/tool_grammar.rs diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/chat_template.rs similarity index 68% rename from router/src/infer/v3/scheduler.rs rename to router/src/infer/chat_template.rs index ad03dd83..c1b65158 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/chat_template.rs @@ -1,511 +1,84 @@ -/// Batching and inference logic -use crate::infer::v3::queue::{Entry, Queue}; -use crate::infer::{ - GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler, -}; -use crate::validation::ValidGenerateRequest; -use crate::{FinishReason, PrefillToken, Token}; -use nohash_hasher::IntMap; -use std::sync::{ - atomic::{AtomicBool, Ordering}, - Arc, -}; -use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient}; -use text_generation_client::ClientError; -use tokio::sync::mpsc::error::SendError; -use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit}; -use tokio::time::Instant; -use tokio_stream::wrappers::UnboundedReceiverStream; -use tracing::{info_span, instrument, Instrument, Span}; +use crate::infer::InferError; +use crate::{ChatTemplateInputs, GrammarType, Message, MessageChunk, Text, TextMessage}; +use minijinja::{Environment, ErrorKind, Template}; +use minijinja_contrib::pycompat; -pub(crate) struct SchedulerV3 { - /// Request queue - queue: Queue, - /// Notify batcher on queue appends - batching_task_notifier: Arc, +/// Raise a exception (custom function) used in the chat templates +pub(crate) fn raise_exception(err_text: String) -> Result { + Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text)) } -impl SchedulerV3 { - #[allow(clippy::too_many_arguments)] +#[derive(Clone)] +pub(crate) struct ChatTemplate { + template: Template<'static, 'static>, + bos_token: Option, + eos_token: Option, + use_default_tool_template: bool, +} + +impl ChatTemplate { pub(crate) fn new( - client: ShardedClient, - waiting_served_ratio: f32, - max_batch_prefill_tokens: u32, - max_batch_total_tokens: u32, - max_waiting_tokens: usize, - max_batch_size: Option, - requires_padding: bool, - window_size: Option, - speculate: u32, - generation_health: Arc, + template: String, + bos_token: Option, + eos_token: Option, ) -> Self { - let queue = Queue::new( - requires_padding, - 16, - window_size, - speculate, - max_batch_total_tokens, - ); - let batching_task_notifier = Arc::new(Notify::new()); + let mut env = Box::new(Environment::new()); + // enable things like .strip() or .capitalize() + env.set_unknown_method_callback(pycompat::unknown_method_callback); + let template_str = template.into_boxed_str(); + env.add_function("raise_exception", raise_exception); - // Spawn batching background task that contains all the inference logic - tokio::spawn(batching_task( - client, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, - queue.clone(), - batching_task_notifier.clone(), - generation_health, - )); + // check if contains the tools variable within the template + let use_default_tool_template = + !template_str.as_ref().replace(' ', "").contains("{{tools}}"); + // leaking env and template_str as read-only, static resources for performance. + let template = Box::leak(env) + .template_from_str(Box::leak(template_str)) + .unwrap(); Self { - queue, - batching_task_notifier, + template, + bos_token, + eos_token, + use_default_tool_template, } } -} -impl Scheduler for SchedulerV3 { - #[instrument(skip_all)] - fn schedule( + pub(crate) fn apply( &self, - request: ValidGenerateRequest, - permit: OwnedSemaphorePermit, - ) -> Result { - // MPSC channel to communicate with the background batching task - let (response_tx, response_rx) = mpsc::unbounded_channel(); - let input_length = request.input_length; - - // Append the request to the queue - self.queue.append(Entry { - request, - response_tx, - span: Span::current(), - temp_span: None, - queue_time: Instant::now(), - batch_time: None, - block_allocation: None, - }); - - // Notify the background task that we have a new entry in the queue that needs - // to be batched - self.batching_task_notifier.notify_one(); - - // Return stream - Ok(( - permit, - input_length, - UnboundedReceiverStream::new(response_rx), - )) - } -} - -/// Batching logic -/// Will be launched in a background Tokio task -/// -/// Batches requests and sends them to the inference server -#[allow(clippy::too_many_arguments)] -pub(crate) async fn batching_task( - mut client: ShardedClient, - waiting_served_ratio: f32, - max_batch_prefill_tokens: u32, - max_batch_total_tokens: u32, - max_waiting_tokens: usize, - max_batch_size: Option, - queue: Queue, - notifier: Arc, - generation_health: Arc, -) { - // Infinite loop - loop { - // Wait for a notification from the Infer struct - notifier.notified().await; - - // Get the next batch from the queue - // This batch might be smaller than the maximum batch size if there are not enough requests - // waiting in the queue - while let Some((mut entries, batch, span)) = queue - .next_batch( - None, - max_batch_size, - max_batch_prefill_tokens, - max_batch_total_tokens, - ) - .await - { - let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health) - .instrument(span) - .await; - let mut waiting_tokens = 1; - - // We loop until we do not receive any cached batch from the inference server (== until - // all requests have met their stopping criteria) - while let Some(batch) = cached_batch { - // Get current batch info - let batch_size = batch.size; - let batch_max_tokens = batch.max_tokens; - let mut batches = vec![batch]; - metrics::gauge!("tgi_batch_current_size", batch_size as f64); - metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64); - - let min_size = if waiting_tokens >= max_waiting_tokens { - // If we didn't onboard any new requests since >= max_waiting_tokens, we try - // to add a new batch even though its size might be small - None - } else { - // Minimum batch size - Some((batch_size as f32 * waiting_served_ratio).floor() as usize) - }; - - let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); - let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize); - - // Try to get a new batch - if let Some((mut new_entries, new_batch, span)) = queue - .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget) - .await - { - // Tracking metrics - if min_size.is_some() { - metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure"); - } else { - metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded"); - } - - entries.iter_mut().for_each(|(_, entry)| { - // Create a new span to add the info that this entry is waiting - // because a new batch is being computed - let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); - // Add relationships - span.follows_from(&entry_waiting_span); - entry_waiting_span.follows_from(&span); - // Update entry - entry.temp_span = Some(entry_waiting_span); - }); - - // Generate one token for this new batch to have the attention past in cache - let new_cached_batch = - prefill(&mut client, new_batch, &mut new_entries, &generation_health) - .instrument(span) - .await; - // Reset waiting counter - waiting_tokens = 1; - // Extend current batch with the new batch - if let Some(new_cached_batch) = new_cached_batch { - entries.extend(new_entries); - batches.push(new_cached_batch); - } + mut messages: Vec, + grammar_with_prompt: Option<(GrammarType, String)>, + ) -> Result { + if self.use_default_tool_template { + if let Some(last_message) = messages.last_mut() { + if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { + last_message.content.push(MessageChunk::Text(Text { + text: format!("\n---\n{}\n{}", tool_prompt, tools), + })); } - - // Create span for this batch to add context to inference calls - let next_batch_size = entries.len(); - let next_batch_span = - info_span!(parent: None, "batch", batch_size = next_batch_size); - entries.iter_mut().for_each(|(_, entry)| { - // Create a new span to link the batch back to this entry - let entry_batch_span = info_span!(parent: &entry.span, "infer"); - // Add relationships - next_batch_span.follows_from(&entry_batch_span); - entry_batch_span.follows_from(&next_batch_span); - // Update entry - entry.temp_span = Some(entry_batch_span); - }); - - cached_batch = decode(&mut client, batches, &mut entries, &generation_health) - .instrument(next_batch_span) - .await; - waiting_tokens += 1; - } - metrics::gauge!("tgi_batch_current_size", 0.0); - metrics::gauge!("tgi_batch_current_max_tokens", 0.0); - } - } -} - -#[instrument(skip_all)] -async fn prefill( - client: &mut ShardedClient, - batch: Batch, - entries: &mut IntMap, - generation_health: &Arc, -) -> Option { - let start_time = Instant::now(); - let batch_id = batch.id; - metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill"); - - match client.prefill(batch).await { - Ok((generations, next_batch, timings)) => { - // Update health - generation_health.store(true, Ordering::SeqCst); - - let start_filtering_time = Instant::now(); - // Send generated tokens and filter stopped entries - filter_send_generations(generations, entries); - - // Filter next batch and remove requests that were stopped - let next_batch = filter_batch(client, next_batch, entries).await; - - metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill"); - metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill"); - metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill"); - metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); - metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill"); - next_batch - } - // If we have an error, we discard the whole batch - Err(err) => { - // Update health - generation_health.store(false, Ordering::SeqCst); - let _ = client.clear_cache(Some(batch_id)).await; - send_errors(err, entries); - metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); - None - } - } -} - -#[instrument(skip_all)] -async fn decode( - client: &mut ShardedClient, - batches: Vec, - entries: &mut IntMap, - generation_health: &Arc, -) -> Option { - let start_time = Instant::now(); - let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); - metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); - - match client.decode(batches).await { - Ok((generations, next_batch, timings)) => { - // Update health - generation_health.store(true, Ordering::SeqCst); - - let start_filtering_time = Instant::now(); - // Send generated tokens and filter stopped entries - filter_send_generations(generations, entries); - - // Filter next batch and remove requests that were stopped - let next_batch = filter_batch(client, next_batch, entries).await; - - if let Some(concat_duration) = timings.concat { - metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode"); - } - metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode"); - metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode"); - metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode"); - metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); - metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); - next_batch - } - // If we have an error, we discard the whole batch - Err(err) => { - generation_health.store(false, Ordering::SeqCst); - for id in batch_ids { - let _ = client.clear_cache(Some(id)).await; - } - send_errors(err, entries); - metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode"); - None - } - } -} - -/// Filter a `batch` and remove all requests not present in `entries` -#[instrument(skip_all)] -async fn filter_batch( - client: &mut ShardedClient, - next_batch: Option, - entries: &IntMap, -) -> Option { - let mut batch = next_batch?; - - // No need to filter - if batch.size as usize == entries.len() { - return Some(batch); - } - - let id = batch.id; - - // Retain only requests that are still in entries - batch.request_ids.retain(|id| entries.contains_key(id)); - - if batch.request_ids.is_empty() { - // All requests have been filtered out - // Next batch is now empty - // Clear it from the Python shards cache - // We unwrap here as we need to panic since we cannot recover if this method fails - client.clear_cache(Some(id)).await.unwrap(); - None - } else { - // Filter Python shard cache - // We unwrap here as we need to panic since we cannot recover if this method fails - client.filter_batch(id, batch.request_ids).await.unwrap() - } -} - -/// Send one or multiple `InferStreamResponse` to Infer for all `entries` -/// and filter entries -#[instrument(skip_all)] -fn filter_send_generations(generations: Vec, entries: &mut IntMap) { - generations.into_iter().for_each(|generation| { - let id = generation.request_id; - // Get entry - // We can `expect` here as the request id should always be in the entries - let entry = entries - .get(&id) - .expect("ID not found in entries. This is a bug."); - - // Create and enter a span to link this function back to the entry - let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered(); - // Send generation responses back to the infer task - // If the receive an error from the Flume channel, it means that the client dropped the - // request and we need to stop generating hence why we unwrap_or(true) - let stopped = send_responses(generation, entry).map_err(|err| { - tracing::error!("Entry response channel error."); - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); - err - }).unwrap_or(true); - if stopped { - entries.remove(&id).expect("ID not found in entries. This is a bug."); - } - }); -} - -/// Send responses through the `entry` response channel -fn send_responses( - generation: Generation, - entry: &Entry, -) -> Result>>> { - // Return directly if the channel is disconnected - if entry.response_tx.is_closed() { - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); - return Ok(true); - } - - let mut stopped = false; - - if let Some(prefill_tokens) = generation.prefill_tokens { - // Create Token objects - // We do that here instead of in the Python code as Rust for loops are faster - let prefill_tokens = prefill_tokens - .ids - .into_iter() - .zip(prefill_tokens.logprobs) - .zip(prefill_tokens.texts) - .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) - .collect(); - - // Send message - entry - .response_tx - .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?; - } - - // Create last Token - let tokens_ = generation.tokens.expect("Non empty tokens in generation"); - let n = tokens_.ids.len(); - metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64); - let mut iterator = tokens_ - .ids - .into_iter() - .zip(tokens_.logprobs) - .zip(tokens_.texts) - .zip(tokens_.is_special) - .enumerate() - .peekable(); - while let Some((i, (((id, logprob), text), special))) = iterator.next() { - let token = Token { - id, - text, - logprob, - special, - }; - let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) { - top_tokens_ - .ids - .iter() - .zip(top_tokens_.logprobs.iter()) - .zip(top_tokens_.texts.iter()) - .zip(top_tokens_.is_special.iter()) - .map(|(((&id, &logprob), text), &special)| Token { - id, - text: text.to_string(), - logprob, - special, - }) - .collect() - } else { - vec![] - }; - match (&generation.generated_text, iterator.peek()) { - (Some(generated_text), None) => { - // Generation has ended - stopped = true; - // Send message - entry.response_tx.send(Ok(InferStreamResponse::End { - token, - top_tokens, - generated_text: GeneratedText::from(generated_text.clone()), - queued: entry.queue_time, - start: entry.batch_time.unwrap(), - }))?; - } - _ => { - // Send message - entry - .response_tx - .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?; } } - } - Ok(stopped) -} + let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); -/// Send errors to Infer for all `entries` -#[instrument(skip_all)] -fn send_errors(error: ClientError, entries: &mut IntMap) { - entries.drain().for_each(|(_, entry)| { - // Create and enter a span to link this function back to the entry - let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); - let err = InferError::GenerationError(error.to_string()); - metrics::increment_counter!("tgi_request_failure", "err" => "generation"); - tracing::error!("{err}"); - - // unwrap_or is valid here as we don't care if the receiver is gone. - entry - .response_tx - .send(Err(err)) - .unwrap_or(()); - }); -} - -impl From for GeneratedText { - fn from(value: text_generation_client::v3::GeneratedText) -> Self { - let v3_finish_reason = - text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap(); - let finish_reason = match v3_finish_reason { - text_generation_client::v3::FinishReason::Length => FinishReason::Length, - text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken, - text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence, - }; - - Self { - text: value.text, - generated_tokens: value.generated_tokens, - finish_reason, - seed: value.seed, - } + self.template + .render(ChatTemplateInputs { + messages, + bos_token: self.bos_token.as_deref(), + eos_token: self.eos_token.as_deref(), + add_generation_prompt: true, + tools: None, + tools_prompt: None, + }) + .map_err(InferError::TemplateError) } } + // tests #[cfg(test)] mod tests { - use crate::infer::raise_exception; + use crate::infer::chat_template::raise_exception; use crate::{ChatTemplateInputs, TextMessage}; use minijinja::Environment; diff --git a/router/src/infer/health.rs b/router/src/infer/health.rs deleted file mode 100644 index 4320c1a4..00000000 --- a/router/src/infer/health.rs +++ /dev/null @@ -1,34 +0,0 @@ -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; -use text_generation_client::Health; - -#[derive(Clone)] -pub(crate) struct HealthCheck { - client: Arc, - generation_health: Arc, -} - -impl HealthCheck { - pub(crate) fn new( - client: Arc, - generation_health: Arc, - ) -> Self { - Self { - client, - generation_health, - } - } - - pub(crate) async fn check(&mut self) -> bool { - let value = if self.generation_health.load(Ordering::SeqCst) { - // Generation is healthy, we only check that the shards can allocate on device - self.client.device_health().await - } else { - self.client.model_health().await - } - .is_ok(); - // Update generation health - self.generation_health.store(value, Ordering::SeqCst); - value - } -} diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 07c334a3..8eadcabb 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -1,6 +1,8 @@ mod health; pub(crate) mod v2; pub(crate) mod v3; +mod chat_template; +mod tool_grammar; pub(crate) use health::HealthCheck; @@ -23,6 +25,7 @@ use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::StreamExt; use tracing::instrument; +use chat_template::ChatTemplate; pub(crate) trait Scheduler { fn schedule( @@ -37,18 +40,20 @@ pub(crate) trait Scheduler { pub struct Infer { /// Validation validation: Validation, - /// Request scheduler - scheduler: Arc, + /// Request backend + backend: Arc, /// Chat template chat_template: Option, /// Inference limit limit_concurrent_requests: Arc, + /// Backend health + backend_health: Arc, } impl Infer { #[allow(clippy::too_many_arguments)] pub(crate) fn new( - scheduler: Arc, + backend: impl Backend + Send + Sync + 'static, validation: Validation, max_concurrent_requests: usize, tokenizer_config: HubTokenizerConfig, @@ -69,20 +74,31 @@ impl Infer { // Inference limit with a semaphore let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); + // Backend health + let backend_health = Arc::new(AtomicBool::new(false)); + Self { validation, - scheduler, + backend: Arc::new(backend), chat_template, limit_concurrent_requests: semaphore, + backend_health, } } /// Add a new request to the queue and return a stream of InferStreamResponse #[instrument(skip_all)] - pub(crate) async fn generate_stream( - &self, + pub(crate) async fn generate_stream<'a>( + &'a self, request: GenerateRequest, - ) -> Result { + ) -> Result< + ( + OwnedSemaphorePermit, + u32, // input_length + impl Stream> + 'a, + ), + InferError, + > { // Limit concurrent requests by acquiring a permit from the semaphore let permit = self .clone() @@ -101,7 +117,36 @@ impl Infer { err })?; - self.scheduler.schedule(valid_request, permit) + let input_length = valid_request.input_length; + let mut generation_stream = self + .backend + .schedule(valid_request) + .map_err(InferError::Backend)?; + + let stream = stream! { + while let Some(generation) = generation_stream.next().await { + self.backend_health.store(generation.is_ok(), Ordering::SeqCst); + + yield generation.map(|generation| match generation { + types::TokenStreamResponse::Prefill(prefill_tokens) => InferStreamResponse::Prefill( + prefill_tokens.into_iter().map(PrefillToken::from).collect() + ), + types::TokenStreamResponse::Intermediate { token, top_tokens } => InferStreamResponse::Intermediate { + token: Token::from(token), + top_tokens: top_tokens.into_iter().map(Token::from).collect(), + }, + types::TokenStreamResponse::End { token, top_tokens, generated_text, start, queued } => InferStreamResponse::End { + token: Token::from(token), + top_tokens: top_tokens.into_iter().map(Token::from).collect(), + generated_text, + start, + queued, + } + }).map_err(InferError::GenerationError) + } + }; + + Ok((permit, input_length, stream)) } /// Tokenizer the input @@ -153,7 +198,7 @@ impl Infer { let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0); // Create stream and keep semaphore permit as long as generate lives - let (_permit, _input_length, mut stream) = self.generate_stream(request).await?; + let (_permit, _input_length, stream) = self.generate_stream(request).await?; // Return values let mut result_prefill = Vec::new(); @@ -163,6 +208,8 @@ impl Infer { let mut result_start = None; let mut result_queued = None; + let mut stream = Box::pin(stream); + // Iterate on stream while let Some(response) = stream.next().await { match response? { @@ -254,190 +301,15 @@ impl Infer { let best_response = infer_responses.remove(max_index); Ok((best_response, infer_responses)) } -} -/// Raise a exception (custom function) used in the chat templates -fn raise_exception(err_text: String) -> Result { - Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text)) -} - -#[derive(Clone)] -struct ChatTemplate { - template: Template<'static, 'static>, - bos_token: Option, - eos_token: Option, - use_default_tool_template: bool, -} - -impl ChatTemplate { - fn new(template: String, bos_token: Option, eos_token: Option) -> Self { - let mut env = Box::new(Environment::new()); - // enable things like .strip() or .capitalize() - env.set_unknown_method_callback(pycompat::unknown_method_callback); - let template_str = template.into_boxed_str(); - env.add_function("raise_exception", raise_exception); - - // check if contains the tools variable within the template - let use_default_tool_template = - !template_str.as_ref().replace(' ', "").contains("{{tools}}"); - // leaking env and template_str as read-only, static resources for performance. - let template = Box::leak(env) - .template_from_str(Box::leak(template_str)) - .unwrap(); - - Self { - template, - bos_token, - eos_token, - use_default_tool_template, - } - } - - fn apply( - &self, - mut messages: Vec, - grammar_with_prompt: Option<(GrammarType, String)>, - ) -> Result { - if self.use_default_tool_template { - if let Some(last_message) = messages.last_mut() { - if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { - last_message.content.push(MessageChunk::Text(Text { - text: format!("\n---\n{}\n{}", tool_prompt, tools), - })); - } - } - } - - let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); - - self.template - .render(ChatTemplateInputs { - messages, - bos_token: self.bos_token.as_deref(), - eos_token: self.eos_token.as_deref(), - add_generation_prompt: true, - tools: None, - tools_prompt: None, - }) - .map_err(InferError::TemplateError) - } -} - -pub struct ToolGrammar {} - -impl ToolGrammar { - pub fn apply( - tools: Option>, - tool_choice: Option, - ) -> Result, InferError> { - if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) { - // let tool_prompt = tool_prompt.unwrap_or_default(); - let tools_to_use = match tool_choice { - ToolType::FunctionName(name) => { - vec![req_tools - .iter() - .find(|tool| tool.function.name == *name) - .unwrap_or_else(|| panic!("Tool with name {} not found", name)) - .clone()] - } - ToolType::OneOf => req_tools.to_owned(), - }; - - // adds the error notification function for LLM feedback if required - let mut text_response_properties = Map::new(); - text_response_properties.insert( - "error".to_string(), - serde_json::json!({ - "type": "string", - "description": "The error or issue to notify" - }), - ); - text_response_properties.insert( - "_name".to_string(), - serde_json::json!({ - "type": "string", - "const": "notify_error" - }), - ); - - let functions: HashMap = tools_to_use - .iter() - .map(|tool| { - let func = tool.function.clone(); - - // Clone the existing parameters, which are expected to be a JSON object - let mut params = if let Value::Object(params) = &func.arguments { - params.clone() - } else { - Map::new() - }; - - // Insert the function's description at the top level, outside of properties - params.insert( - "description".to_string(), - Value::String(func.description.clone().unwrap_or_default()), - ); - - // Ensure 'properties' exists and is an object - let properties = params - .entry("properties".to_string()) - .or_insert_with(|| json!({})) - .as_object_mut() - .unwrap(); - - // Insert the constant for the function name inside 'properties' - properties.insert( - "_name".to_string(), - json!({ - "type": "string", - "const": func.name.clone(), - // "description": "The name of the function" - }), - ); - - // Check if 'required' exists, and it is an array. If not, create an empty array. - let required = params - .entry("required".to_string()) - .or_insert_with(|| json!([])) - .as_array_mut() - .unwrap(); - - // Add 'name' to the 'required' array if it is not already present - if !required.iter().any(|r| r == "_name") { - required.push(json!("_name")); - } - - (func.name, Value::Object(params)) - }) - .chain([( - "notify_error".to_string(), - serde_json::json!({ - "properties": text_response_properties, - "required": ["error", "_name"], - "type": "object" - }), - )]) - .collect(); - - let tools = Tools { - functions_map: FunctionsMap { functions }, - properties: Properties { - function: tools_to_use - .iter() - .map(|tool| FunctionRef { - ref_path: format!("#/$functions/{}", tool.function.name.clone()), - }) - .chain(std::iter::once(FunctionRef { - ref_path: "#/$functions/notify_error".to_string(), - })) - .collect(), - }, - }; - - return Ok(Some(tools)); - } - // Err(InferError::ToolError("No tools provided".to_string())) - Ok(None) + #[instrument(skip(self))] + pub(crate) async fn health(&self) -> bool { + let health = self + .backend + .health(self.backend_health.load(Ordering::SeqCst)) + .await; + self.backend_health.store(health, Ordering::SeqCst); + health } } @@ -491,8 +363,10 @@ pub(crate) struct InferResponse { #[derive(Debug, Error)] pub enum InferError { + #[error("Request failed during scheduling: {0}")] + Backend(BackendError), #[error("Request failed during generation: {0}")] - GenerationError(String), + GenerationError(BackendError), #[error("Model is overloaded")] Overloaded(#[from] TryAcquireError), #[error("Input validation error: {0}")] @@ -508,6 +382,7 @@ pub enum InferError { impl InferError { pub(crate) fn error_type(&self) -> &str { match self { + InferError::Backend(_) => "backend", InferError::GenerationError(_) => "generation", InferError::Overloaded(_) => "overloaded", InferError::ValidationError(_) => "validation", diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs new file mode 100644 index 00000000..4e0b6e60 --- /dev/null +++ b/router/src/infer/tool_grammar.rs @@ -0,0 +1,122 @@ +use crate::infer::InferError; +use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolType, Tools}; +use serde_json::{json, Map, Value}; +use std::collections::HashMap; + +pub(crate) struct ToolGrammar {} + +impl ToolGrammar { + pub fn apply( + tools: Option>, + tool_choice: Option, + ) -> Result, InferError> { + if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) { + // let tool_prompt = tool_prompt.unwrap_or_default(); + let tools_to_use = match tool_choice { + ToolType::FunctionName(name) => { + vec![req_tools + .iter() + .find(|tool| tool.function.name == *name) + .unwrap_or_else(|| panic!("Tool with name {} not found", name)) + .clone()] + } + ToolType::OneOf => req_tools.to_owned(), + }; + + // adds the error notification function for LLM feedback if required + let mut text_response_properties = Map::new(); + text_response_properties.insert( + "error".to_string(), + json!({ + "type": "string", + "description": "The error or issue to notify" + }), + ); + text_response_properties.insert( + "_name".to_string(), + json!({ + "type": "string", + "const": "notify_error" + }), + ); + + let functions: HashMap = tools_to_use + .iter() + .map(|tool| { + let func = tool.function.clone(); + + // Clone the existing parameters, which are expected to be a JSON object + let mut params = if let Value::Object(params) = &func.arguments { + params.clone() + } else { + Map::new() + }; + + // Insert the function's description at the top level, outside of properties + params.insert( + "description".to_string(), + Value::String(func.description.clone().unwrap_or_default()), + ); + + // Ensure 'properties' exists and is an object + let properties = params + .entry("properties".to_string()) + .or_insert_with(|| json!({})) + .as_object_mut() + .unwrap(); + + // Insert the constant for the function name inside 'properties' + properties.insert( + "_name".to_string(), + json!({ + "type": "string", + "const": func.name.clone(), + // "description": "The name of the function" + }), + ); + + // Check if 'required' exists, and it is an array. If not, create an empty array. + let required = params + .entry("required".to_string()) + .or_insert_with(|| json!([])) + .as_array_mut() + .unwrap(); + + // Add 'name' to the 'required' array if it is not already present + if !required.iter().any(|r| r == "_name") { + required.push(json!("_name")); + } + + (func.name, Value::Object(params)) + }) + .chain([( + "notify_error".to_string(), + json!({ + "properties": text_response_properties, + "required": ["error", "_name"], + "type": "object" + }), + )]) + .collect(); + + let tools = Tools { + functions_map: FunctionsMap { functions }, + properties: Properties { + function: tools_to_use + .iter() + .map(|tool| FunctionRef { + ref_path: format!("#/$functions/{}", tool.function.name.clone()), + }) + .chain(std::iter::once(FunctionRef { + ref_path: "#/$functions/notify_error".to_string(), + })) + .collect(), + }, + }; + + return Ok(Some(tools)); + } + // Err(InferError::ToolError("No tools provided".to_string())) + Ok(None) + } +} diff --git a/router/src/server.rs b/router/src/server.rs index 7f15bfdd..8e48f586 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -121,12 +121,10 @@ responses( example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})), ) )] -#[instrument(skip(health))] +#[instrument(skip(infer))] /// Health check method -async fn health( - mut health: Extension, -) -> Result<(), (StatusCode, Json)> { - match health.check().await { +async fn health(infer: Extension) -> Result<(), (StatusCode, Json)> { + match infer.health().await { true => Ok(()), false => Err(( StatusCode::SERVICE_UNAVAILABLE, @@ -437,8 +435,9 @@ async fn generate_stream_internal( } else { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { // Keep permit as long as generate_stream lives - Ok((_permit, _input_length, mut response_stream)) => { + Ok((_permit, _input_length, response_stream)) => { let mut index = 0; + let mut response_stream = Box::pin(response_stream); // Server-Sent Event stream while let Some(response) = response_stream.next().await { index += 1; @@ -1960,16 +1959,8 @@ impl From for Event { #[derive(Debug, Error)] pub enum WebServerError { - #[error("Unable to connect to the Python model shards: {0}")] - Connection(ClientError), - #[error("Unable to clear the Python model shards cache: {0}")] - Cache(ClientError), - #[error("Unable to get the Python model shards info: {0}")] - Info(ClientError), - #[error("Unable to warmup the Python model shards: {0}")] - Warmup(ClientError), - #[error("Not enough memory to handle `max_total_tokens={0}`")] - NotEnoughMemory(usize), + #[error("Backend error: {0}")] + Backend(#[from] BackendError), #[error("Axum error: {0}")] Axum(#[from] axum::BoxError), }