This commit is contained in:
OlivierDehaene 2024-06-26 13:13:32 +02:00
parent 504754861f
commit b562680be4
9 changed files with 687 additions and 224 deletions

1
Cargo.lock generated
View File

@ -3602,6 +3602,7 @@ name = "text-generation-router"
version = "2.0.5-dev0"
dependencies = [
"async-stream",
"async-trait",
"axum 0.7.5",
"axum-tracing-opentelemetry",
"base64 0.22.1",

View File

@ -15,6 +15,7 @@ name = "text-generation-router"
path = "src/main.rs"
[dependencies]
async-trait = "0.1.74"
async-stream = "0.3.5"
axum = { version = "0.7", features = ["json"] }
axum-tracing-opentelemetry = "0.16"

View File

@ -1,24 +1,18 @@
mod health;
pub(crate) mod v2;
// pub(crate) mod v2;
pub(crate) mod v3;
mod chat_template;
mod tool_grammar;
pub(crate) use health::HealthCheck;
pub mod tool_grammar;
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
use crate::{
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig,
HubTokenizerConfig, Message, MessageChunk, PrefillToken, Text, TextMessage, Token,
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig,
HubTokenizerConfig, Message, PrefillToken, Token,
};
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
use crate::{GrammarType};
use futures::future::try_join_all;
use minijinja::{Environment, ErrorKind, Template};
use minijinja_contrib::pycompat;
use serde_json::{json, Map, Value};
use std::collections::HashMap;
use minijinja::{ErrorKind};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use thiserror::Error;
use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::time::Instant;
@ -26,13 +20,16 @@ use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt;
use tracing::instrument;
use chat_template::ChatTemplate;
use async_trait::async_trait;
pub(crate) trait Scheduler {
#[async_trait]
pub(crate) trait Backend {
fn schedule(
&self,
request: ValidGenerateRequest,
permit: OwnedSemaphorePermit,
) -> Result<GenerateStreamResponse, InferError>;
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError>;
async fn health(&self, current_health: bool) -> bool;
}
/// Inference struct
@ -90,15 +87,7 @@ impl Infer {
#[instrument(skip_all)]
pub(crate) async fn generate_stream<'a>(
&'a self,
request: GenerateRequest,
) -> Result<
(
OwnedSemaphorePermit,
u32, // input_length
impl Stream<Item=Result<InferStreamResponse, InferError>> + 'a,
),
InferError,
> {
request: GenerateRequest) -> Result<GenerateStreamResponse, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore
let permit = self
.clone()
@ -118,35 +107,11 @@ impl Infer {
})?;
let input_length = valid_request.input_length;
let mut generation_stream = self
let generation_stream = self
.backend
.schedule(valid_request)
.map_err(InferError::Backend)?;
.schedule(valid_request)?;
let stream = stream! {
while let Some(generation) = generation_stream.next().await {
self.backend_health.store(generation.is_ok(), Ordering::SeqCst);
yield generation.map(|generation| match generation {
types::TokenStreamResponse::Prefill(prefill_tokens) => InferStreamResponse::Prefill(
prefill_tokens.into_iter().map(PrefillToken::from).collect()
),
types::TokenStreamResponse::Intermediate { token, top_tokens } => InferStreamResponse::Intermediate {
token: Token::from(token),
top_tokens: top_tokens.into_iter().map(Token::from).collect(),
},
types::TokenStreamResponse::End { token, top_tokens, generated_text, start, queued } => InferStreamResponse::End {
token: Token::from(token),
top_tokens: top_tokens.into_iter().map(Token::from).collect(),
generated_text,
start,
queued,
}
}).map_err(InferError::GenerationError)
}
};
Ok((permit, input_length, stream))
Ok((permit, input_length, generation_stream))
}
/// Tokenizer the input
@ -363,10 +328,8 @@ pub(crate) struct InferResponse {
#[derive(Debug, Error)]
pub enum InferError {
#[error("Request failed during scheduling: {0}")]
Backend(BackendError),
#[error("Request failed during generation: {0}")]
GenerationError(BackendError),
GenerationError(String),
#[error("Model is overloaded")]
Overloaded(#[from] TryAcquireError),
#[error("Input validation error: {0}")]
@ -382,7 +345,6 @@ pub enum InferError {
impl InferError {
pub(crate) fn error_type(&self) -> &str {
match self {
InferError::Backend(_) => "backend",
InferError::GenerationError(_) => "generation",
InferError::Overloaded(_) => "overloaded",
InferError::ValidationError(_) => "validation",

View File

@ -1,4 +1,4 @@
mod queue;
mod scheduler;
pub(crate) use scheduler::SchedulerV2;
pub(crate) use scheduler::BackendV2;

View File

@ -1,7 +1,7 @@
/// Batching and inference logic
use crate::infer::v2::queue::{Entry, Queue};
use crate::infer::{
GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler,
GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Backend,
};
use crate::validation::ValidGenerateRequest;
use crate::{FinishReason, PrefillToken, Token};
@ -18,14 +18,14 @@ use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{info_span, instrument, Instrument, Span};
pub(crate) struct SchedulerV2 {
pub(crate) struct BackendV2 {
/// Request queue
queue: Queue,
/// Notify batcher on queue appends
batching_task_notifier: Arc<Notify>,
}
impl SchedulerV2 {
impl BackendV2 {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
client: ShardedClient,
@ -62,7 +62,7 @@ impl SchedulerV2 {
}
}
impl Scheduler for SchedulerV2 {
impl Backend for BackendV2 {
#[instrument(skip_all)]
fn schedule(
&self,

View File

@ -0,0 +1,500 @@
/// Batching and inference logic
use crate::infer::v3::queue::{Entry, Queue};
use crate::infer::{
GeneratedText, InferError, InferStreamResponse, Backend,
};
use crate::validation::ValidGenerateRequest;
use crate::{FinishReason, PrefillToken, Token};
use nohash_hasher::IntMap;
use std::sync::{
Arc,
};
use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient};
use text_generation_client::{ClientError, Health};
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{info_span, instrument, Instrument, Span};
use async_trait::async_trait;
pub(crate) struct BackendV3 {
/// Request queue
queue: Queue,
/// Notify batcher on queue appends
batching_task_notifier: Arc<Notify>,
/// Client clone, used for health checks to skip the queue
client: ShardedClient,
}
impl BackendV3 {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
client: ShardedClient,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
requires_padding: bool,
window_size: Option<u32>,
speculate: u32,
) -> Self {
let queue = Queue::new(
requires_padding,
16,
window_size,
speculate,
max_batch_total_tokens,
);
let batching_task_notifier = Arc::new(Notify::new());
// Spawn batching background task that contains all the inference logic
tokio::spawn(batching_task(
client.clone(),
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
queue.clone(),
batching_task_notifier.clone(),
));
Self {
queue,
batching_task_notifier,
client,
}
}
}
#[async_trait]
impl Backend for BackendV3 {
#[instrument(skip_all)]
fn schedule(
&self,
request: ValidGenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
// MPSC channel to communicate with the background batching task
let (response_tx, response_rx) = mpsc::unbounded_channel();
let input_length = request.input_length;
// Append the request to the queue
self.queue.append(Entry {
request,
response_tx,
span: Span::current(),
temp_span: None,
queue_time: Instant::now(),
batch_time: None,
block_allocation: None,
});
// Notify the background task that we have a new entry in the queue that needs
// to be batched
self.batching_task_notifier.notify_one();
// Return stream
Ok(
UnboundedReceiverStream::new(response_rx),
)
}
async fn health(&self, current_health: bool) -> bool {
if current_health {
// Generation is healthy, we only check that the shards can allocate on device
self.client.device_health().await
} else {
self.client.model_health().await
}
.is_ok()
}
}
/// Batching logic
/// Will be launched in a background Tokio task
///
/// Batches requests and sends them to the inference server
#[allow(clippy::too_many_arguments)]
pub(crate) async fn batching_task(
mut client: ShardedClient,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
queue: Queue,
notifier: Arc<Notify>,
) {
// Infinite loop
loop {
// Wait for a notification from the Infer struct
notifier.notified().await;
// Get the next batch from the queue
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue
while let Some((mut entries, batch, span)) = queue
.next_batch(
None,
max_batch_size,
max_batch_prefill_tokens,
max_batch_total_tokens,
)
.await
{
let mut cached_batch = prefill(&mut client, batch, &mut entries)
.instrument(span)
.await;
let mut waiting_tokens = 1;
// We loop until we do not receive any cached batch from the inference server (== until
// all requests have met their stopping criteria)
while let Some(batch) = cached_batch {
// Get current batch info
let batch_size = batch.size;
let batch_max_tokens = batch.max_tokens;
let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64);
let min_size = if waiting_tokens >= max_waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small
None
} else {
// Minimum batch size
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
};
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize);
// Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
.await
{
// Tracking metrics
if min_size.is_some() {
metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure");
} else {
metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded");
}
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
// Add relationships
span.follows_from(&entry_waiting_span);
entry_waiting_span.follows_from(&span);
// Update entry
entry.temp_span = Some(entry_waiting_span);
});
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch =
prefill(&mut client, new_batch, &mut new_entries)
.instrument(span)
.await;
// Reset waiting counter
waiting_tokens = 1;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
entries.extend(new_entries);
batches.push(new_cached_batch);
}
}
// Create span for this batch to add context to inference calls
let next_batch_size = entries.len();
let next_batch_span =
info_span!(parent: None, "batch", batch_size = next_batch_size);
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer");
// Add relationships
next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span);
// Update entry
entry.temp_span = Some(entry_batch_span);
});
cached_batch = decode(&mut client, batches, &mut entries)
.instrument(next_batch_span)
.await;
waiting_tokens += 1;
}
metrics::gauge!("tgi_batch_current_size", 0.0);
metrics::gauge!("tgi_batch_current_max_tokens", 0.0);
}
}
}
#[instrument(skip_all)]
async fn prefill(
client: &mut ShardedClient,
batch: Batch,
entries: &mut IntMap<u64, Entry>,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_id = batch.id;
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
match client.prefill(batch).await {
Ok((generations, next_batch, timings)) => {
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill");
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
None
}
}
}
#[instrument(skip_all)]
async fn decode(
client: &mut ShardedClient,
batches: Vec<CachedBatch>,
entries: &mut IntMap<u64, Entry>,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
match client.decode(batches).await {
Ok((generations, next_batch, timings)) => {
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
}
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode");
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode");
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode");
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
for id in batch_ids {
let _ = client.clear_cache(Some(id)).await;
}
send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
None
}
}
}
/// Filter a `batch` and remove all requests not present in `entries`
#[instrument(skip_all)]
async fn filter_batch(
client: &mut ShardedClient,
next_batch: Option<CachedBatch>,
entries: &IntMap<u64, Entry>,
) -> Option<CachedBatch> {
let mut batch = next_batch?;
// No need to filter
if batch.size as usize == entries.len() {
return Some(batch);
}
let id = batch.id;
// Retain only requests that are still in entries
batch.request_ids.retain(|id| entries.contains_key(id));
if batch.request_ids.is_empty() {
// All requests have been filtered out
// Next batch is now empty
// Clear it from the Python shards cache
// We unwrap here as we need to panic since we cannot recover if this method fails
client.clear_cache(Some(id)).await.unwrap();
None
} else {
// Filter Python shard cache
// We unwrap here as we need to panic since we cannot recover if this method fails
client.filter_batch(id, batch.request_ids).await.unwrap()
}
}
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
/// and filter entries
#[instrument(skip_all)]
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
generations.into_iter().for_each(|generation| {
let id = generation.request_id;
// Get entry
// We can `expect` here as the request id should always be in the entries
let entry = entries
.get(&id)
.expect("ID not found in entries. This is a bug.");
// Create and enter a span to link this function back to the entry
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
// Send generation responses back to the infer task
// If the receive an error from the Flume channel, it means that the client dropped the
// request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).map_err(|err| {
tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
err
}).unwrap_or(true);
if stopped {
entries.remove(&id).expect("ID not found in entries. This is a bug.");
}
});
}
/// Send responses through the `entry` response channel
fn send_responses(
generation: Generation,
entry: &Entry,
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
// Return directly if the channel is disconnected
if entry.response_tx.is_closed() {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
return Ok(true);
}
let mut stopped = false;
if let Some(prefill_tokens) = generation.prefill_tokens {
// Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
let prefill_tokens = prefill_tokens
.ids
.into_iter()
.zip(prefill_tokens.logprobs)
.zip(prefill_tokens.texts)
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
.collect();
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
}
// Create last Token
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len();
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
let mut iterator = tokens_
.ids
.into_iter()
.zip(tokens_.logprobs)
.zip(tokens_.texts)
.zip(tokens_.is_special)
.enumerate()
.peekable();
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
let token = Token {
id,
text,
logprob,
special,
};
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
top_tokens_
.ids
.iter()
.zip(top_tokens_.logprobs.iter())
.zip(top_tokens_.texts.iter())
.zip(top_tokens_.is_special.iter())
.map(|(((&id, &logprob), text), &special)| Token {
id,
text: text.to_string(),
logprob,
special,
})
.collect()
} else {
vec![]
};
match (&generation.generated_text, iterator.peek()) {
(Some(generated_text), None) => {
// Generation has ended
stopped = true;
// Send message
entry.response_tx.send(Ok(InferStreamResponse::End {
token,
top_tokens,
generated_text: GeneratedText::from(generated_text.clone()),
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))?;
}
_ => {
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
}
}
}
Ok(stopped)
}
/// Send errors to Infer for all `entries`
#[instrument(skip_all)]
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
entries.drain().for_each(|(_, entry)| {
// Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::GenerationError(error.to_string());
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Err(err))
.unwrap_or(());
});
}
impl From<text_generation_client::v3::GeneratedText> for GeneratedText {
fn from(value: text_generation_client::v3::GeneratedText) -> Self {
let v3_finish_reason =
text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap();
let finish_reason = match v3_finish_reason {
text_generation_client::v3::FinishReason::Length => FinishReason::Length,
text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence,
};
Self {
text: value.text,
generated_tokens: value.generated_tokens,
finish_reason,
seed: value.seed,
}
}
}

View File

@ -1,5 +1,141 @@
mod block_allocator;
mod queue;
mod scheduler;
mod backend;
pub(crate) use scheduler::SchedulerV3;
use futures_util::TryFutureExt;
use serde::Serialize;
use thiserror::Error;
use utoipa::ToSchema;
pub(crate) use backend::BackendV3;
use text_generation_client::ClientError;
use text_generation_client::v3::ShardedClient;
#[derive(Clone, Debug, Serialize, ToSchema)]
pub struct BackendInfo {
/// Mandatory
#[schema(example = "cuda")]
pub model_device_type: String,
#[schema(example = "torch.float16")]
pub model_dtype: String,
/// Backend parameters
#[schema(example = "1")]
pub speculate: usize,
#[schema(example = "1.2")]
pub waiting_served_ratio: f32,
#[schema(example = "32000")]
pub max_batch_total_tokens: u32,
#[schema(example = "20")]
pub max_waiting_tokens: usize,
#[schema(nullable = true, example = "null")]
pub max_batch_size: Option<usize>,
}
pub async fn connect_backend(
max_input_tokens: usize, max_total_tokens: usize,
master_shard_uds_path: String,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
) -> Result<(BackendV3, BackendInfo), V3Error> {
// Helper function
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
match max_supported_batch_total_tokens {
// Older models do not support automatic max-batch-total-tokens
None => {
let max_batch_total_tokens = max_batch_total_tokens
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
tracing::warn!("Model does not support automatic max batch total tokens");
Ok(max_batch_total_tokens)
}
// Flash attention models return their max supported total tokens
Some(max_supported_batch_total_tokens) => {
// Warn if user added his own max-batch-total-tokens as we will ignore it
if max_batch_total_tokens.is_some() {
tracing::warn!(
"`--max-batch-total-tokens` is deprecated for Flash \
Attention models."
);
tracing::warn!(
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
);
}
if max_total_tokens as u32 > max_supported_batch_total_tokens {
return Err(V3Error::NotEnoughMemory(max_total_tokens));
}
Ok(max_supported_batch_total_tokens)
}
}
};
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await
.map_err(V3Error::Connection)?;
// server is running on v3
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.map_err(V3Error::Cache)?;
// Get info from the shard
let shard_info = sharded_client.info().await.map_err(V3Error::Info)?;
// Warmup model
tracing::info!("Warming up model");
let max_batch_total_tokens = check_max_batch_total_tokens(
sharded_client
.warmup(
max_input_tokens as u32,
max_batch_prefill_tokens,
max_total_tokens as u32,
max_batch_size,
)
.await
.map_err(V3Error::Warmup)?,
)?;
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
let backend_info = BackendInfo {
waiting_served_ratio,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
model_device_type: shard_info.device_type.clone(),
model_dtype: shard_info.dtype.clone(),
speculate: shard_info.speculate as usize,
};
let backend = BackendV3::new(
sharded_client,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
shard_info.requires_padding,
shard_info.window_size,
shard_info.speculate,
);
tracing::info!("Using backend V3");
Ok((backend, backend_info))
}
#[derive(Debug, Error)]
pub(crate) enum V3Error {
#[error("Unable to clear the Python model shards cache: {0}")]
Cache(ClientError),
#[error("Unable to connect to the Python model shards: {0}")]
Connection(ClientError),
#[error("Unable to get the Python model shards info: {0}")]
Info(ClientError),
#[error("Unable to warmup the Python model shards: {0}")]
Warmup(ClientError),
#[error("Not enough memory to handle `max_total_tokens={0}`")]
NotEnoughMemory(usize),
}

View File

@ -135,12 +135,13 @@ pub struct Info {
pub model_id: String,
#[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")]
pub model_sha: Option<String>,
#[schema(example = "torch.float16")]
pub model_dtype: String,
#[schema(example = "cuda")]
pub model_device_type: String,
// #[schema(example = "torch.float16")]
// pub model_dtype: String,
// #[schema(example = "cuda")]
// pub model_device_type: String,
#[schema(nullable = true, example = "text-generation")]
pub model_pipeline_tag: Option<String>,
/// Router Parameters
#[schema(example = "128")]
pub max_concurrent_requests: usize,
@ -152,18 +153,12 @@ pub struct Info {
pub max_input_tokens: usize,
#[schema(example = "2048")]
pub max_total_tokens: usize,
#[schema(example = "1.2")]
pub waiting_served_ratio: f32,
#[schema(example = "32000")]
pub max_batch_total_tokens: u32,
#[schema(example = "20")]
pub max_waiting_tokens: usize,
#[schema(nullable = true, example = "null")]
pub max_batch_size: Option<usize>,
#[schema(example = "2")]
pub validation_workers: usize,
#[schema(example = "32")]
pub max_client_batch_size: usize,
/// Router Info
#[schema(example = "text-generation-router")]
pub router: &'static str,

View File

@ -1,9 +1,8 @@
/// HTTP Server logic
use crate::config::Config;
use crate::infer::v2::SchedulerV2;
use crate::infer::v3::SchedulerV3;
use crate::infer::{HealthCheck, Scheduler};
use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar};
use crate::infer::v3::{connect_backend, V3Error};
use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, Backend};
use crate::infer::tool_grammar::ToolGrammar;
#[cfg(feature = "kserve")]
use crate::kserve::{
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
@ -1498,138 +1497,8 @@ pub async fn run(
// Create state
// Open connection, get model info and warmup
let (scheduler, health_ext, shard_info, max_batch_total_tokens): (
Arc<dyn Scheduler + Send + Sync>,
HealthCheck,
ShardInfo,
u32,
) = {
// Helper function to check both v2 and v3
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
match max_supported_batch_total_tokens {
// Older models do not support automatic max-batch-total-tokens
None => {
let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)),
);
tracing::warn!("Model does not support automatic max batch total tokens");
Ok(max_batch_total_tokens)
}
// Flash attention models return their max supported total tokens
Some(max_supported_batch_total_tokens) => {
// Warn if user added his own max-batch-total-tokens as we will ignore it
if max_batch_total_tokens.is_some() {
tracing::warn!(
"`--max-batch-total-tokens` is deprecated for Flash \
Attention models."
);
tracing::warn!(
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
);
}
if max_total_tokens as u32 > max_supported_batch_total_tokens {
return Err(WebServerError::NotEnoughMemory(max_total_tokens));
}
Ok(max_supported_batch_total_tokens)
}
}
};
let generation_health = Arc::new(AtomicBool::new(false));
match v3::ShardedClient::connect_uds(master_shard_uds_path.clone()).await {
Ok(mut sharded_client) => {
// server is running on v3
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.map_err(WebServerError::Cache)?;
// Get info from the shard
let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?;
// Warmup model
tracing::info!("Warming up model");
let max_batch_total_tokens = check_max_batch_total_tokens(
sharded_client
.warmup(
max_input_tokens as u32,
max_batch_prefill_tokens,
max_total_tokens as u32,
max_batch_size,
)
.await
.map_err(WebServerError::Warmup)?,
)?;
let health_ext =
HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone());
let scheduler = Arc::new(SchedulerV3::new(
sharded_client,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
shard_info.requires_padding,
shard_info.window_size,
shard_info.speculate,
generation_health,
));
tracing::info!("Using scheduler V3");
(scheduler, health_ext, shard_info, max_batch_total_tokens)
}
Err(_) => {
let mut sharded_client = v2::ShardedClient::connect_uds(master_shard_uds_path)
.await
.map_err(WebServerError::Connection)?;
// server is running on v2
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.map_err(WebServerError::Cache)?;
// Get info from the shard
let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?;
// Warmup model
tracing::info!("Warming up model");
let max_batch_total_tokens = check_max_batch_total_tokens(
sharded_client
.warmup(
max_input_tokens as u32,
max_batch_prefill_tokens,
max_total_tokens as u32,
max_batch_size,
)
.await
.map_err(WebServerError::Warmup)?,
)?;
let health_ext =
HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone());
let scheduler = Arc::new(SchedulerV2::new(
sharded_client,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
shard_info.requires_padding,
shard_info.window_size,
shard_info.speculate,
generation_health,
));
tracing::info!("Using scheduler V2");
(scheduler, health_ext, shard_info, max_batch_total_tokens)
}
}
};
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
let (backend, backend_info) = connect_backend(max_input_tokens, max_total_tokens, master_shard_uds_path, waiting_served_ratio, max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, max_batch_size).await?;
// tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
let validation = Validation::new(
validation_workers,
@ -1644,7 +1513,7 @@ pub async fn run(
);
let infer = Infer::new(
scheduler,
backend,
validation,
max_concurrent_requests,
tokenizer_config,
@ -1681,8 +1550,8 @@ pub async fn run(
let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size"));
let batch_size_buckets: Vec<f64> = (0..1024).map(|x| (x + 1) as f64).collect();
// Speculated tokens buckets
let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens"));
let skipped_buckets: Vec<f64> = (0..shard_info.speculate + 1).map(|x| x as f64).collect();
// let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens"));
// let skipped_buckets: Vec<f64> = (0..shard_info.speculate + 1).map(|x| x as f64).collect();
// Prometheus handler
let builder = PrometheusBuilder::new()
@ -1695,9 +1564,9 @@ pub async fn run(
.set_buckets_for_metric(max_new_tokens_matcher, &max_new_tokens_buckets)
.unwrap()
.set_buckets_for_metric(batch_size_matcher, &batch_size_buckets)
.unwrap()
.set_buckets_for_metric(skipped_matcher, &skipped_buckets)
.unwrap();
// .set_buckets_for_metric(skipped_matcher, &skipped_buckets)
// .unwrap();
let prom_handle = builder
.install_recorder()
.expect("failed to install metrics recorder");
@ -1713,18 +1582,18 @@ pub async fn run(
let info = Info {
model_id: model_info.model_id,
model_sha: model_info.sha,
model_dtype: shard_info.dtype,
model_device_type: shard_info.device_type,
// model_dtype: shard_info.dtype,
// model_device_type: shard_info.device_type,
model_pipeline_tag: model_info.pipeline_tag,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_input_tokens,
max_total_tokens,
waiting_served_ratio,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
// waiting_served_ratio,
// max_batch_total_tokens,
// max_waiting_tokens,
// max_batch_size,
validation_workers,
max_client_batch_size,
router: env!("CARGO_PKG_NAME"),
@ -1858,7 +1727,6 @@ pub async fn run(
// add layers after routes
app = app
.layer(Extension(info))
.layer(Extension(health_ext.clone()))
.layer(Extension(compat_return_full_text))
.layer(Extension(infer))
.layer(Extension(compute_type))
@ -1960,7 +1828,7 @@ impl From<InferError> for Event {
#[derive(Debug, Error)]
pub enum WebServerError {
#[error("Backend error: {0}")]
Backend(#[from] BackendError),
Backend(#[from] V3Error),
#[error("Axum error: {0}")]
Axum(#[from] axum::BoxError),
}