refactor schedulers

This commit is contained in:
OlivierDehaene 2024-06-20 12:40:36 +02:00
parent bcb3faa1c2
commit abf56b75a4
16 changed files with 485 additions and 454 deletions

1
Cargo.lock generated
View File

@ -3602,6 +3602,7 @@ name = "text-generation-router"
version = "2.0.5-dev0" version = "2.0.5-dev0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"async-trait",
"axum 0.7.5", "axum 0.7.5",
"axum-tracing-opentelemetry", "axum-tracing-opentelemetry",
"base64 0.22.1", "base64 0.22.1",

View File

@ -15,6 +15,7 @@ name = "text-generation-router"
path = "src/main.rs" path = "src/main.rs"
[dependencies] [dependencies]
async-trait = "^0.1"
async-stream = "0.3.5" async-stream = "0.3.5"
axum = { version = "0.7", features = ["json"] } axum = { version = "0.7", features = ["json"] }
axum-tracing-opentelemetry = "0.16" axum-tracing-opentelemetry = "0.16"

View File

@ -0,0 +1,75 @@
use crate::infer::InferError;
use crate::{ChatTemplateInputs, GrammarType, Message, MessageChunk, Text, TextMessage};
use minijinja::{Environment, ErrorKind, Template};
use minijinja_contrib::pycompat;
/// Raise a exception (custom function) used in the chat templates
pub(crate) fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
}
#[derive(Clone)]
pub(crate) struct ChatTemplate {
template: Template<'static, 'static>,
bos_token: Option<String>,
eos_token: Option<String>,
use_default_tool_template: bool,
}
impl ChatTemplate {
pub(crate) fn new(
template: String,
bos_token: Option<String>,
eos_token: Option<String>,
) -> Self {
let mut env = Box::new(Environment::new());
// enable things like .strip() or .capitalize()
env.set_unknown_method_callback(pycompat::unknown_method_callback);
let template_str = template.into_boxed_str();
env.add_function("raise_exception", raise_exception);
// check if contains the tools variable within the template
let use_default_tool_template =
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
// leaking env and template_str as read-only, static resources for performance.
let template = Box::leak(env)
.template_from_str(Box::leak(template_str))
.unwrap();
Self {
template,
bos_token,
eos_token,
use_default_tool_template,
}
}
pub(crate) fn apply(
&self,
mut messages: Vec<Message>,
grammar_with_prompt: Option<(GrammarType, String)>,
) -> Result<String, InferError> {
if self.use_default_tool_template {
if let Some(last_message) = messages.last_mut() {
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
last_message.content.push(MessageChunk::Text(Text {
text: format!("\n---\n{}\n{}", tool_prompt, tools),
}));
}
}
}
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
self.template
.render(ChatTemplateInputs {
messages,
bos_token: self.bos_token.as_deref(),
eos_token: self.eos_token.as_deref(),
add_generation_prompt: true,
tools: None,
tools_prompt: None,
})
.map_err(InferError::TemplateError)
}
}

View File

@ -1,34 +0,0 @@
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use text_generation_client::Health;
#[derive(Clone)]
pub(crate) struct HealthCheck {
client: Arc<dyn Health + Send + Sync>,
generation_health: Arc<AtomicBool>,
}
impl HealthCheck {
pub(crate) fn new(
client: Arc<dyn Health + Send + Sync>,
generation_health: Arc<AtomicBool>,
) -> Self {
Self {
client,
generation_health,
}
}
pub(crate) async fn check(&mut self) -> bool {
let value = if self.generation_health.load(Ordering::SeqCst) {
// Generation is healthy, we only check that the shards can allocate on device
self.client.device_health().await
} else {
self.client.model_health().await
}
.is_ok();
// Update generation health
self.generation_health.store(value, Ordering::SeqCst);
value
}
}

View File

