wip
This commit is contained in:
parent
504754861f
commit
b562680be4
|
@ -3602,6 +3602,7 @@ name = "text-generation-router"
|
||||||
version = "2.0.5-dev0"
|
version = "2.0.5-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
|
"async-trait",
|
||||||
"axum 0.7.5",
|
"axum 0.7.5",
|
||||||
"axum-tracing-opentelemetry",
|
"axum-tracing-opentelemetry",
|
||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
|
|
|
@ -15,6 +15,7 @@ name = "text-generation-router"
|
||||||
path = "src/main.rs"
|
path = "src/main.rs"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
async-trait = "0.1.74"
|
||||||
async-stream = "0.3.5"
|
async-stream = "0.3.5"
|
||||||
axum = { version = "0.7", features = ["json"] }
|
axum = { version = "0.7", features = ["json"] }
|
||||||
axum-tracing-opentelemetry = "0.16"
|
axum-tracing-opentelemetry = "0.16"
|
||||||
|
|
|
@ -1,24 +1,18 @@
|
||||||
mod health;
|
// pub(crate) mod v2;
|
||||||
pub(crate) mod v2;
|
|
||||||
pub(crate) mod v3;
|
pub(crate) mod v3;
|
||||||
mod chat_template;
|
mod chat_template;
|
||||||
mod tool_grammar;
|
pub mod tool_grammar;
|
||||||
|
|
||||||
pub(crate) use health::HealthCheck;
|
|
||||||
|
|
||||||
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig,
|
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig,
|
||||||
HubTokenizerConfig, Message, MessageChunk, PrefillToken, Text, TextMessage, Token,
|
HubTokenizerConfig, Message, PrefillToken, Token,
|
||||||
};
|
};
|
||||||
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
|
use crate::{GrammarType};
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
use minijinja::{Environment, ErrorKind, Template};
|
use minijinja::{ErrorKind};
|
||||||
use minijinja_contrib::pycompat;
|
|
||||||
|
|
||||||
use serde_json::{json, Map, Value};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
|
@ -26,13 +20,16 @@ use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tokio_stream::StreamExt;
|
use tokio_stream::StreamExt;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
use chat_template::ChatTemplate;
|
use chat_template::ChatTemplate;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
pub(crate) trait Scheduler {
|
#[async_trait]
|
||||||
|
pub(crate) trait Backend {
|
||||||
fn schedule(
|
fn schedule(
|
||||||
&self,
|
&self,
|
||||||
request: ValidGenerateRequest,
|
request: ValidGenerateRequest,
|
||||||
permit: OwnedSemaphorePermit,
|
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError>;
|
||||||
) -> Result<GenerateStreamResponse, InferError>;
|
|
||||||
|
async fn health(&self, current_health: bool) -> bool;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Inference struct
|
/// Inference struct
|
||||||
|
@ -90,15 +87,7 @@ impl Infer {
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
pub(crate) async fn generate_stream<'a>(
|
pub(crate) async fn generate_stream<'a>(
|
||||||
&'a self,
|
&'a self,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest) -> Result<GenerateStreamResponse, InferError> {
|
||||||
) -> Result<
|
|
||||||
(
|
|
||||||
OwnedSemaphorePermit,
|
|
||||||
u32, // input_length
|
|
||||||
impl Stream<Item=Result<InferStreamResponse, InferError>> + 'a,
|
|
||||||
),
|
|
||||||
InferError,
|
|
||||||
> {
|
|
||||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||||
let permit = self
|
let permit = self
|
||||||
.clone()
|
.clone()
|
||||||
|
@ -118,35 +107,11 @@ impl Infer {
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let input_length = valid_request.input_length;
|
let input_length = valid_request.input_length;
|
||||||
let mut generation_stream = self
|
let generation_stream = self
|
||||||
.backend
|
.backend
|
||||||
.schedule(valid_request)
|
.schedule(valid_request)?;
|
||||||
.map_err(InferError::Backend)?;
|
|
||||||
|
|
||||||
let stream = stream! {
|
Ok((permit, input_length, generation_stream))
|
||||||
while let Some(generation) = generation_stream.next().await {
|
|
||||||
self.backend_health.store(generation.is_ok(), Ordering::SeqCst);
|
|
||||||
|
|
||||||
yield generation.map(|generation| match generation {
|
|
||||||
types::TokenStreamResponse::Prefill(prefill_tokens) => InferStreamResponse::Prefill(
|
|
||||||
prefill_tokens.into_iter().map(PrefillToken::from).collect()
|
|
||||||
),
|
|
||||||
types::TokenStreamResponse::Intermediate { token, top_tokens } => InferStreamResponse::Intermediate {
|
|
||||||
token: Token::from(token),
|
|
||||||
top_tokens: top_tokens.into_iter().map(Token::from).collect(),
|
|
||||||
},
|
|
||||||
types::TokenStreamResponse::End { token, top_tokens, generated_text, start, queued } => InferStreamResponse::End {
|
|
||||||
token: Token::from(token),
|
|
||||||
top_tokens: top_tokens.into_iter().map(Token::from).collect(),
|
|
||||||
generated_text,
|
|
||||||
start,
|
|
||||||
queued,
|
|
||||||
}
|
|
||||||
}).map_err(InferError::GenerationError)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok((permit, input_length, stream))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Tokenizer the input
|
/// Tokenizer the input
|
||||||
|
@ -363,10 +328,8 @@ pub(crate) struct InferResponse {
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum InferError {
|
pub enum InferError {
|
||||||
#[error("Request failed during scheduling: {0}")]
|
|
||||||
Backend(BackendError),
|
|
||||||
#[error("Request failed during generation: {0}")]
|
#[error("Request failed during generation: {0}")]
|
||||||
GenerationError(BackendError),
|
GenerationError(String),
|
||||||
#[error("Model is overloaded")]
|
#[error("Model is overloaded")]
|
||||||
Overloaded(#[from] TryAcquireError),
|
Overloaded(#[from] TryAcquireError),
|
||||||
#[error("Input validation error: {0}")]
|
#[error("Input validation error: {0}")]
|
||||||
|
@ -382,7 +345,6 @@ pub enum InferError {
|
||||||
impl InferError {
|
impl InferError {
|
||||||
pub(crate) fn error_type(&self) -> &str {
|
pub(crate) fn error_type(&self) -> &str {
|
||||||
match self {
|
match self {
|
||||||
InferError::Backend(_) => "backend",
|
|
||||||
InferError::GenerationError(_) => "generation",
|
InferError::GenerationError(_) => "generation",
|
||||||
InferError::Overloaded(_) => "overloaded",
|
InferError::Overloaded(_) => "overloaded",
|
||||||
InferError::ValidationError(_) => "validation",
|
InferError::ValidationError(_) => "validation",
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
mod queue;
|
mod queue;
|
||||||
mod scheduler;
|
mod scheduler;
|
||||||
|
|
||||||
pub(crate) use scheduler::SchedulerV2;
|
pub(crate) use scheduler::BackendV2;
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
/// Batching and inference logic
|
/// Batching and inference logic
|
||||||
use crate::infer::v2::queue::{Entry, Queue};
|
use crate::infer::v2::queue::{Entry, Queue};
|
||||||
use crate::infer::{
|
use crate::infer::{
|
||||||
GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler,
|
GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Backend,
|
||||||
};
|
};
|
||||||
use crate::validation::ValidGenerateRequest;
|
use crate::validation::ValidGenerateRequest;
|
||||||
use crate::{FinishReason, PrefillToken, Token};
|
use crate::{FinishReason, PrefillToken, Token};
|
||||||
|
@ -18,14 +18,14 @@ use tokio::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{info_span, instrument, Instrument, Span};
|
use tracing::{info_span, instrument, Instrument, Span};
|
||||||
|
|
||||||
pub(crate) struct SchedulerV2 {
|
pub(crate) struct BackendV2 {
|
||||||
/// Request queue
|
/// Request queue
|
||||||
queue: Queue,
|
queue: Queue,
|
||||||
/// Notify batcher on queue appends
|
/// Notify batcher on queue appends
|
||||||
batching_task_notifier: Arc<Notify>,
|
batching_task_notifier: Arc<Notify>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SchedulerV2 {
|
impl BackendV2 {
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
|
@ -62,7 +62,7 @@ impl SchedulerV2 {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Scheduler for SchedulerV2 {
|
impl Backend for BackendV2 {
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
fn schedule(
|
fn schedule(
|
||||||
&self,
|
&self,
|
||||||
|
|
|
@ -0,0 +1,500 @@
|
||||||
|
/// Batching and inference logic
|
||||||
|
use crate::infer::v3::queue::{Entry, Queue};
|
||||||
|
use crate::infer::{
|
||||||
|
GeneratedText, InferError, InferStreamResponse, Backend,
|
||||||
|
};
|
||||||
|
use crate::validation::ValidGenerateRequest;
|
||||||
|
use crate::{FinishReason, PrefillToken, Token};
|
||||||
|
use nohash_hasher::IntMap;
|
||||||
|
use std::sync::{
|
||||||
|
Arc,
|
||||||
|
};
|
||||||
|
use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient};
|
||||||
|
use text_generation_client::{ClientError, Health};
|
||||||
|
use tokio::sync::mpsc::error::SendError;
|
||||||
|
use tokio::sync::{mpsc, Notify};
|
||||||
|
use tokio::time::Instant;
|
||||||
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
|
use tracing::{info_span, instrument, Instrument, Span};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
pub(crate) struct BackendV3 {
|
||||||
|
/// Request queue
|
||||||
|
queue: Queue,
|
||||||
|
/// Notify batcher on queue appends
|
||||||
|
batching_task_notifier: Arc<Notify>,
|
||||||
|
/// Client clone, used for health checks to skip the queue
|
||||||
|
client: ShardedClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BackendV3 {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub(crate) fn new(
|
||||||
|
client: ShardedClient,
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
requires_padding: bool,
|
||||||
|
window_size: Option<u32>,
|
||||||
|
speculate: u32,
|
||||||
|
) -> Self {
|
||||||
|
let queue = Queue::new(
|
||||||
|
requires_padding,
|
||||||
|
16,
|
||||||
|
window_size,
|
||||||
|
speculate,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
);
|
||||||
|
let batching_task_notifier = Arc::new(Notify::new());
|
||||||
|
|
||||||
|
// Spawn batching background task that contains all the inference logic
|
||||||
|
tokio::spawn(batching_task(
|
||||||
|
client.clone(),
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
queue.clone(),
|
||||||
|
batching_task_notifier.clone(),
|
||||||
|
));
|
||||||
|
|
||||||
|
Self {
|
||||||
|
queue,
|
||||||
|
batching_task_notifier,
|
||||||
|
client,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Backend for BackendV3 {
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
fn schedule(
|
||||||
|
&self,
|
||||||
|
request: ValidGenerateRequest,
|
||||||
|
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||||
|
// MPSC channel to communicate with the background batching task
|
||||||
|
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||||
|
let input_length = request.input_length;
|
||||||
|
|
||||||
|
// Append the request to the queue
|
||||||
|
self.queue.append(Entry {
|
||||||
|
request,
|
||||||
|
response_tx,
|
||||||
|
span: Span::current(),
|
||||||
|
temp_span: None,
|
||||||
|
queue_time: Instant::now(),
|
||||||
|
batch_time: None,
|
||||||
|
block_allocation: None,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Notify the background task that we have a new entry in the queue that needs
|
||||||
|
// to be batched
|
||||||
|
self.batching_task_notifier.notify_one();
|
||||||
|
|
||||||
|
// Return stream
|
||||||
|
Ok(
|
||||||
|
UnboundedReceiverStream::new(response_rx),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health(&self, current_health: bool) -> bool {
|
||||||
|
if current_health {
|
||||||
|
// Generation is healthy, we only check that the shards can allocate on device
|
||||||
|
self.client.device_health().await
|
||||||
|
} else {
|
||||||
|
self.client.model_health().await
|
||||||
|
}
|
||||||
|
.is_ok()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Batching logic
|
||||||
|
/// Will be launched in a background Tokio task
|
||||||
|
///
|
||||||
|
/// Batches requests and sends them to the inference server
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub(crate) async fn batching_task(
|
||||||
|
mut client: ShardedClient,
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
queue: Queue,
|
||||||
|
notifier: Arc<Notify>,
|
||||||
|
) {
|
||||||
|
// Infinite loop
|
||||||
|
loop {
|
||||||
|
// Wait for a notification from the Infer struct
|
||||||
|
notifier.notified().await;
|
||||||
|
|
||||||
|
// Get the next batch from the queue
|
||||||
|
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||||
|
// waiting in the queue
|
||||||
|
while let Some((mut entries, batch, span)) = queue
|
||||||
|
.next_batch(
|
||||||
|
None,
|
||||||
|
max_batch_size,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
||||||
|
.instrument(span)
|
||||||
|
.await;
|
||||||
|
let mut waiting_tokens = 1;
|
||||||
|
|
||||||
|
// We loop until we do not receive any cached batch from the inference server (== until
|
||||||
|
// all requests have met their stopping criteria)
|
||||||
|
while let Some(batch) = cached_batch {
|
||||||
|
// Get current batch info
|
||||||
|
let batch_size = batch.size;
|
||||||
|
let batch_max_tokens = batch.max_tokens;
|
||||||
|
let mut batches = vec![batch];
|
||||||
|
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
|
||||||
|
metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64);
|
||||||
|
|
||||||
|
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||||
|
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||||
|
// to add a new batch even though its size might be small
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
// Minimum batch size
|
||||||
|
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||||
|
};
|
||||||
|
|
||||||
|
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||||
|
let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize);
|
||||||
|
|
||||||
|
// Try to get a new batch
|
||||||
|
if let Some((mut new_entries, new_batch, span)) = queue
|
||||||
|
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
// Tracking metrics
|
||||||
|
if min_size.is_some() {
|
||||||
|
metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure");
|
||||||
|
} else {
|
||||||
|
metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded");
|
||||||
|
}
|
||||||
|
|
||||||
|
entries.iter_mut().for_each(|(_, entry)| {
|
||||||
|
// Create a new span to add the info that this entry is waiting
|
||||||
|
// because a new batch is being computed
|
||||||
|
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
|
||||||
|
// Add relationships
|
||||||
|
span.follows_from(&entry_waiting_span);
|
||||||
|
entry_waiting_span.follows_from(&span);
|
||||||
|
// Update entry
|
||||||
|
entry.temp_span = Some(entry_waiting_span);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Generate one token for this new batch to have the attention past in cache
|
||||||
|
let new_cached_batch =
|
||||||
|
prefill(&mut client, new_batch, &mut new_entries)
|
||||||
|
.instrument(span)
|
||||||
|
.await;
|
||||||
|
// Reset waiting counter
|
||||||
|
waiting_tokens = 1;
|
||||||
|
// Extend current batch with the new batch
|
||||||
|
if let Some(new_cached_batch) = new_cached_batch {
|
||||||
|
entries.extend(new_entries);
|
||||||
|
batches.push(new_cached_batch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create span for this batch to add context to inference calls
|
||||||
|
let next_batch_size = entries.len();
|
||||||
|
let next_batch_span =
|
||||||
|
info_span!(parent: None, "batch", batch_size = next_batch_size);
|
||||||
|
entries.iter_mut().for_each(|(_, entry)| {
|
||||||
|
// Create a new span to link the batch back to this entry
|
||||||
|
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||||
|
// Add relationships
|
||||||
|
next_batch_span.follows_from(&entry_batch_span);
|
||||||
|
entry_batch_span.follows_from(&next_batch_span);
|
||||||
|
// Update entry
|
||||||
|
entry.temp_span = Some(entry_batch_span);
|
||||||
|
});
|
||||||
|
|
||||||
|
cached_batch = decode(&mut client, batches, &mut entries)
|
||||||
|
.instrument(next_batch_span)
|
||||||
|
.await;
|
||||||
|
waiting_tokens += 1;
|
||||||
|
}
|
||||||
|
metrics::gauge!("tgi_batch_current_size", 0.0);
|
||||||
|
metrics::gauge!("tgi_batch_current_max_tokens", 0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn prefill(
|
||||||
|
client: &mut ShardedClient,
|
||||||
|
batch: Batch,
|
||||||
|
entries: &mut IntMap<u64, Entry>,
|
||||||
|
) -> Option<CachedBatch> {
|
||||||
|
let start_time = Instant::now();
|
||||||
|
let batch_id = batch.id;
|
||||||
|
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
|
||||||
|
|
||||||
|
match client.prefill(batch).await {
|
||||||
|
Ok((generations, next_batch, timings)) => {
|
||||||
|
let start_filtering_time = Instant::now();
|
||||||
|
// Send generated tokens and filter stopped entries
|
||||||
|
filter_send_generations(generations, entries);
|
||||||
|
|
||||||
|
// Filter next batch and remove requests that were stopped
|
||||||
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
|
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
|
||||||
|
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
|
||||||
|
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill");
|
||||||
|
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
|
||||||
|
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
|
||||||
|
next_batch
|
||||||
|
}
|
||||||
|
// If we have an error, we discard the whole batch
|
||||||
|
Err(err) => {
|
||||||
|
let _ = client.clear_cache(Some(batch_id)).await;
|
||||||
|
send_errors(err, entries);
|
||||||
|
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn decode(
|
||||||
|
client: &mut ShardedClient,
|
||||||
|
batches: Vec<CachedBatch>,
|
||||||
|
entries: &mut IntMap<u64, Entry>,
|
||||||
|
) -> Option<CachedBatch> {
|
||||||
|
let start_time = Instant::now();
|
||||||
|
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||||
|
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
|
||||||
|
|
||||||
|
match client.decode(batches).await {
|
||||||
|
Ok((generations, next_batch, timings)) => {
|
||||||
|
let start_filtering_time = Instant::now();
|
||||||
|
// Send generated tokens and filter stopped entries
|
||||||
|
filter_send_generations(generations, entries);
|
||||||
|
|
||||||
|
// Filter next batch and remove requests that were stopped
|
||||||
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
|
if let Some(concat_duration) = timings.concat {
|
||||||
|
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
|
||||||
|
}
|
||||||
|
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode");
|
||||||
|
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode");
|
||||||
|
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode");
|
||||||
|
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
|
||||||
|
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
|
||||||
|
next_batch
|
||||||
|
}
|
||||||
|
// If we have an error, we discard the whole batch
|
||||||
|
Err(err) => {
|
||||||
|
for id in batch_ids {
|
||||||
|
let _ = client.clear_cache(Some(id)).await;
|
||||||
|
}
|
||||||
|
send_errors(err, entries);
|
||||||
|
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filter a `batch` and remove all requests not present in `entries`
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn filter_batch(
|
||||||
|
client: &mut ShardedClient,
|
||||||
|
next_batch: Option<CachedBatch>,
|
||||||
|
entries: &IntMap<u64, Entry>,
|
||||||
|
) -> Option<CachedBatch> {
|
||||||
|
let mut batch = next_batch?;
|
||||||
|
|
||||||
|
// No need to filter
|
||||||
|
if batch.size as usize == entries.len() {
|
||||||
|
return Some(batch);
|
||||||
|
}
|
||||||
|
|
||||||
|
let id = batch.id;
|
||||||
|
|
||||||
|
// Retain only requests that are still in entries
|
||||||
|
batch.request_ids.retain(|id| entries.contains_key(id));
|
||||||
|
|
||||||
|
if batch.request_ids.is_empty() {
|
||||||
|
// All requests have been filtered out
|
||||||
|
// Next batch is now empty
|
||||||
|
// Clear it from the Python shards cache
|
||||||
|
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||||
|
client.clear_cache(Some(id)).await.unwrap();
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
// Filter Python shard cache
|
||||||
|
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||||
|
client.filter_batch(id, batch.request_ids).await.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
||||||
|
/// and filter entries
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
||||||
|
generations.into_iter().for_each(|generation| {
|
||||||
|
let id = generation.request_id;
|
||||||
|
// Get entry
|
||||||
|
// We can `expect` here as the request id should always be in the entries
|
||||||
|
let entry = entries
|
||||||
|
.get(&id)
|
||||||
|
.expect("ID not found in entries. This is a bug.");
|
||||||
|
|
||||||
|
// Create and enter a span to link this function back to the entry
|
||||||
|
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
||||||
|
// Send generation responses back to the infer task
|
||||||
|
// If the receive an error from the Flume channel, it means that the client dropped the
|
||||||
|
// request and we need to stop generating hence why we unwrap_or(true)
|
||||||
|
let stopped = send_responses(generation, entry).map_err(|err| {
|
||||||
|
tracing::error!("Entry response channel error.");
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||||
|
err
|
||||||
|
}).unwrap_or(true);
|
||||||
|
if stopped {
|
||||||
|
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send responses through the `entry` response channel
|
||||||
|
fn send_responses(
|
||||||
|
generation: Generation,
|
||||||
|
entry: &Entry,
|
||||||
|
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||||
|
// Return directly if the channel is disconnected
|
||||||
|
if entry.response_tx.is_closed() {
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||||
|
return Ok(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut stopped = false;
|
||||||
|
|
||||||
|
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||||
|
// Create Token objects
|
||||||
|
// We do that here instead of in the Python code as Rust for loops are faster
|
||||||
|
let prefill_tokens = prefill_tokens
|
||||||
|
.ids
|
||||||
|
.into_iter()
|
||||||
|
.zip(prefill_tokens.logprobs)
|
||||||
|
.zip(prefill_tokens.texts)
|
||||||
|
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Send message
|
||||||
|
entry
|
||||||
|
.response_tx
|
||||||
|
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create last Token
|
||||||
|
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||||
|
let n = tokens_.ids.len();
|
||||||
|
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
|
||||||
|
let mut iterator = tokens_
|
||||||
|
.ids
|
||||||
|
.into_iter()
|
||||||
|
.zip(tokens_.logprobs)
|
||||||
|
.zip(tokens_.texts)
|
||||||
|
.zip(tokens_.is_special)
|
||||||
|
.enumerate()
|
||||||
|
.peekable();
|
||||||
|
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
|
||||||
|
let token = Token {
|
||||||
|
id,
|
||||||
|
text,
|
||||||
|
logprob,
|
||||||
|
special,
|
||||||
|
};
|
||||||
|
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
|
||||||
|
top_tokens_
|
||||||
|
.ids
|
||||||
|
.iter()
|
||||||
|
.zip(top_tokens_.logprobs.iter())
|
||||||
|
.zip(top_tokens_.texts.iter())
|
||||||
|
.zip(top_tokens_.is_special.iter())
|
||||||
|
.map(|(((&id, &logprob), text), &special)| Token {
|
||||||
|
id,
|
||||||
|
text: text.to_string(),
|
||||||
|
logprob,
|
||||||
|
special,
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
} else {
|
||||||
|
vec![]
|
||||||
|
};
|
||||||
|
match (&generation.generated_text, iterator.peek()) {
|
||||||
|
(Some(generated_text), None) => {
|
||||||
|
// Generation has ended
|
||||||
|
stopped = true;
|
||||||
|
// Send message
|
||||||
|
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||||
|
token,
|
||||||
|
top_tokens,
|
||||||
|
generated_text: GeneratedText::from(generated_text.clone()),
|
||||||
|
queued: entry.queue_time,
|
||||||
|
start: entry.batch_time.unwrap(),
|
||||||
|
}))?;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// Send message
|
||||||
|
entry
|
||||||
|
.response_tx
|
||||||
|
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(stopped)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send errors to Infer for all `entries`
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||||
|
entries.drain().for_each(|(_, entry)| {
|
||||||
|
// Create and enter a span to link this function back to the entry
|
||||||
|
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||||
|
let err = InferError::GenerationError(error.to_string());
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
|
||||||
|
tracing::error!("{err}");
|
||||||
|
|
||||||
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
|
entry
|
||||||
|
.response_tx
|
||||||
|
.send(Err(err))
|
||||||
|
.unwrap_or(());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<text_generation_client::v3::GeneratedText> for GeneratedText {
|
||||||
|
fn from(value: text_generation_client::v3::GeneratedText) -> Self {
|
||||||
|
let v3_finish_reason =
|
||||||
|
text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap();
|
||||||
|
let finish_reason = match v3_finish_reason {
|
||||||
|
text_generation_client::v3::FinishReason::Length => FinishReason::Length,
|
||||||
|
text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||||
|
text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||||
|
};
|
||||||
|
|
||||||
|
Self {
|
||||||
|
text: value.text,
|
||||||
|
generated_tokens: value.generated_tokens,
|
||||||
|
finish_reason,
|
||||||
|
seed: value.seed,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,5 +1,141 @@
|
||||||
mod block_allocator;
|
mod block_allocator;
|
||||||
mod queue;
|
mod queue;
|
||||||
mod scheduler;
|
mod backend;
|
||||||
|
|
||||||
pub(crate) use scheduler::SchedulerV3;
|
use futures_util::TryFutureExt;
|
||||||
|
use serde::Serialize;
|
||||||
|
use thiserror::Error;
|
||||||
|
use utoipa::ToSchema;
|
||||||
|
pub(crate) use backend::BackendV3;
|
||||||
|
use text_generation_client::ClientError;
|
||||||
|
use text_generation_client::v3::ShardedClient;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||||
|
pub struct BackendInfo {
|
||||||
|
/// Mandatory
|
||||||
|
#[schema(example = "cuda")]
|
||||||
|
pub model_device_type: String,
|
||||||
|
#[schema(example = "torch.float16")]
|
||||||
|
pub model_dtype: String,
|
||||||
|
|
||||||
|
/// Backend parameters
|
||||||
|
#[schema(example = "1")]
|
||||||
|
pub speculate: usize,
|
||||||
|
#[schema(example = "1.2")]
|
||||||
|
pub waiting_served_ratio: f32,
|
||||||
|
#[schema(example = "32000")]
|
||||||
|
pub max_batch_total_tokens: u32,
|
||||||
|
#[schema(example = "20")]
|
||||||
|
pub max_waiting_tokens: usize,
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub max_batch_size: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn connect_backend(
|
||||||
|
max_input_tokens: usize, max_total_tokens: usize,
|
||||||
|
master_shard_uds_path: String,
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
max_batch_total_tokens: Option<u32>,
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
) -> Result<(BackendV3, BackendInfo), V3Error> {
|
||||||
|
// Helper function
|
||||||
|
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
|
||||||
|
match max_supported_batch_total_tokens {
|
||||||
|
// Older models do not support automatic max-batch-total-tokens
|
||||||
|
None => {
|
||||||
|
let max_batch_total_tokens = max_batch_total_tokens
|
||||||
|
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
|
||||||
|
tracing::warn!("Model does not support automatic max batch total tokens");
|
||||||
|
Ok(max_batch_total_tokens)
|
||||||
|
}
|
||||||
|
// Flash attention models return their max supported total tokens
|
||||||
|
Some(max_supported_batch_total_tokens) => {
|
||||||
|
// Warn if user added his own max-batch-total-tokens as we will ignore it
|
||||||
|
if max_batch_total_tokens.is_some() {
|
||||||
|
tracing::warn!(
|
||||||
|
"`--max-batch-total-tokens` is deprecated for Flash \
|
||||||
|
Attention models."
|
||||||
|
);
|
||||||
|
tracing::warn!(
|
||||||
|
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
||||||
|
return Err(V3Error::NotEnoughMemory(max_total_tokens));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(max_supported_batch_total_tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||||
|
.await
|
||||||
|
.map_err(V3Error::Connection)?;
|
||||||
|
|
||||||
|
// server is running on v3
|
||||||
|
// Clear the cache; useful if the webserver rebooted
|
||||||
|
sharded_client
|
||||||
|
.clear_cache(None)
|
||||||
|
.await
|
||||||
|
.map_err(V3Error::Cache)?;
|
||||||
|
// Get info from the shard
|
||||||
|
let shard_info = sharded_client.info().await.map_err(V3Error::Info)?;
|
||||||
|
|
||||||
|
// Warmup model
|
||||||
|
tracing::info!("Warming up model");
|
||||||
|
let max_batch_total_tokens = check_max_batch_total_tokens(
|
||||||
|
sharded_client
|
||||||
|
.warmup(
|
||||||
|
max_input_tokens as u32,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_total_tokens as u32,
|
||||||
|
max_batch_size,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(V3Error::Warmup)?,
|
||||||
|
)?;
|
||||||
|
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
||||||
|
|
||||||
|
let backend_info = BackendInfo {
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
model_device_type: shard_info.device_type.clone(),
|
||||||
|
model_dtype: shard_info.dtype.clone(),
|
||||||
|
speculate: shard_info.speculate as usize,
|
||||||
|
};
|
||||||
|
|
||||||
|
let backend = BackendV3::new(
|
||||||
|
sharded_client,
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
shard_info.requires_padding,
|
||||||
|
shard_info.window_size,
|
||||||
|
shard_info.speculate,
|
||||||
|
);
|
||||||
|
|
||||||
|
tracing::info!("Using backend V3");
|
||||||
|
|
||||||
|
Ok((backend, backend_info))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub(crate) enum V3Error {
|
||||||
|
#[error("Unable to clear the Python model shards cache: {0}")]
|
||||||
|
Cache(ClientError),
|
||||||
|
#[error("Unable to connect to the Python model shards: {0}")]
|
||||||
|
Connection(ClientError),
|
||||||
|
#[error("Unable to get the Python model shards info: {0}")]
|
||||||
|
Info(ClientError),
|
||||||
|
#[error("Unable to warmup the Python model shards: {0}")]
|
||||||
|
Warmup(ClientError),
|
||||||
|
#[error("Not enough memory to handle `max_total_tokens={0}`")]
|
||||||
|
NotEnoughMemory(usize),
|
||||||
|
}
|
|
@ -135,12 +135,13 @@ pub struct Info {
|
||||||
pub model_id: String,
|
pub model_id: String,
|
||||||
#[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")]
|
#[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")]
|
||||||
pub model_sha: Option<String>,
|
pub model_sha: Option<String>,
|
||||||
#[schema(example = "torch.float16")]
|
// #[schema(example = "torch.float16")]
|
||||||
pub model_dtype: String,
|
// pub model_dtype: String,
|
||||||
#[schema(example = "cuda")]
|
// #[schema(example = "cuda")]
|
||||||
pub model_device_type: String,
|
// pub model_device_type: String,
|
||||||
#[schema(nullable = true, example = "text-generation")]
|
#[schema(nullable = true, example = "text-generation")]
|
||||||
pub model_pipeline_tag: Option<String>,
|
pub model_pipeline_tag: Option<String>,
|
||||||
|
|
||||||
/// Router Parameters
|
/// Router Parameters
|
||||||
#[schema(example = "128")]
|
#[schema(example = "128")]
|
||||||
pub max_concurrent_requests: usize,
|
pub max_concurrent_requests: usize,
|
||||||
|
@ -152,18 +153,12 @@ pub struct Info {
|
||||||
pub max_input_tokens: usize,
|
pub max_input_tokens: usize,
|
||||||
#[schema(example = "2048")]
|
#[schema(example = "2048")]
|
||||||
pub max_total_tokens: usize,
|
pub max_total_tokens: usize,
|
||||||
#[schema(example = "1.2")]
|
|
||||||
pub waiting_served_ratio: f32,
|
|
||||||
#[schema(example = "32000")]
|
|
||||||
pub max_batch_total_tokens: u32,
|
|
||||||
#[schema(example = "20")]
|
|
||||||
pub max_waiting_tokens: usize,
|
|
||||||
#[schema(nullable = true, example = "null")]
|
|
||||||
pub max_batch_size: Option<usize>,
|
|
||||||
#[schema(example = "2")]
|
#[schema(example = "2")]
|
||||||
pub validation_workers: usize,
|
pub validation_workers: usize,
|
||||||
#[schema(example = "32")]
|
#[schema(example = "32")]
|
||||||
pub max_client_batch_size: usize,
|
pub max_client_batch_size: usize,
|
||||||
|
|
||||||
|
|
||||||
/// Router Info
|
/// Router Info
|
||||||
#[schema(example = "text-generation-router")]
|
#[schema(example = "text-generation-router")]
|
||||||
pub router: &'static str,
|
pub router: &'static str,
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
/// HTTP Server logic
|
/// HTTP Server logic
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::infer::v2::SchedulerV2;
|
use crate::infer::v3::{connect_backend, V3Error};
|
||||||
use crate::infer::v3::SchedulerV3;
|
use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, Backend};
|
||||||
use crate::infer::{HealthCheck, Scheduler};
|
use crate::infer::tool_grammar::ToolGrammar;
|
||||||
use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar};
|
|
||||||
#[cfg(feature = "kserve")]
|
#[cfg(feature = "kserve")]
|
||||||
use crate::kserve::{
|
use crate::kserve::{
|
||||||
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
|
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
|
||||||
|
@ -1498,138 +1497,8 @@ pub async fn run(
|
||||||
// Create state
|
// Create state
|
||||||
|
|
||||||
// Open connection, get model info and warmup
|
// Open connection, get model info and warmup
|
||||||
let (scheduler, health_ext, shard_info, max_batch_total_tokens): (
|
let (backend, backend_info) = connect_backend(max_input_tokens, max_total_tokens, master_shard_uds_path, waiting_served_ratio, max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, max_batch_size).await?;
|
||||||
Arc<dyn Scheduler + Send + Sync>,
|
// tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
||||||
HealthCheck,
|
|
||||||
ShardInfo,
|
|
||||||
u32,
|
|
||||||
) = {
|
|
||||||
// Helper function to check both v2 and v3
|
|
||||||
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
|
|
||||||
match max_supported_batch_total_tokens {
|
|
||||||
// Older models do not support automatic max-batch-total-tokens
|
|
||||||
None => {
|
|
||||||
let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
|
|
||||||
16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)),
|
|
||||||
);
|
|
||||||
tracing::warn!("Model does not support automatic max batch total tokens");
|
|
||||||
Ok(max_batch_total_tokens)
|
|
||||||
}
|
|
||||||
// Flash attention models return their max supported total tokens
|
|
||||||
Some(max_supported_batch_total_tokens) => {
|
|
||||||
// Warn if user added his own max-batch-total-tokens as we will ignore it
|
|
||||||
if max_batch_total_tokens.is_some() {
|
|
||||||
tracing::warn!(
|
|
||||||
"`--max-batch-total-tokens` is deprecated for Flash \
|
|
||||||
Attention models."
|
|
||||||
);
|
|
||||||
tracing::warn!(
|
|
||||||
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
|
||||||
return Err(WebServerError::NotEnoughMemory(max_total_tokens));
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(max_supported_batch_total_tokens)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let generation_health = Arc::new(AtomicBool::new(false));
|
|
||||||
|
|
||||||
match v3::ShardedClient::connect_uds(master_shard_uds_path.clone()).await {
|
|
||||||
Ok(mut sharded_client) => {
|
|
||||||
// server is running on v3
|
|
||||||
// Clear the cache; useful if the webserver rebooted
|
|
||||||
sharded_client
|
|
||||||
.clear_cache(None)
|
|
||||||
.await
|
|
||||||
.map_err(WebServerError::Cache)?;
|
|
||||||
// Get info from the shard
|
|
||||||
let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?;
|
|
||||||
|
|
||||||
// Warmup model
|
|
||||||
tracing::info!("Warming up model");
|
|
||||||
let max_batch_total_tokens = check_max_batch_total_tokens(
|
|
||||||
sharded_client
|
|
||||||
.warmup(
|
|
||||||
max_input_tokens as u32,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_total_tokens as u32,
|
|
||||||
max_batch_size,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.map_err(WebServerError::Warmup)?,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let health_ext =
|
|
||||||
HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone());
|
|
||||||
let scheduler = Arc::new(SchedulerV3::new(
|
|
||||||
sharded_client,
|
|
||||||
waiting_served_ratio,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_batch_total_tokens,
|
|
||||||
max_waiting_tokens,
|
|
||||||
max_batch_size,
|
|
||||||
shard_info.requires_padding,
|
|
||||||
shard_info.window_size,
|
|
||||||
shard_info.speculate,
|
|
||||||
generation_health,
|
|
||||||
));
|
|
||||||
tracing::info!("Using scheduler V3");
|
|
||||||
|
|
||||||
(scheduler, health_ext, shard_info, max_batch_total_tokens)
|
|
||||||
}
|
|
||||||
Err(_) => {
|
|
||||||
let mut sharded_client = v2::ShardedClient::connect_uds(master_shard_uds_path)
|
|
||||||
.await
|
|
||||||
.map_err(WebServerError::Connection)?;
|
|
||||||
|
|
||||||
// server is running on v2
|
|
||||||
// Clear the cache; useful if the webserver rebooted
|
|
||||||
sharded_client
|
|
||||||
.clear_cache(None)
|
|
||||||
.await
|
|
||||||
.map_err(WebServerError::Cache)?;
|
|
||||||
// Get info from the shard
|
|
||||||
let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?;
|
|
||||||
|
|
||||||
// Warmup model
|
|
||||||
tracing::info!("Warming up model");
|
|
||||||
let max_batch_total_tokens = check_max_batch_total_tokens(
|
|
||||||
sharded_client
|
|
||||||
.warmup(
|
|
||||||
max_input_tokens as u32,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_total_tokens as u32,
|
|
||||||
max_batch_size,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.map_err(WebServerError::Warmup)?,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let health_ext =
|
|
||||||
HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone());
|
|
||||||
let scheduler = Arc::new(SchedulerV2::new(
|
|
||||||
sharded_client,
|
|
||||||
waiting_served_ratio,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_batch_total_tokens,
|
|
||||||
max_waiting_tokens,
|
|
||||||
max_batch_size,
|
|
||||||
shard_info.requires_padding,
|
|
||||||
shard_info.window_size,
|
|
||||||
shard_info.speculate,
|
|
||||||
generation_health,
|
|
||||||
));
|
|
||||||
tracing::info!("Using scheduler V2");
|
|
||||||
|
|
||||||
(scheduler, health_ext, shard_info, max_batch_total_tokens)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
|
||||||
|
|
||||||
let validation = Validation::new(
|
let validation = Validation::new(
|
||||||
validation_workers,
|
validation_workers,
|
||||||
|
@ -1644,7 +1513,7 @@ pub async fn run(
|
||||||
);
|
);
|
||||||
|
|
||||||
let infer = Infer::new(
|
let infer = Infer::new(
|
||||||
scheduler,
|
backend,
|
||||||
validation,
|
validation,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
tokenizer_config,
|
tokenizer_config,
|
||||||
|
@ -1681,8 +1550,8 @@ pub async fn run(
|
||||||
let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size"));
|
let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size"));
|
||||||
let batch_size_buckets: Vec<f64> = (0..1024).map(|x| (x + 1) as f64).collect();
|
let batch_size_buckets: Vec<f64> = (0..1024).map(|x| (x + 1) as f64).collect();
|
||||||
// Speculated tokens buckets
|
// Speculated tokens buckets
|
||||||
let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens"));
|
// let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens"));
|
||||||
let skipped_buckets: Vec<f64> = (0..shard_info.speculate + 1).map(|x| x as f64).collect();
|
// let skipped_buckets: Vec<f64> = (0..shard_info.speculate + 1).map(|x| x as f64).collect();
|
||||||
|
|
||||||
// Prometheus handler
|
// Prometheus handler
|
||||||
let builder = PrometheusBuilder::new()
|
let builder = PrometheusBuilder::new()
|
||||||
|
@ -1695,9 +1564,9 @@ pub async fn run(
|
||||||
.set_buckets_for_metric(max_new_tokens_matcher, &max_new_tokens_buckets)
|
.set_buckets_for_metric(max_new_tokens_matcher, &max_new_tokens_buckets)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.set_buckets_for_metric(batch_size_matcher, &batch_size_buckets)
|
.set_buckets_for_metric(batch_size_matcher, &batch_size_buckets)
|
||||||
.unwrap()
|
|
||||||
.set_buckets_for_metric(skipped_matcher, &skipped_buckets)
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
// .set_buckets_for_metric(skipped_matcher, &skipped_buckets)
|
||||||
|
// .unwrap();
|
||||||
let prom_handle = builder
|
let prom_handle = builder
|
||||||
.install_recorder()
|
.install_recorder()
|
||||||
.expect("failed to install metrics recorder");
|
.expect("failed to install metrics recorder");
|
||||||
|
@ -1713,18 +1582,18 @@ pub async fn run(
|
||||||
let info = Info {
|
let info = Info {
|
||||||
model_id: model_info.model_id,
|
model_id: model_info.model_id,
|
||||||
model_sha: model_info.sha,
|
model_sha: model_info.sha,
|
||||||
model_dtype: shard_info.dtype,
|
// model_dtype: shard_info.dtype,
|
||||||
model_device_type: shard_info.device_type,
|
// model_device_type: shard_info.device_type,
|
||||||
model_pipeline_tag: model_info.pipeline_tag,
|
model_pipeline_tag: model_info.pipeline_tag,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
max_best_of,
|
max_best_of,
|
||||||
max_stop_sequences,
|
max_stop_sequences,
|
||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
waiting_served_ratio,
|
// waiting_served_ratio,
|
||||||
max_batch_total_tokens,
|
// max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
// max_waiting_tokens,
|
||||||
max_batch_size,
|
// max_batch_size,
|
||||||
validation_workers,
|
validation_workers,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
router: env!("CARGO_PKG_NAME"),
|
router: env!("CARGO_PKG_NAME"),
|
||||||
|
@ -1858,7 +1727,6 @@ pub async fn run(
|
||||||
// add layers after routes
|
// add layers after routes
|
||||||
app = app
|
app = app
|
||||||
.layer(Extension(info))
|
.layer(Extension(info))
|
||||||
.layer(Extension(health_ext.clone()))
|
|
||||||
.layer(Extension(compat_return_full_text))
|
.layer(Extension(compat_return_full_text))
|
||||||
.layer(Extension(infer))
|
.layer(Extension(infer))
|
||||||
.layer(Extension(compute_type))
|
.layer(Extension(compute_type))
|
||||||
|
@ -1960,7 +1828,7 @@ impl From<InferError> for Event {
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum WebServerError {
|
pub enum WebServerError {
|
||||||
#[error("Backend error: {0}")]
|
#[error("Backend error: {0}")]
|
||||||
Backend(#[from] BackendError),
|
Backend(#[from] V3Error),
|
||||||
#[error("Axum error: {0}")]
|
#[error("Axum error: {0}")]
|
||||||
Axum(#[from] axum::BoxError),
|
Axum(#[from] axum::BoxError),
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue