From abf56b75a4f9d416b2e5a7e9bc51ff70dd726b86 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 20 Jun 2024 12:40:36 +0200 Subject: [PATCH] refactor schedulers --- Cargo.lock | 1 + router/Cargo.toml | 1 + router/src/infer/chat_template.rs | 75 +++++ router/src/infer/health.rs | 34 --- router/src/infer/mod.rs | 277 ++++-------------- router/src/infer/schedulers/mod.rs | 54 ++++ router/src/infer/{ => schedulers}/v2/mod.rs | 0 router/src/infer/{ => schedulers}/v2/queue.rs | 0 .../infer/{ => schedulers}/v2/scheduler.rs | 4 +- .../{ => schedulers}/v3/block_allocator.rs | 0 router/src/infer/schedulers/v3/mod.rs | 109 +++++++ router/src/infer/{ => schedulers}/v3/queue.rs | 8 +- .../infer/{ => schedulers}/v3/scheduler.rs | 75 +++-- router/src/infer/tool_grammar.rs | 122 ++++++++ router/src/infer/v3/mod.rs | 5 - router/src/server.rs | 174 ++--------- 16 files changed, 485 insertions(+), 454 deletions(-) create mode 100644 router/src/infer/chat_template.rs delete mode 100644 router/src/infer/health.rs create mode 100644 router/src/infer/schedulers/mod.rs rename router/src/infer/{ => schedulers}/v2/mod.rs (100%) rename router/src/infer/{ => schedulers}/v2/queue.rs (100%) rename router/src/infer/{ => schedulers}/v2/scheduler.rs (99%) rename router/src/infer/{ => schedulers}/v3/block_allocator.rs (100%) create mode 100644 router/src/infer/schedulers/v3/mod.rs rename router/src/infer/{ => schedulers}/v3/queue.rs (99%) rename router/src/infer/{ => schedulers}/v3/scheduler.rs (97%) create mode 100644 router/src/infer/tool_grammar.rs delete mode 100644 router/src/infer/v3/mod.rs diff --git a/Cargo.lock b/Cargo.lock index b9bd7363..0ec85025 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3602,6 +3602,7 @@ name = "text-generation-router" version = "2.0.5-dev0" dependencies = [ "async-stream", + "async-trait", "axum 0.7.5", "axum-tracing-opentelemetry", "base64 0.22.1", diff --git a/router/Cargo.toml b/router/Cargo.toml index 5bf4c00c..5faa0aa5 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -15,6 +15,7 @@ name = "text-generation-router" path = "src/main.rs" [dependencies] +async-trait = "^0.1" async-stream = "0.3.5" axum = { version = "0.7", features = ["json"] } axum-tracing-opentelemetry = "0.16" diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs new file mode 100644 index 00000000..ed112223 --- /dev/null +++ b/router/src/infer/chat_template.rs @@ -0,0 +1,75 @@ +use crate::infer::InferError; +use crate::{ChatTemplateInputs, GrammarType, Message, MessageChunk, Text, TextMessage}; +use minijinja::{Environment, ErrorKind, Template}; +use minijinja_contrib::pycompat; + +/// Raise a exception (custom function) used in the chat templates +pub(crate) fn raise_exception(err_text: String) -> Result { + Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text)) +} + +#[derive(Clone)] +pub(crate) struct ChatTemplate { + template: Template<'static, 'static>, + bos_token: Option, + eos_token: Option, + use_default_tool_template: bool, +} + +impl ChatTemplate { + pub(crate) fn new( + template: String, + bos_token: Option, + eos_token: Option, + ) -> Self { + let mut env = Box::new(Environment::new()); + // enable things like .strip() or .capitalize() + env.set_unknown_method_callback(pycompat::unknown_method_callback); + let template_str = template.into_boxed_str(); + env.add_function("raise_exception", raise_exception); + + // check if contains the tools variable within the template + let use_default_tool_template = + !template_str.as_ref().replace(' ', "").contains("{{tools}}"); + // leaking env and template_str as read-only, static resources for performance. + let template = Box::leak(env) + .template_from_str(Box::leak(template_str)) + .unwrap(); + + Self { + template, + bos_token, + eos_token, + use_default_tool_template, + } + } + + pub(crate) fn apply( + &self, + mut messages: Vec, + grammar_with_prompt: Option<(GrammarType, String)>, + ) -> Result { + if self.use_default_tool_template { + if let Some(last_message) = messages.last_mut() { + if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { + last_message.content.push(MessageChunk::Text(Text { + text: format!("\n---\n{}\n{}", tool_prompt, tools), + })); + } + } + } + + let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); + + self.template + .render(ChatTemplateInputs { + messages, + bos_token: self.bos_token.as_deref(), + eos_token: self.eos_token.as_deref(), + add_generation_prompt: true, + tools: None, + tools_prompt: None, + }) + .map_err(InferError::TemplateError) + } +} diff --git a/router/src/infer/health.rs b/router/src/infer/health.rs deleted file mode 100644 index 4320c1a4..00000000 --- a/router/src/infer/health.rs +++ /dev/null @@ -1,34 +0,0 @@ -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; -use text_generation_client::Health; - -#[derive(Clone)] -pub(crate) struct HealthCheck { - client: Arc, - generation_health: Arc, -} - -impl HealthCheck { - pub(crate) fn new( - client: Arc, - generation_health: Arc, - ) -> Self { - Self { - client, - generation_health, - } - } - - pub(crate) async fn check(&mut self) -> bool { - let value = if self.generation_health.load(Ordering::SeqCst) { - // Generation is healthy, we only check that the shards can allocate on device - self.client.device_health().await - } else { - self.client.model_health().await - } - .is_ok(); - // Update generation health - self.generation_health.store(value, Ordering::SeqCst); - value - } -} diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 07c334a3..e2950574 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -1,37 +1,31 @@ -mod health; -pub(crate) mod v2; -pub(crate) mod v3; +mod chat_template; +pub(crate) mod schedulers; +mod tool_grammar; -pub(crate) use health::HealthCheck; +pub(crate) use tool_grammar::ToolGrammar; -use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; +use crate::infer::chat_template::ChatTemplate; +use crate::validation::{Validation, ValidationError}; +use crate::GrammarType; use crate::{ - ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, - HubTokenizerConfig, Message, MessageChunk, PrefillToken, Text, TextMessage, Token, + ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig, + Message, PrefillToken, Token, }; -use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools}; use futures::future::try_join_all; -use minijinja::{Environment, ErrorKind, Template}; -use minijinja_contrib::pycompat; +use minijinja::ErrorKind; +pub(crate) use schedulers::Scheduler; -use serde_json::{json, Map, Value}; -use std::collections::HashMap; +use crate::infer::schedulers::SchedulerError; +use async_stream::stream; +use futures::Stream; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use thiserror::Error; use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::time::Instant; -use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::StreamExt; use tracing::instrument; -pub(crate) trait Scheduler { - fn schedule( - &self, - request: ValidGenerateRequest, - permit: OwnedSemaphorePermit, - ) -> Result; -} - /// Inference struct #[derive(Clone)] pub struct Infer { @@ -43,6 +37,8 @@ pub struct Infer { chat_template: Option, /// Inference limit limit_concurrent_requests: Arc, + /// Backend health + backend_health: Arc, } impl Infer { @@ -69,20 +65,31 @@ impl Infer { // Inference limit with a semaphore let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); + // Backend health + let backend_health = Arc::new(AtomicBool::new(false)); + Self { validation, scheduler, chat_template, limit_concurrent_requests: semaphore, + backend_health, } } /// Add a new request to the queue and return a stream of InferStreamResponse #[instrument(skip_all)] - pub(crate) async fn generate_stream( - &self, + pub(crate) async fn generate_stream<'a>( + &'a self, request: GenerateRequest, - ) -> Result { + ) -> Result< + ( + OwnedSemaphorePermit, + u32, // input_length + impl Stream> + 'a, + ), + InferError, + > { // Limit concurrent requests by acquiring a permit from the semaphore let permit = self .clone() @@ -101,7 +108,20 @@ impl Infer { err })?; - self.scheduler.schedule(valid_request, permit) + let input_length = valid_request.input_length; + let mut generation_stream = self + .scheduler + .schedule(valid_request) + .map_err(InferError::Scheduler)?; + + let stream = stream! { + while let Some(generation) = generation_stream.next().await { + self.backend_health.store(generation.is_ok(), Ordering::SeqCst); + yield generation.map_err(InferError::GenerationError) + } + }; + + Ok((permit, input_length, stream)) } /// Tokenizer the input @@ -153,7 +173,7 @@ impl Infer { let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0); // Create stream and keep semaphore permit as long as generate lives - let (_permit, _input_length, mut stream) = self.generate_stream(request).await?; + let (_permit, _input_length, stream) = self.generate_stream(request).await?; // Return values let mut result_prefill = Vec::new(); @@ -163,6 +183,8 @@ impl Infer { let mut result_start = None; let mut result_queued = None; + let mut stream = Box::pin(stream); + // Iterate on stream while let Some(response) = stream.next().await { match response? { @@ -254,200 +276,18 @@ impl Infer { let best_response = infer_responses.remove(max_index); Ok((best_response, infer_responses)) } -} -/// Raise a exception (custom function) used in the chat templates -fn raise_exception(err_text: String) -> Result { - Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text)) -} - -#[derive(Clone)] -struct ChatTemplate { - template: Template<'static, 'static>, - bos_token: Option, - eos_token: Option, - use_default_tool_template: bool, -} - -impl ChatTemplate { - fn new(template: String, bos_token: Option, eos_token: Option) -> Self { - let mut env = Box::new(Environment::new()); - // enable things like .strip() or .capitalize() - env.set_unknown_method_callback(pycompat::unknown_method_callback); - let template_str = template.into_boxed_str(); - env.add_function("raise_exception", raise_exception); - - // check if contains the tools variable within the template - let use_default_tool_template = - !template_str.as_ref().replace(' ', "").contains("{{tools}}"); - // leaking env and template_str as read-only, static resources for performance. - let template = Box::leak(env) - .template_from_str(Box::leak(template_str)) - .unwrap(); - - Self { - template, - bos_token, - eos_token, - use_default_tool_template, - } - } - - fn apply( - &self, - mut messages: Vec, - grammar_with_prompt: Option<(GrammarType, String)>, - ) -> Result { - if self.use_default_tool_template { - if let Some(last_message) = messages.last_mut() { - if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { - last_message.content.push(MessageChunk::Text(Text { - text: format!("\n---\n{}\n{}", tool_prompt, tools), - })); - } - } - } - - let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); - - self.template - .render(ChatTemplateInputs { - messages, - bos_token: self.bos_token.as_deref(), - eos_token: self.eos_token.as_deref(), - add_generation_prompt: true, - tools: None, - tools_prompt: None, - }) - .map_err(InferError::TemplateError) + #[instrument(skip(self))] + pub(crate) async fn health(&self) -> bool { + let health = self + .scheduler + .health(self.backend_health.load(Ordering::SeqCst)) + .await; + self.backend_health.store(health, Ordering::SeqCst); + health } } -pub struct ToolGrammar {} - -impl ToolGrammar { - pub fn apply( - tools: Option>, - tool_choice: Option, - ) -> Result, InferError> { - if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) { - // let tool_prompt = tool_prompt.unwrap_or_default(); - let tools_to_use = match tool_choice { - ToolType::FunctionName(name) => { - vec![req_tools - .iter() - .find(|tool| tool.function.name == *name) - .unwrap_or_else(|| panic!("Tool with name {} not found", name)) - .clone()] - } - ToolType::OneOf => req_tools.to_owned(), - }; - - // adds the error notification function for LLM feedback if required - let mut text_response_properties = Map::new(); - text_response_properties.insert( - "error".to_string(), - serde_json::json!({ - "type": "string", - "description": "The error or issue to notify" - }), - ); - text_response_properties.insert( - "_name".to_string(), - serde_json::json!({ - "type": "string", - "const": "notify_error" - }), - ); - - let functions: HashMap = tools_to_use - .iter() - .map(|tool| { - let func = tool.function.clone(); - - // Clone the existing parameters, which are expected to be a JSON object - let mut params = if let Value::Object(params) = &func.arguments { - params.clone() - } else { - Map::new() - }; - - // Insert the function's description at the top level, outside of properties - params.insert( - "description".to_string(), - Value::String(func.description.clone().unwrap_or_default()), - ); - - // Ensure 'properties' exists and is an object - let properties = params - .entry("properties".to_string()) - .or_insert_with(|| json!({})) - .as_object_mut() - .unwrap(); - - // Insert the constant for the function name inside 'properties' - properties.insert( - "_name".to_string(), - json!({ - "type": "string", - "const": func.name.clone(), - // "description": "The name of the function" - }), - ); - - // Check if 'required' exists, and it is an array. If not, create an empty array. - let required = params - .entry("required".to_string()) - .or_insert_with(|| json!([])) - .as_array_mut() - .unwrap(); - - // Add 'name' to the 'required' array if it is not already present - if !required.iter().any(|r| r == "_name") { - required.push(json!("_name")); - } - - (func.name, Value::Object(params)) - }) - .chain([( - "notify_error".to_string(), - serde_json::json!({ - "properties": text_response_properties, - "required": ["error", "_name"], - "type": "object" - }), - )]) - .collect(); - - let tools = Tools { - functions_map: FunctionsMap { functions }, - properties: Properties { - function: tools_to_use - .iter() - .map(|tool| FunctionRef { - ref_path: format!("#/$functions/{}", tool.function.name.clone()), - }) - .chain(std::iter::once(FunctionRef { - ref_path: "#/$functions/notify_error".to_string(), - })) - .collect(), - }, - }; - - return Ok(Some(tools)); - } - // Err(InferError::ToolError("No tools provided".to_string())) - Ok(None) - } -} - -/// Type alias for generation responses -pub(crate) type GenerateStreamResponse = ( - OwnedSemaphorePermit, - u32, // input_length - UnboundedReceiverStream>, -); - #[derive(Debug)] pub(crate) struct GeneratedText { pub(crate) text: String, @@ -491,8 +331,10 @@ pub(crate) struct InferResponse { #[derive(Debug, Error)] pub enum InferError { + #[error("Request failed during scheduling: {0}")] + Scheduler(SchedulerError), #[error("Request failed during generation: {0}")] - GenerationError(String), + GenerationError(SchedulerError), #[error("Model is overloaded")] Overloaded(#[from] TryAcquireError), #[error("Input validation error: {0}")] @@ -508,6 +350,7 @@ pub enum InferError { impl InferError { pub(crate) fn error_type(&self) -> &str { match self { + InferError::Scheduler(_) => "scheduler", InferError::GenerationError(_) => "generation", InferError::Overloaded(_) => "overloaded", InferError::ValidationError(_) => "validation", diff --git a/router/src/infer/schedulers/mod.rs b/router/src/infer/schedulers/mod.rs new file mode 100644 index 00000000..3743b466 --- /dev/null +++ b/router/src/infer/schedulers/mod.rs @@ -0,0 +1,54 @@ +mod v3; + +use crate::infer::InferStreamResponse; +use crate::validation::ValidGenerateRequest; +use async_trait::async_trait; +use std::sync::Arc; +use text_generation_client::ShardInfo; +use thiserror::Error; +use tokio_stream::wrappers::UnboundedReceiverStream; + +#[async_trait] +pub(crate) trait Scheduler { + fn schedule( + &self, + request: ValidGenerateRequest, + ) -> Result>, SchedulerError>; + + async fn health(&self, current_health: bool) -> bool; +} + +#[allow(clippy::too_many_arguments)] +pub(crate) async fn connect_backend( + master_shard_uds_path: String, + max_input_tokens: usize, + max_total_tokens: usize, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: Option, + max_waiting_tokens: usize, + max_batch_size: Option, +) -> Result<(Arc, ShardInfo, u32), SchedulerError> { + v3::connect_backend( + master_shard_uds_path, + max_input_tokens, + max_total_tokens, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + ) + .await + .map_err(|err| SchedulerError::Startup(Box::new(err))) +} + +#[derive(Debug, Error)] +pub enum SchedulerError { + #[error("Startup error: {0}")] + Startup(Box), + #[error("Request failed during generation: {0}")] + Generation(Box), + #[error("Backend error: {0}")] + Backend(Box), +} diff --git a/router/src/infer/v2/mod.rs b/router/src/infer/schedulers/v2/mod.rs similarity index 100% rename from router/src/infer/v2/mod.rs rename to router/src/infer/schedulers/v2/mod.rs diff --git a/router/src/infer/v2/queue.rs b/router/src/infer/schedulers/v2/queue.rs similarity index 100% rename from router/src/infer/v2/queue.rs rename to router/src/infer/schedulers/v2/queue.rs diff --git a/router/src/infer/v2/scheduler.rs b/router/src/infer/schedulers/v2/scheduler.rs similarity index 99% rename from router/src/infer/v2/scheduler.rs rename to router/src/infer/schedulers/v2/scheduler.rs index ba6f520d..e7a207fc 100644 --- a/router/src/infer/v2/scheduler.rs +++ b/router/src/infer/schedulers/v2/scheduler.rs @@ -1,5 +1,5 @@ /// Batching and inference logic -use crate::infer::v2::queue::{Entry, Queue}; +use crate::infer::schedulers::v2::queue::{Entry, Queue}; use crate::infer::{ GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler, }; @@ -498,7 +498,7 @@ impl From for GeneratedText { // tests #[cfg(test)] mod tests { - use crate::infer::raise_exception; + use crate::infer::chat_template::raise_exception; use crate::{ChatTemplateInputs, TextMessage}; use minijinja::Environment; diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/schedulers/v3/block_allocator.rs similarity index 100% rename from router/src/infer/v3/block_allocator.rs rename to router/src/infer/schedulers/v3/block_allocator.rs diff --git a/router/src/infer/schedulers/v3/mod.rs b/router/src/infer/schedulers/v3/mod.rs new file mode 100644 index 00000000..2e1fead4 --- /dev/null +++ b/router/src/infer/schedulers/v3/mod.rs @@ -0,0 +1,109 @@ +mod block_allocator; +mod queue; +mod scheduler; + +use crate::infer::schedulers::v3::scheduler::SchedulerV3; +use crate::infer::schedulers::Scheduler; +use std::sync::Arc; +use text_generation_client::v3::ShardedClient; +use text_generation_client::{ClientError, ShardInfo}; +use thiserror::Error; + +#[allow(clippy::too_many_arguments)] +pub(crate) async fn connect_backend( + master_shard_uds_path: String, + max_input_tokens: usize, + max_total_tokens: usize, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: Option, + max_waiting_tokens: usize, + max_batch_size: Option, +) -> Result<(Arc, ShardInfo, u32), V3Error> { + // Helper function + let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option| { + match max_supported_batch_total_tokens { + // Older models do not support automatic max-batch-total-tokens + None => { + let max_batch_total_tokens = max_batch_total_tokens + .unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))); + tracing::warn!("Model does not support automatic max batch total tokens"); + Ok(max_batch_total_tokens) + } + // Flash attention models return their max supported total tokens + Some(max_supported_batch_total_tokens) => { + // Warn if user added his own max-batch-total-tokens as we will ignore it + if max_batch_total_tokens.is_some() { + tracing::warn!( + "`--max-batch-total-tokens` is deprecated for Flash \ + Attention models." + ); + tracing::warn!( + "Inferred max batch total tokens: {max_supported_batch_total_tokens}" + ); + } + if max_total_tokens as u32 > max_supported_batch_total_tokens { + return Err(V3Error::NotEnoughMemory(max_total_tokens)); + } + + Ok(max_supported_batch_total_tokens) + } + } + }; + + let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) + .await + .map_err(V3Error::Connection)?; + + // server is running on v3 + // Clear the cache; useful if the webserver rebooted + sharded_client + .clear_cache(None) + .await + .map_err(V3Error::Cache)?; + // Get info from the shard + let shard_info = sharded_client.info().await.map_err(V3Error::Info)?; + + // Warmup model + tracing::info!("Warming up model"); + let max_batch_total_tokens = check_max_batch_total_tokens( + sharded_client + .warmup( + max_input_tokens as u32, + max_batch_prefill_tokens, + max_total_tokens as u32, + max_batch_size, + ) + .await + .map_err(V3Error::Warmup)?, + )?; + + let scheduler = Arc::new(SchedulerV3::new( + sharded_client, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + shard_info.requires_padding, + shard_info.window_size, + shard_info.speculate, + )); + tracing::info!("Using scheduler V3"); + + Ok((scheduler, shard_info, max_batch_total_tokens)) +} + +#[derive(Debug, Error)] +pub(crate) enum V3Error { + #[error("Unable to clear the Python model shards cache: {0}")] + Cache(ClientError), + #[error("Unable to connect to the Python model shards: {0}")] + Connection(ClientError), + #[error("Unable to get the Python model shards info: {0}")] + Info(ClientError), + #[error("Unable to warmup the Python model shards: {0}")] + Warmup(ClientError), + #[error("Not enough memory to handle `max_total_tokens={0}`")] + NotEnoughMemory(usize), +} diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/schedulers/v3/queue.rs similarity index 99% rename from router/src/infer/v3/queue.rs rename to router/src/infer/schedulers/v3/queue.rs index 0b66142a..5b945a69 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/schedulers/v3/queue.rs @@ -1,5 +1,5 @@ -use crate::infer::v3::block_allocator::{BlockAllocation, BlockAllocator}; -use crate::infer::InferError; +use crate::infer::schedulers::v3::block_allocator::{BlockAllocation, BlockAllocator}; +use crate::infer::schedulers::SchedulerError; use crate::infer::InferStreamResponse; use crate::validation::{ ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, @@ -22,7 +22,7 @@ pub(crate) struct Entry { /// Request pub request: ValidGenerateRequest, /// Response sender to communicate between the Infer struct and the batching_task - pub response_tx: mpsc::UnboundedSender>, + pub response_tx: mpsc::UnboundedSender>, /// Span that will live as long as entry pub span: Span, /// Temporary span used as a guard when logging inference, wait times... @@ -463,7 +463,7 @@ mod tests { fn default_entry() -> ( Entry, - mpsc::UnboundedReceiver>, + mpsc::UnboundedReceiver>, ) { let (response_tx, receiver_tx) = mpsc::unbounded_channel(); diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/schedulers/v3/scheduler.rs similarity index 97% rename from router/src/infer/v3/scheduler.rs rename to router/src/infer/schedulers/v3/scheduler.rs index ad03dd83..038ba397 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/schedulers/v3/scheduler.rs @@ -1,19 +1,16 @@ /// Batching and inference logic -use crate::infer::v3::queue::{Entry, Queue}; -use crate::infer::{ - GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler, -}; +use crate::infer::schedulers::v3::queue::{Entry, Queue}; +use crate::infer::schedulers::SchedulerError; +use crate::infer::{GeneratedText, InferStreamResponse, Scheduler}; use crate::validation::ValidGenerateRequest; use crate::{FinishReason, PrefillToken, Token}; +use async_trait::async_trait; use nohash_hasher::IntMap; -use std::sync::{ - atomic::{AtomicBool, Ordering}, - Arc, -}; +use std::sync::Arc; use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient}; -use text_generation_client::ClientError; +use text_generation_client::{ClientError, Health}; use tokio::sync::mpsc::error::SendError; -use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit}; +use tokio::sync::{mpsc, Notify}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{info_span, instrument, Instrument, Span}; @@ -23,6 +20,8 @@ pub(crate) struct SchedulerV3 { queue: Queue, /// Notify batcher on queue appends batching_task_notifier: Arc, + /// Client, used for health checks to skip the queue + client: ShardedClient, } impl SchedulerV3 { @@ -37,7 +36,6 @@ impl SchedulerV3 { requires_padding: bool, window_size: Option, speculate: u32, - generation_health: Arc, ) -> Self { let queue = Queue::new( requires_padding, @@ -50,7 +48,7 @@ impl SchedulerV3 { // Spawn batching background task that contains all the inference logic tokio::spawn(batching_task( - client, + client.clone(), waiting_served_ratio, max_batch_prefill_tokens, max_batch_total_tokens, @@ -58,26 +56,26 @@ impl SchedulerV3 { max_batch_size, queue.clone(), batching_task_notifier.clone(), - generation_health, )); Self { queue, batching_task_notifier, + client, } } } +#[async_trait] impl Scheduler for SchedulerV3 { #[instrument(skip_all)] fn schedule( &self, request: ValidGenerateRequest, - permit: OwnedSemaphorePermit, - ) -> Result { + ) -> Result>, SchedulerError> + { // MPSC channel to communicate with the background batching task let (response_tx, response_rx) = mpsc::unbounded_channel(); - let input_length = request.input_length; // Append the request to the queue self.queue.append(Entry { @@ -95,11 +93,17 @@ impl Scheduler for SchedulerV3 { self.batching_task_notifier.notify_one(); // Return stream - Ok(( - permit, - input_length, - UnboundedReceiverStream::new(response_rx), - )) + Ok(UnboundedReceiverStream::new(response_rx)) + } + + async fn health(&self, current_health: bool) -> bool { + if current_health { + // Generation is healthy, we only check that the shards can allocate on device + self.client.device_health().await + } else { + self.client.model_health().await + } + .is_ok() } } @@ -117,7 +121,6 @@ pub(crate) async fn batching_task( max_batch_size: Option, queue: Queue, notifier: Arc, - generation_health: Arc, ) { // Infinite loop loop { @@ -136,7 +139,7 @@ pub(crate) async fn batching_task( ) .await { - let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health) + let mut cached_batch = prefill(&mut client, batch, &mut entries) .instrument(span) .await; let mut waiting_tokens = 1; @@ -187,10 +190,9 @@ pub(crate) async fn batching_task( }); // Generate one token for this new batch to have the attention past in cache - let new_cached_batch = - prefill(&mut client, new_batch, &mut new_entries, &generation_health) - .instrument(span) - .await; + let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries) + .instrument(span) + .await; // Reset waiting counter waiting_tokens = 1; // Extend current batch with the new batch @@ -214,7 +216,7 @@ pub(crate) async fn batching_task( entry.temp_span = Some(entry_batch_span); }); - cached_batch = decode(&mut client, batches, &mut entries, &generation_health) + cached_batch = decode(&mut client, batches, &mut entries) .instrument(next_batch_span) .await; waiting_tokens += 1; @@ -230,7 +232,6 @@ async fn prefill( client: &mut ShardedClient, batch: Batch, entries: &mut IntMap, - generation_health: &Arc, ) -> Option { let start_time = Instant::now(); let batch_id = batch.id; @@ -238,9 +239,6 @@ async fn prefill( match client.prefill(batch).await { Ok((generations, next_batch, timings)) => { - // Update health - generation_health.store(true, Ordering::SeqCst); - let start_filtering_time = Instant::now(); // Send generated tokens and filter stopped entries filter_send_generations(generations, entries); @@ -257,8 +255,6 @@ async fn prefill( } // If we have an error, we discard the whole batch Err(err) => { - // Update health - generation_health.store(false, Ordering::SeqCst); let _ = client.clear_cache(Some(batch_id)).await; send_errors(err, entries); metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); @@ -272,7 +268,6 @@ async fn decode( client: &mut ShardedClient, batches: Vec, entries: &mut IntMap, - generation_health: &Arc, ) -> Option { let start_time = Instant::now(); let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); @@ -280,9 +275,6 @@ async fn decode( match client.decode(batches).await { Ok((generations, next_batch, timings)) => { - // Update health - generation_health.store(true, Ordering::SeqCst); - let start_filtering_time = Instant::now(); // Send generated tokens and filter stopped entries filter_send_generations(generations, entries); @@ -302,7 +294,6 @@ async fn decode( } // If we have an error, we discard the whole batch Err(err) => { - generation_health.store(false, Ordering::SeqCst); for id in batch_ids { let _ = client.clear_cache(Some(id)).await; } @@ -378,7 +369,7 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap Result>>> { +) -> Result>>> { // Return directly if the channel is disconnected if entry.response_tx.is_closed() { metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); @@ -471,7 +462,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { entries.drain().for_each(|(_, entry)| { // Create and enter a span to link this function back to the entry let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); - let err = InferError::GenerationError(error.to_string()); + let err = SchedulerError::Generation(Box::new(error.clone())); metrics::increment_counter!("tgi_request_failure", "err" => "generation"); tracing::error!("{err}"); @@ -505,7 +496,7 @@ impl From for GeneratedText { // tests #[cfg(test)] mod tests { - use crate::infer::raise_exception; + use crate::infer::chat_template::raise_exception; use crate::{ChatTemplateInputs, TextMessage}; use minijinja::Environment; diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs new file mode 100644 index 00000000..4e0b6e60 --- /dev/null +++ b/router/src/infer/tool_grammar.rs @@ -0,0 +1,122 @@ +use crate::infer::InferError; +use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolType, Tools}; +use serde_json::{json, Map, Value}; +use std::collections::HashMap; + +pub(crate) struct ToolGrammar {} + +impl ToolGrammar { + pub fn apply( + tools: Option>, + tool_choice: Option, + ) -> Result, InferError> { + if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) { + // let tool_prompt = tool_prompt.unwrap_or_default(); + let tools_to_use = match tool_choice { + ToolType::FunctionName(name) => { + vec![req_tools + .iter() + .find(|tool| tool.function.name == *name) + .unwrap_or_else(|| panic!("Tool with name {} not found", name)) + .clone()] + } + ToolType::OneOf => req_tools.to_owned(), + }; + + // adds the error notification function for LLM feedback if required + let mut text_response_properties = Map::new(); + text_response_properties.insert( + "error".to_string(), + json!({ + "type": "string", + "description": "The error or issue to notify" + }), + ); + text_response_properties.insert( + "_name".to_string(), + json!({ + "type": "string", + "const": "notify_error" + }), + ); + + let functions: HashMap = tools_to_use + .iter() + .map(|tool| { + let func = tool.function.clone(); + + // Clone the existing parameters, which are expected to be a JSON object + let mut params = if let Value::Object(params) = &func.arguments { + params.clone() + } else { + Map::new() + }; + + // Insert the function's description at the top level, outside of properties + params.insert( + "description".to_string(), + Value::String(func.description.clone().unwrap_or_default()), + ); + + // Ensure 'properties' exists and is an object + let properties = params + .entry("properties".to_string()) + .or_insert_with(|| json!({})) + .as_object_mut() + .unwrap(); + + // Insert the constant for the function name inside 'properties' + properties.insert( + "_name".to_string(), + json!({ + "type": "string", + "const": func.name.clone(), + // "description": "The name of the function" + }), + ); + + // Check if 'required' exists, and it is an array. If not, create an empty array. + let required = params + .entry("required".to_string()) + .or_insert_with(|| json!([])) + .as_array_mut() + .unwrap(); + + // Add 'name' to the 'required' array if it is not already present + if !required.iter().any(|r| r == "_name") { + required.push(json!("_name")); + } + + (func.name, Value::Object(params)) + }) + .chain([( + "notify_error".to_string(), + json!({ + "properties": text_response_properties, + "required": ["error", "_name"], + "type": "object" + }), + )]) + .collect(); + + let tools = Tools { + functions_map: FunctionsMap { functions }, + properties: Properties { + function: tools_to_use + .iter() + .map(|tool| FunctionRef { + ref_path: format!("#/$functions/{}", tool.function.name.clone()), + }) + .chain(std::iter::once(FunctionRef { + ref_path: "#/$functions/notify_error".to_string(), + })) + .collect(), + }, + }; + + return Ok(Some(tools)); + } + // Err(InferError::ToolError("No tools provided".to_string())) + Ok(None) + } +} diff --git a/router/src/infer/v3/mod.rs b/router/src/infer/v3/mod.rs deleted file mode 100644 index f9effab8..00000000 --- a/router/src/infer/v3/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod block_allocator; -mod queue; -mod scheduler; - -pub(crate) use scheduler::SchedulerV3; diff --git a/router/src/server.rs b/router/src/server.rs index aa872df9..e64db7ed 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,8 +1,7 @@ /// HTTP Server logic use crate::config::Config; -use crate::infer::v2::SchedulerV2; -use crate::infer::v3::SchedulerV3; -use crate::infer::{HealthCheck, Scheduler}; +use crate::infer::schedulers::{connect_backend, SchedulerError}; +use crate::infer::Scheduler; use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar}; #[cfg(feature = "kserve")] use crate::kserve::{ @@ -39,9 +38,8 @@ use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use serde_json::Value; use std::convert::Infallible; use std::net::SocketAddr; -use std::sync::atomic::AtomicBool; use std::sync::Arc; -use text_generation_client::{v2, v3, ClientError, ShardInfo}; +use text_generation_client::ShardInfo; use thiserror::Error; use tokenizers::Tokenizer; use tokio::select; @@ -121,12 +119,10 @@ responses( example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})), ) )] -#[instrument(skip(health))] +#[instrument(skip(infer))] /// Health check method -async fn health( - mut health: Extension, -) -> Result<(), (StatusCode, Json)> { - match health.check().await { +async fn health(infer: Extension) -> Result<(), (StatusCode, Json)> { + match infer.health().await { true => Ok(()), false => Err(( StatusCode::SERVICE_UNAVAILABLE, @@ -437,8 +433,9 @@ async fn generate_stream_internal( } else { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { // Keep permit as long as generate_stream lives - Ok((_permit, _input_length, mut response_stream)) => { + Ok((_permit, _input_length, response_stream)) => { let mut index = 0; + let mut response_stream = Box::pin(response_stream); // Server-Sent Event stream while let Some(response) = response_stream.next().await { index += 1; @@ -1497,137 +1494,22 @@ pub async fn run( // Create state // Open connection, get model info and warmup - let (scheduler, health_ext, shard_info, max_batch_total_tokens): ( + let (scheduler, shard_info, max_batch_total_tokens): ( Arc, - HealthCheck, ShardInfo, u32, - ) = { - // Helper function to check both v2 and v3 - let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option| { - match max_supported_batch_total_tokens { - // Older models do not support automatic max-batch-total-tokens - None => { - let max_batch_total_tokens = max_batch_total_tokens.unwrap_or( - 16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)), - ); - tracing::warn!("Model does not support automatic max batch total tokens"); - Ok(max_batch_total_tokens) - } - // Flash attention models return their max supported total tokens - Some(max_supported_batch_total_tokens) => { - // Warn if user added his own max-batch-total-tokens as we will ignore it - if max_batch_total_tokens.is_some() { - tracing::warn!( - "`--max-batch-total-tokens` is deprecated for Flash \ - Attention models." - ); - tracing::warn!( - "Inferred max batch total tokens: {max_supported_batch_total_tokens}" - ); - } - if max_total_tokens as u32 > max_supported_batch_total_tokens { - return Err(WebServerError::NotEnoughMemory(max_total_tokens)); - } - - Ok(max_supported_batch_total_tokens) - } - } - }; - - let generation_health = Arc::new(AtomicBool::new(false)); - - match v3::ShardedClient::connect_uds(master_shard_uds_path.clone()).await { - Ok(mut sharded_client) => { - // server is running on v3 - // Clear the cache; useful if the webserver rebooted - sharded_client - .clear_cache(None) - .await - .map_err(WebServerError::Cache)?; - // Get info from the shard - let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?; - - // Warmup model - tracing::info!("Warming up model"); - let max_batch_total_tokens = check_max_batch_total_tokens( - sharded_client - .warmup( - max_input_tokens as u32, - max_batch_prefill_tokens, - max_total_tokens as u32, - max_batch_size, - ) - .await - .map_err(WebServerError::Warmup)?, - )?; - - let health_ext = - HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone()); - let scheduler = Arc::new(SchedulerV3::new( - sharded_client, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, - shard_info.requires_padding, - shard_info.window_size, - shard_info.speculate, - generation_health, - )); - tracing::info!("Using scheduler V3"); - - (scheduler, health_ext, shard_info, max_batch_total_tokens) - } - Err(_) => { - let mut sharded_client = v2::ShardedClient::connect_uds(master_shard_uds_path) - .await - .map_err(WebServerError::Connection)?; - - // server is running on v2 - // Clear the cache; useful if the webserver rebooted - sharded_client - .clear_cache(None) - .await - .map_err(WebServerError::Cache)?; - // Get info from the shard - let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?; - - // Warmup model - tracing::info!("Warming up model"); - let max_batch_total_tokens = check_max_batch_total_tokens( - sharded_client - .warmup( - max_input_tokens as u32, - max_batch_prefill_tokens, - max_total_tokens as u32, - max_batch_size, - ) - .await - .map_err(WebServerError::Warmup)?, - )?; - - let health_ext = - HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone()); - let scheduler = Arc::new(SchedulerV2::new( - sharded_client, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, - shard_info.requires_padding, - shard_info.window_size, - shard_info.speculate, - generation_health, - )); - tracing::info!("Using scheduler V2"); - - (scheduler, health_ext, shard_info, max_batch_total_tokens) - } - } - }; + ) = connect_backend( + master_shard_uds_path, + max_input_tokens, + max_total_tokens, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + ) + .await + .map_err(WebServerError::Scheduler)?; tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}"); let validation = Validation::new( @@ -1857,7 +1739,6 @@ pub async fn run( // add layers after routes app = app .layer(Extension(info)) - .layer(Extension(health_ext.clone())) .layer(Extension(compat_return_full_text)) .layer(Extension(infer)) .layer(Extension(compute_type)) @@ -1933,6 +1814,7 @@ impl From for (StatusCode, Json) { InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY, + InferError::Scheduler(_) => StatusCode::INTERNAL_SERVER_ERROR, }; ( @@ -1958,16 +1840,8 @@ impl From for Event { #[derive(Debug, Error)] pub enum WebServerError { - #[error("Unable to connect to the Python model shards: {0}")] - Connection(ClientError), - #[error("Unable to clear the Python model shards cache: {0}")] - Cache(ClientError), - #[error("Unable to get the Python model shards info: {0}")] - Info(ClientError), - #[error("Unable to warmup the Python model shards: {0}")] - Warmup(ClientError), - #[error("Not enough memory to handle `max_total_tokens={0}`")] - NotEnoughMemory(usize), + #[error("Scheduler error: {0}")] + Scheduler(#[from] SchedulerError), #[error("Axum error: {0}")] Axum(#[from] axum::BoxError), }