@ -1,37 +1,31 @@
mod health; mod chat_template;
pub(crate) mod v2; pub(crate) mod schedulers;
pub(crate) mod v3; mod tool_grammar;
pub(crate) use health::HealthCheck; pub(crate) use tool_grammar::ToolGrammar;
use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; use crate::infer::chat_template::ChatTemplate;
use crate::validation::{Validation, ValidationError};
use crate::GrammarType;
use crate::{ use crate::{
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
HubTokenizerConfig, Message, MessageChunk, PrefillToken, Text, TextMessage, Token, Message, PrefillToken, Token,
}; };
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
use futures::future::try_join_all; use futures::future::try_join_all;
use minijinja::{Environment, ErrorKind, Template}; use minijinja::ErrorKind;
use minijinja_contrib::pycompat; pub(crate) use schedulers::Scheduler;
use serde_json::{json, Map, Value}; use crate::infer::schedulers::SchedulerError;
use std::collections::HashMap; use async_stream::stream;
use futures::Stream;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
use thiserror::Error; use thiserror::Error;
use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use tracing::instrument; use tracing::instrument;
pub(crate) trait Scheduler {
fn schedule(
&self,
request: ValidGenerateRequest,
permit: OwnedSemaphorePermit,
) -> Result<GenerateStreamResponse, InferError>;
}
/// Inference struct /// Inference struct
#[derive(Clone)] #[derive(Clone)]
pub struct Infer { pub struct Infer {
@ -43,6 +37,8 @@ pub struct Infer {
chat_template: Option<ChatTemplate>, chat_template: Option<ChatTemplate>,
/// Inference limit /// Inference limit
limit_concurrent_requests: Arc<Semaphore>, limit_concurrent_requests: Arc<Semaphore>,
/// Backend health
backend_health: Arc<AtomicBool>,
} }
impl Infer { impl Infer {
@ -69,20 +65,31 @@ impl Infer {
// Inference limit with a semaphore // Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
// Backend health
let backend_health = Arc::new(AtomicBool::new(false));
Self { Self {
validation, validation,
scheduler, scheduler,
chat_template, chat_template,
limit_concurrent_requests: semaphore, limit_concurrent_requests: semaphore,
backend_health,
} }
} }
/// Add a new request to the queue and return a stream of InferStreamResponse /// Add a new request to the queue and return a stream of InferStreamResponse
#[instrument(skip_all)] #[instrument(skip_all)]
pub(crate) async fn generate_stream( pub(crate) async fn generate_stream<'a>(
&self, &'a self,
request: GenerateRequest, request: GenerateRequest,
) -> Result<GenerateStreamResponse, InferError> { ) -> Result<
(
OwnedSemaphorePermit,
u32, // input_length
impl Stream<Item = Result<InferStreamResponse, InferError>> + 'a,
),
InferError,
> {
// Limit concurrent requests by acquiring a permit from the semaphore // Limit concurrent requests by acquiring a permit from the semaphore
let permit = self let permit = self
.clone() .clone()
@ -101,7 +108,20 @@ impl Infer {
err err
})?; })?;
self.scheduler.schedule(valid_request, permit) let input_length = valid_request.input_length;
let mut generation_stream = self
.scheduler
.schedule(valid_request)
.map_err(InferError::Scheduler)?;
let stream = stream! {
while let Some(generation) = generation_stream.next().await {
self.backend_health.store(generation.is_ok(), Ordering::SeqCst);
yield generation.map_err(InferError::GenerationError)
}
};
Ok((permit, input_length, stream))
} }
/// Tokenizer the input /// Tokenizer the input
@ -153,7 +173,7 @@ impl Infer {
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0); let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
// Create stream and keep semaphore permit as long as generate lives // Create stream and keep semaphore permit as long as generate lives
let (_permit, _input_length, mut stream) = self.generate_stream(request).await?; let (_permit, _input_length, stream) = self.generate_stream(request).await?;
// Return values // Return values
let mut result_prefill = Vec::new(); let mut result_prefill = Vec::new();
@ -163,6 +183,8 @@ impl Infer {
let mut result_start = None; let mut result_start = None;
let mut result_queued = None; let mut result_queued = None;
let mut stream = Box::pin(stream);
// Iterate on stream // Iterate on stream
while let Some(response) = stream.next().await { while let Some(response) = stream.next().await {
match response? { match response? {
@ -254,200 +276,18 @@ impl Infer {
let best_response = infer_responses.remove(max_index); let best_response = infer_responses.remove(max_index);
Ok((best_response, infer_responses)) Ok((best_response, infer_responses))
} }
}
/// Raise a exception (custom function) used in the chat templates #[instrument(skip(self))]
fn raise_exception(err_text: String) -> Result<String, minijinja::Error> { pub(crate) async fn health(&self) -> bool {
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text)) let health = self
} .scheduler
.health(self.backend_health.load(Ordering::SeqCst))
#[derive(Clone)] .await;
struct ChatTemplate { self.backend_health.store(health, Ordering::SeqCst);
template: Template<'static, 'static>, health
bos_token: Option<String>,
eos_token: Option<String>,
use_default_tool_template: bool,
}
impl ChatTemplate {
fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
let mut env = Box::new(Environment::new());
// enable things like .strip() or .capitalize()
env.set_unknown_method_callback(pycompat::unknown_method_callback);
let template_str = template.into_boxed_str();
env.add_function("raise_exception", raise_exception);
// check if contains the tools variable within the template
let use_default_tool_template =
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
// leaking env and template_str as read-only, static resources for performance.
let template = Box::leak(env)
.template_from_str(Box::leak(template_str))
.unwrap();
Self {
template,
bos_token,
eos_token,
use_default_tool_template,
}
}
fn apply(
&self,
mut messages: Vec<Message>,
grammar_with_prompt: Option<(GrammarType, String)>,
) -> Result<String, InferError> {
if self.use_default_tool_template {
if let Some(last_message) = messages.last_mut() {
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
last_message.content.push(MessageChunk::Text(Text {
text: format!("\n---\n{}\n{}", tool_prompt, tools),
}));
}
}
}
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
self.template
.render(ChatTemplateInputs {
messages,
bos_token: self.bos_token.as_deref(),
eos_token: self.eos_token.as_deref(),
add_generation_prompt: true,
tools: None,
tools_prompt: None,
})
.map_err(InferError::TemplateError)
} }
} }
pub struct ToolGrammar {}
impl ToolGrammar {
pub fn apply(
tools: Option<Vec<Tool>>,
tool_choice: Option<ToolType>,
) -> Result<Option<Tools>, InferError> {
if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) {
// let tool_prompt = tool_prompt.unwrap_or_default();
let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => {
vec![req_tools
.iter()
.find(|tool| tool.function.name == *name)
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
.clone()]
}
ToolType::OneOf => req_tools.to_owned(),
};
// adds the error notification function for LLM feedback if required
let mut text_response_properties = Map::new();
text_response_properties.insert(
"error".to_string(),
serde_json::json!({
"type": "string",
"description": "The error or issue to notify"
}),
);
text_response_properties.insert(
"_name".to_string(),
serde_json::json!({
"type": "string",
"const": "notify_error"
}),
);
let functions: HashMap<String, serde_json::Value> = tools_to_use
.iter()
.map(|tool| {
let func = tool.function.clone();
// Clone the existing parameters, which are expected to be a JSON object
let mut params = if let Value::Object(params) = &func.arguments {
params.clone()
} else {
Map::new()
};
// Insert the function's description at the top level, outside of properties
params.insert(
"description".to_string(),
Value::String(func.description.clone().unwrap_or_default()),
);
// Ensure 'properties' exists and is an object
let properties = params
.entry("properties".to_string())
.or_insert_with(|| json!({}))
.as_object_mut()
.unwrap();
// Insert the constant for the function name inside 'properties'
properties.insert(
"_name".to_string(),
json!({
"type": "string",
"const": func.name.clone(),
// "description": "The name of the function"
}),
);
// Check if 'required' exists, and it is an array. If not, create an empty array.
let required = params
.entry("required".to_string())
.or_insert_with(|| json!([]))
.as_array_mut()
.unwrap();
// Add 'name' to the 'required' array if it is not already present
if !required.iter().any(|r| r == "_name") {
required.push(json!("_name"));
}
(func.name, Value::Object(params))
})
.chain([(
"notify_error".to_string(),
serde_json::json!({
"properties": text_response_properties,
"required": ["error", "_name"],
"type": "object"
}),
)])
.collect();
let tools = Tools {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.chain(std::iter::once(FunctionRef {
ref_path: "#/$functions/notify_error".to_string(),
}))
.collect(),
},
};
return Ok(Some(tools));
}
// Err(InferError::ToolError("No tools provided".to_string()))
Ok(None)
}
}
/// Type alias for generation responses
pub(crate) type GenerateStreamResponse = (
OwnedSemaphorePermit,
u32, // input_length
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
);
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct GeneratedText { pub(crate) struct GeneratedText {
pub(crate) text: String, pub(crate) text: String,
@ -491,8 +331,10 @@ pub(crate) struct InferResponse {
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum InferError { pub enum InferError {
#[error("Request failed during scheduling: {0}")]
Scheduler(SchedulerError),
#[error("Request failed during generation: {0}")] #[error("Request failed during generation: {0}")]
GenerationError(String), GenerationError(SchedulerError),
#[error("Model is overloaded")] #[error("Model is overloaded")]
Overloaded(#[from] TryAcquireError), Overloaded(#[from] TryAcquireError),
#[error("Input validation error: {0}")] #[error("Input validation error: {0}")]
@ -508,6 +350,7 @@ pub enum InferError {
impl InferError { impl InferError {
pub(crate) fn error_type(&self) -> &str { pub(crate) fn error_type(&self) -> &str {
match self { match self {
InferError::Scheduler(_) => "scheduler",
InferError::GenerationError(_) => "generation", InferError::GenerationError(_) => "generation",
InferError::Overloaded(_) => "overloaded", InferError::Overloaded(_) => "overloaded",
InferError::ValidationError(_) => "validation", InferError::ValidationError(_) => "validation",

View File

@ -0,0 +1,54 @@
mod v3;
use crate::infer::InferStreamResponse;
use crate::validation::ValidGenerateRequest;
use async_trait::async_trait;
use std::sync::Arc;
use text_generation_client::ShardInfo;
use thiserror::Error;
use tokio_stream::wrappers::UnboundedReceiverStream;
#[async_trait]
pub(crate) trait Scheduler {
fn schedule(
&self,
request: ValidGenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, SchedulerError>>, SchedulerError>;
async fn health(&self, current_health: bool) -> bool;
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn connect_backend(
master_shard_uds_path: String,
max_input_tokens: usize,
max_total_tokens: usize,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
) -> Result<(Arc<dyn Scheduler + Send + Sync>, ShardInfo, u32), SchedulerError> {
v3::connect_backend(
master_shard_uds_path,
max_input_tokens,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
)
.await
.map_err(|err| SchedulerError::Startup(Box::new(err)))
}
#[derive(Debug, Error)]
pub enum SchedulerError {
#[error("Startup error: {0}")]
Startup(Box<dyn std::error::Error + Send + Sync>),
#[error("Request failed during generation: {0}")]
Generation(Box<dyn std::error::Error + Send + Sync>),
#[error("Backend error: {0}")]
Backend(Box<dyn std::error::Error + Send + Sync>),
}

View File

@ -1,5 +1,5 @@
/// Batching and inference logic /// Batching and inference logic
use crate::infer::v2::queue::{Entry, Queue}; use crate::infer::schedulers::v2::queue::{Entry, Queue};
use crate::infer::{ use crate::infer::{
GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler,
}; };
@ -498,7 +498,7 @@ impl From<text_generation_client::v2::GeneratedText> for GeneratedText {
// tests // tests
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::infer::raise_exception; use crate::infer::chat_template::raise_exception;
use crate::{ChatTemplateInputs, TextMessage}; use crate::{ChatTemplateInputs, TextMessage};
use minijinja::Environment; use minijinja::Environment;

View File

@ -0,0 +1,109 @@
mod block_allocator;
mod queue;
mod scheduler;
use crate::infer::schedulers::v3::scheduler::SchedulerV3;
use crate::infer::schedulers::Scheduler;
use std::sync::Arc;
use text_generation_client::v3::ShardedClient;
use text_generation_client::{ClientError, ShardInfo};
use thiserror::Error;
#[allow(clippy::too_many_arguments)]
pub(crate) async fn connect_backend(
master_shard_uds_path: String,
max_input_tokens: usize,
max_total_tokens: usize,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
) -> Result<(Arc<dyn Scheduler + Send + Sync>, ShardInfo, u32), V3Error> {
// Helper function
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
match max_supported_batch_total_tokens {
// Older models do not support automatic max-batch-total-tokens
None => {
let max_batch_total_tokens = max_batch_total_tokens
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
tracing::warn!("Model does not support automatic max batch total tokens");
Ok(max_batch_total_tokens)
}
// Flash attention models return their max supported total tokens
Some(max_supported_batch_total_tokens) => {
// Warn if user added his own max-batch-total-tokens as we will ignore it
if max_batch_total_tokens.is_some() {
tracing::warn!(
"`--max-batch-total-tokens` is deprecated for Flash \
Attention models."
);
tracing::warn!(
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
);
}
if max_total_tokens as u32 > max_supported_batch_total_tokens {
return Err(V3Error::NotEnoughMemory(max_total_tokens));
}
Ok(max_supported_batch_total_tokens)
}
}
};
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await
.map_err(V3Error::Connection)?;
// server is running on v3
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.map_err(V3Error::Cache)?;
// Get info from the shard
let shard_info = sharded_client.info().await.map_err(V3Error::Info)?;
// Warmup model
tracing::info!("Warming up model");
let max_batch_total_tokens = check_max_batch_total_tokens(
sharded_client
.warmup(
max_input_tokens as u32,
max_batch_prefill_tokens,
max_total_tokens as u32,
max_batch_size,
)
.await
.map_err(V3Error::Warmup)?,
)?;
let scheduler = Arc::new(SchedulerV3::new(
sharded_client,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
shard_info.requires_padding,
shard_info.window_size,
shard_info.speculate,
));
tracing::info!("Using scheduler V3");
Ok((scheduler, shard_info, max_batch_total_tokens))
}
#[derive(Debug, Error)]
pub(crate) enum V3Error {
#[error("Unable to clear the Python model shards cache: {0}")]
Cache(ClientError),
#[error("Unable to connect to the Python model shards: {0}")]
Connection(ClientError),
#[error("Unable to get the Python model shards info: {0}")]
Info(ClientError),
#[error("Unable to warmup the Python model shards: {0}")]
Warmup(ClientError),
#[error("Not enough memory to handle `max_total_tokens={0}`")]
NotEnoughMemory(usize),
}

View File

@ -1,5 +1,5 @@
use crate::infer::v3::block_allocator::{BlockAllocation, BlockAllocator}; use crate::infer::schedulers::v3::block_allocator::{BlockAllocation, BlockAllocator};
use crate::infer::InferError; use crate::infer::schedulers::SchedulerError;
use crate::infer::InferStreamResponse; use crate::infer::InferStreamResponse;
use crate::validation::{ use crate::validation::{
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
@ -22,7 +22,7 @@ pub(crate) struct Entry {
/// Request /// Request
pub request: ValidGenerateRequest, pub request: ValidGenerateRequest,
/// Response sender to communicate between the Infer struct and the batching_task /// Response sender to communicate between the Infer struct and the batching_task
pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>, pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, SchedulerError>>,
/// Span that will live as long as entry /// Span that will live as long as entry
pub span: Span, pub span: Span,
/// Temporary span used as a guard when logging inference, wait times... /// Temporary span used as a guard when logging inference, wait times...
@ -463,7 +463,7 @@ mod tests {
fn default_entry() -> ( fn default_entry() -> (
Entry, Entry,
mpsc::UnboundedReceiver<Result<InferStreamResponse, InferError>>, mpsc::UnboundedReceiver<Result<InferStreamResponse, SchedulerError>>,
) { ) {
let (response_tx, receiver_tx) = mpsc::unbounded_channel(); let (response_tx, receiver_tx) = mpsc::unbounded_channel();

View File

@ -1,19 +1,16 @@
/// Batching and inference logic /// Batching and inference logic
use crate::infer::v3::queue::{Entry, Queue}; use crate::infer::schedulers::v3::queue::{Entry, Queue};
use crate::infer::{ use crate::infer::schedulers::SchedulerError;
GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler, use crate::infer::{GeneratedText, InferStreamResponse, Scheduler};
};
use crate::validation::ValidGenerateRequest; use crate::validation::ValidGenerateRequest;
use crate::{FinishReason, PrefillToken, Token}; use crate::{FinishReason, PrefillToken, Token};
use async_trait::async_trait;
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
use std::sync::{ use std::sync::Arc;
atomic::{AtomicBool, Ordering},
Arc,
};
use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient}; use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient};
use text_generation_client::ClientError; use text_generation_client::{ClientError, Health};
use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit}; use tokio::sync::{mpsc, Notify};
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{info_span, instrument, Instrument, Span}; use tracing::{info_span, instrument, Instrument, Span};
@ -23,6 +20,8 @@ pub(crate) struct SchedulerV3 {
queue: Queue, queue: Queue,
/// Notify batcher on queue appends /// Notify batcher on queue appends
batching_task_notifier: Arc<Notify>, batching_task_notifier: Arc<Notify>,
/// Client, used for health checks to skip the queue
client: ShardedClient,
} }
impl SchedulerV3 { impl SchedulerV3 {
@ -37,7 +36,6 @@ impl SchedulerV3 {
requires_padding: bool, requires_padding: bool,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
generation_health: Arc<AtomicBool>,
) -> Self { ) -> Self {
let queue = Queue::new( let queue = Queue::new(
requires_padding, requires_padding,
@ -50,7 +48,7 @@ impl SchedulerV3 {
// Spawn batching background task that contains all the inference logic // Spawn batching background task that contains all the inference logic
tokio::spawn(batching_task( tokio::spawn(batching_task(
client, client.clone(),
waiting_served_ratio, waiting_served_ratio,
max_batch_prefill_tokens, max_batch_prefill_tokens,
max_batch_total_tokens, max_batch_total_tokens,
@ -58,26 +56,26 @@ impl SchedulerV3 {
max_batch_size, max_batch_size,
queue.clone(), queue.clone(),
batching_task_notifier.clone(), batching_task_notifier.clone(),
generation_health,
)); ));
Self { Self {
queue, queue,
batching_task_notifier, batching_task_notifier,
client,
} }
} }
} }
#[async_trait]
impl Scheduler for SchedulerV3 { impl Scheduler for SchedulerV3 {
#[instrument(skip_all)] #[instrument(skip_all)]
fn schedule( fn schedule(
&self, &self,
request: ValidGenerateRequest, request: ValidGenerateRequest,
permit: OwnedSemaphorePermit, ) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, SchedulerError>>, SchedulerError>
) -> Result<GenerateStreamResponse, InferError> { {
// MPSC channel to communicate with the background batching task // MPSC channel to communicate with the background batching task
let (response_tx, response_rx) = mpsc::unbounded_channel(); let (response_tx, response_rx) = mpsc::unbounded_channel();
let input_length = request.input_length;
// Append the request to the queue // Append the request to the queue
self.queue.append(Entry { self.queue.append(Entry {
@ -95,11 +93,17 @@ impl Scheduler for SchedulerV3 {
self.batching_task_notifier.notify_one(); self.batching_task_notifier.notify_one();
// Return stream // Return stream
Ok(( Ok(UnboundedReceiverStream::new(response_rx))
permit, }
input_length,
UnboundedReceiverStream::new(response_rx), async fn health(&self, current_health: bool) -> bool {
)) if current_health {
// Generation is healthy, we only check that the shards can allocate on device
self.client.device_health().await
} else {
self.client.model_health().await
}
.is_ok()
} }
} }
@ -117,7 +121,6 @@ pub(crate) async fn batching_task(
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
queue: Queue, queue: Queue,
notifier: Arc<Notify>, notifier: Arc<Notify>,
generation_health: Arc<AtomicBool>,
) { ) {
// Infinite loop // Infinite loop
loop { loop {
@ -136,7 +139,7 @@ pub(crate) async fn batching_task(
) )
.await .await
{ {
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health) let mut cached_batch = prefill(&mut client, batch, &mut entries)
.instrument(span) .instrument(span)
.await; .await;
let mut waiting_tokens = 1; let mut waiting_tokens = 1;
@ -187,8 +190,7 @@ pub(crate) async fn batching_task(
}); });
// Generate one token for this new batch to have the attention past in cache // Generate one token for this new batch to have the attention past in cache
let new_cached_batch = let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
prefill(&mut client, new_batch, &mut new_entries, &generation_health)
.instrument(span) .instrument(span)
.await; .await;
// Reset waiting counter // Reset waiting counter
@ -214,7 +216,7 @@ pub(crate) async fn batching_task(
entry.temp_span = Some(entry_batch_span); entry.temp_span = Some(entry_batch_span);
}); });
cached_batch = decode(&mut client, batches, &mut entries, &generation_health) cached_batch = decode(&mut client, batches, &mut entries)
.instrument(next_batch_span) .instrument(next_batch_span)
.await; .await;
waiting_tokens += 1; waiting_tokens += 1;
@ -230,7 +232,6 @@ async fn prefill(
client: &mut ShardedClient, client: &mut ShardedClient,
batch: Batch, batch: Batch,
entries: &mut IntMap<u64, Entry>, entries: &mut IntMap<u64, Entry>,
generation_health: &Arc<AtomicBool>,
) -> Option<CachedBatch> { ) -> Option<CachedBatch> {
let start_time = Instant::now(); let start_time = Instant::now();
let batch_id = batch.id; let batch_id = batch.id;
@ -238,9 +239,6 @@ async fn prefill(
match client.prefill(batch).await { match client.prefill(batch).await {
Ok((generations, next_batch, timings)) => { Ok((generations, next_batch, timings)) => {
// Update health
generation_health.store(true, Ordering::SeqCst);
let start_filtering_time = Instant::now(); let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries // Send generated tokens and filter stopped entries
filter_send_generations(generations, entries); filter_send_generations(generations, entries);
@ -257,8 +255,6 @@ async fn prefill(
} }
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
Err(err) => { Err(err) => {
// Update health
generation_health.store(false, Ordering::SeqCst);
let _ = client.clear_cache(Some(batch_id)).await; let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries); send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
@ -272,7 +268,6 @@ async fn decode(
client: &mut ShardedClient, client: &mut ShardedClient,
batches: Vec<CachedBatch>, batches: Vec<CachedBatch>,
entries: &mut IntMap<u64, Entry>, entries: &mut IntMap<u64, Entry>,
generation_health: &Arc<AtomicBool>,
) -> Option<CachedBatch> { ) -> Option<CachedBatch> {
let start_time = Instant::now(); let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect(); let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
@ -280,9 +275,6 @@ async fn decode(
match client.decode(batches).await { match client.decode(batches).await {
Ok((generations, next_batch, timings)) => { Ok((generations, next_batch, timings)) => {
// Update health
generation_health.store(true, Ordering::SeqCst);
let start_filtering_time = Instant::now(); let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries // Send generated tokens and filter stopped entries
filter_send_generations(generations, entries); filter_send_generations(generations, entries);
@ -302,7 +294,6 @@ async fn decode(
} }
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
Err(err) => { Err(err) => {
generation_health.store(false, Ordering::SeqCst);
for id in batch_ids { for id in batch_ids {
let _ = client.clear_cache(Some(id)).await; let _ = client.clear_cache(Some(id)).await;
} }
@ -378,7 +369,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
fn send_responses( fn send_responses(
generation: Generation, generation: Generation,
entry: &Entry, entry: &Entry,
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> { ) -> Result<bool, Box<SendError<Result<InferStreamResponse, SchedulerError>>>> {
// Return directly if the channel is disconnected // Return directly if the channel is disconnected
if entry.response_tx.is_closed() { if entry.response_tx.is_closed() {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
@ -471,7 +462,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
entries.drain().for_each(|(_, entry)| { entries.drain().for_each(|(_, entry)| {
// Create and enter a span to link this function back to the entry // Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::GenerationError(error.to_string()); let err = SchedulerError::Generation(Box::new(error.clone()));
metrics::increment_counter!("tgi_request_failure", "err" => "generation"); metrics::increment_counter!("tgi_request_failure", "err" => "generation");
tracing::error!("{err}"); tracing::error!("{err}");
@ -505,7 +496,7 @@ impl From<text_generation_client::v3::GeneratedText> for GeneratedText {
// tests // tests
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::infer::raise_exception; use crate::infer::chat_template::raise_exception;
use crate::{ChatTemplateInputs, TextMessage}; use crate::{ChatTemplateInputs, TextMessage};
use minijinja::Environment; use minijinja::Environment;

View File

@ -0,0 +1,122 @@
use crate::infer::InferError;
use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolType, Tools};
use serde_json::{json, Map, Value};
use std::collections::HashMap;
pub(crate) struct ToolGrammar {}
impl ToolGrammar {
pub fn apply(
tools: Option<Vec<Tool>>,
tool_choice: Option<ToolType>,
) -> Result<Option<Tools>, InferError> {
if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) {
// let tool_prompt = tool_prompt.unwrap_or_default();
let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => {
vec![req_tools
.iter()
.find(|tool| tool.function.name == *name)
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
.clone()]
}
ToolType::OneOf => req_tools.to_owned(),
};
// adds the error notification function for LLM feedback if required
let mut text_response_properties = Map::new();
text_response_properties.insert(
"error".to_string(),
json!({
"type": "string",
"description": "The error or issue to notify"
}),
);
text_response_properties.insert(
"_name".to_string(),
json!({
"type": "string",
"const": "notify_error"
}),
);
let functions: HashMap<String, Value> = tools_to_use
.iter()
.map(|tool| {
let func = tool.function.clone();
// Clone the existing parameters, which are expected to be a JSON object
let mut params = if let Value::Object(params) = &func.arguments {
params.clone()
} else {
Map::new()
};
// Insert the function's description at the top level, outside of properties
params.insert(
"description".to_string(),
Value::String(func.description.clone().unwrap_or_default()),
);
// Ensure 'properties' exists and is an object
let properties = params
.entry("properties".to_string())
.or_insert_with(|| json!({}))
.as_object_mut()
.unwrap();
// Insert the constant for the function name inside 'properties'
properties.insert(
"_name".to_string(),
json!({
"type": "string",
"const": func.name.clone(),
// "description": "The name of the function"
}),
);
// Check if 'required' exists, and it is an array. If not, create an empty array.
let required = params
.entry("required".to_string())
.or_insert_with(|| json!([]))
.as_array_mut()
.unwrap();
// Add 'name' to the 'required' array if it is not already present
if !required.iter().any(|r| r == "_name") {
required.push(json!("_name"));
}
(func.name, Value::Object(params))
})
.chain([(
"notify_error".to_string(),
json!({
"properties": text_response_properties,
"required": ["error", "_name"],
"type": "object"
}),
)])
.collect();
let tools = Tools {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.chain(std::iter::once(FunctionRef {
ref_path: "#/$functions/notify_error".to_string(),
}))
.collect(),
},
};
return Ok(Some(tools));
}
// Err(InferError::ToolError("No tools provided".to_string()))
Ok(None)
}
}

View File

@ -1,5 +0,0 @@
mod block_allocator;
mod queue;
mod scheduler;
pub(crate) use scheduler::SchedulerV3;

View File

@ -1,8 +1,7 @@
/// HTTP Server logic /// HTTP Server logic
use crate::config::Config; use crate::config::Config;
use crate::infer::v2::SchedulerV2; use crate::infer::schedulers::{connect_backend, SchedulerError};
use crate::infer::v3::SchedulerV3; use crate::infer::Scheduler;
use crate::infer::{HealthCheck, Scheduler};
use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar}; use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar};
#[cfg(feature = "kserve")] #[cfg(feature = "kserve")]
use crate::kserve::{ use crate::kserve::{
@ -39,9 +38,8 @@ use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use serde_json::Value; use serde_json::Value;
use std::convert::Infallible; use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::AtomicBool;
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{v2, v3, ClientError, ShardInfo}; use text_generation_client::ShardInfo;
use thiserror::Error; use thiserror::Error;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::select; use tokio::select;
@ -121,12 +119,10 @@ responses(
example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})), example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
) )
)] )]
#[instrument(skip(health))] #[instrument(skip(infer))]
/// Health check method /// Health check method
async fn health( async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
mut health: Extension<HealthCheck>, match infer.health().await {
) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
match health.check().await {
true => Ok(()), true => Ok(()),
false => Err(( false => Err((
StatusCode::SERVICE_UNAVAILABLE, StatusCode::SERVICE_UNAVAILABLE,
@ -437,8 +433,9 @@ async fn generate_stream_internal(
} else { } else {
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives // Keep permit as long as generate_stream lives
Ok((_permit, _input_length, mut response_stream)) => { Ok((_permit, _input_length, response_stream)) => {
let mut index = 0; let mut index = 0;
let mut response_stream = Box::pin(response_stream);
// Server-Sent Event stream // Server-Sent Event stream
while let Some(response) = response_stream.next().await { while let Some(response) = response_stream.next().await {
index += 1; index += 1;
@ -1497,137 +1494,22 @@ pub async fn run(
// Create state // Create state
// Open connection, get model info and warmup // Open connection, get model info and warmup
let (scheduler, health_ext, shard_info, max_batch_total_tokens): ( let (scheduler, shard_info, max_batch_total_tokens): (
Arc<dyn Scheduler + Send + Sync>, Arc<dyn Scheduler + Send + Sync>,
HealthCheck,
ShardInfo, ShardInfo,
u32, u32,
) = { ) = connect_backend(
// Helper function to check both v2 and v3 master_shard_uds_path,
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| { max_input_tokens,
match max_supported_batch_total_tokens { max_total_tokens,
// Older models do not support automatic max-batch-total-tokens
None => {
let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)),
);
tracing::warn!("Model does not support automatic max batch total tokens");
Ok(max_batch_total_tokens)
}
// Flash attention models return their max supported total tokens
Some(max_supported_batch_total_tokens) => {
// Warn if user added his own max-batch-total-tokens as we will ignore it
if max_batch_total_tokens.is_some() {
tracing::warn!(
"`--max-batch-total-tokens` is deprecated for Flash \
Attention models."
);
tracing::warn!(
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
);
}
if max_total_tokens as u32 > max_supported_batch_total_tokens {
return Err(WebServerError::NotEnoughMemory(max_total_tokens));
}
Ok(max_supported_batch_total_tokens)
}
}
};
let generation_health = Arc::new(AtomicBool::new(false));
match v3::ShardedClient::connect_uds(master_shard_uds_path.clone()).await {
Ok(mut sharded_client) => {
// server is running on v3
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.map_err(WebServerError::Cache)?;
// Get info from the shard
let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?;
// Warmup model
tracing::info!("Warming up model");
let max_batch_total_tokens = check_max_batch_total_tokens(
sharded_client
.warmup(
max_input_tokens as u32,
max_batch_prefill_tokens,
max_total_tokens as u32,
max_batch_size,
)
.await
.map_err(WebServerError::Warmup)?,
)?;
let health_ext =
HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone());
let scheduler = Arc::new(SchedulerV3::new(
sharded_client,
waiting_served_ratio, waiting_served_ratio,
max_batch_prefill_tokens, max_batch_prefill_tokens,
max_batch_total_tokens, max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
max_batch_size, max_batch_size,
shard_info.requires_padding,
shard_info.window_size,
shard_info.speculate,
generation_health,
));
tracing::info!("Using scheduler V3");
(scheduler, health_ext, shard_info, max_batch_total_tokens)
}
Err(_) => {
let mut sharded_client = v2::ShardedClient::connect_uds(master_shard_uds_path)
.await
.map_err(WebServerError::Connection)?;
// server is running on v2
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.map_err(WebServerError::Cache)?;
// Get info from the shard
let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?;
// Warmup model
tracing::info!("Warming up model");
let max_batch_total_tokens = check_max_batch_total_tokens(
sharded_client
.warmup(
max_input_tokens as u32,
max_batch_prefill_tokens,
max_total_tokens as u32,
max_batch_size,
) )
.await .await
.map_err(WebServerError::Warmup)?, .map_err(WebServerError::Scheduler)?;
)?;
let health_ext =
HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone());
let scheduler = Arc::new(SchedulerV2::new(
sharded_client,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
shard_info.requires_padding,
shard_info.window_size,
shard_info.speculate,
generation_health,
));
tracing::info!("Using scheduler V2");
(scheduler, health_ext, shard_info, max_batch_total_tokens)
}
}
};
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}"); tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
let validation = Validation::new( let validation = Validation::new(
@ -1857,7 +1739,6 @@ pub async fn run(
// add layers after routes // add layers after routes
app = app app = app
.layer(Extension(info)) .layer(Extension(info))
.layer(Extension(health_ext.clone()))
.layer(Extension(compat_return_full_text)) .layer(Extension(compat_return_full_text))
.layer(Extension(infer)) .layer(Extension(infer))
.layer(Extension(compute_type)) .layer(Extension(compute_type))
@ -1933,6 +1814,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::Scheduler(_) => StatusCode::INTERNAL_SERVER_ERROR,
}; };
( (
@ -1958,16 +1840,8 @@ impl From<InferError> for Event {
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum WebServerError { pub enum WebServerError {
#[error("Unable to connect to the Python model shards: {0}")] #[error("Scheduler error: {0}")]
Connection(ClientError), Scheduler(#[from] SchedulerError),
#[error("Unable to clear the Python model shards cache: {0}")]
Cache(ClientError),
#[error("Unable to get the Python model shards info: {0}")]
Info(ClientError),
#[error("Unable to warmup the Python model shards: {0}")]
Warmup(ClientError),
#[error("Not enough memory to handle `max_total_tokens={0}`")]
NotEnoughMemory(usize),
#[error("Axum error: {0}")] #[error("Axum error: {0}")]
Axum(#[from] axum::BoxError), Axum(#[from] axum::BoxError),
} }