diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 0821a456..0aedf563 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -28,8 +28,8 @@ struct Args { max_input_length: usize, #[clap(default_value = "32", long, env)] max_batch_size: usize, - #[clap(default_value = "5", long, env)] - max_waiting_time: u64, + #[clap(default_value = "20", long, env)] + max_waiting_tokens: usize, #[clap(default_value = "3000", long, short, env)] port: u16, #[clap(default_value = "/tmp/text-generation-server", long, env)] @@ -41,7 +41,7 @@ struct Args { } fn main() -> ExitCode { - tracing_subscriber::fmt::init(); + tracing_subscriber::fmt().compact().with_ansi(false).init(); // Pattern match configuration let Args { @@ -51,7 +51,7 @@ fn main() -> ExitCode { max_concurrent_requests, max_input_length, max_batch_size, - max_waiting_time, + max_waiting_tokens, port, shard_uds_path, master_addr, @@ -148,8 +148,8 @@ fn main() -> ExitCode { &max_input_length.to_string(), "--max-batch-size", &max_batch_size.to_string(), - "--max-waiting-time", - &max_waiting_time.to_string(), + "--max-waiting-tokens", + &max_waiting_tokens.to_string(), "--port", &port.to_string(), "--master-shard-uds-path", diff --git a/router/src/batcher.rs b/router/src/batcher.rs index 4523dbff..052716a4 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -5,7 +5,6 @@ use axum::http::StatusCode; use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient}; use std::future::Future; use std::sync::Arc; -use std::time::Duration; use thiserror::Error; use tokio::sync::{oneshot, Notify}; use tokio::time::Instant; @@ -30,7 +29,7 @@ impl Batcher { pub(crate) fn new( client: ShardedClient, max_batch_size: usize, - max_waiting_time: Duration, + max_waiting_tokens: usize, ) -> Self { // Batcher shared state let db = Db::new(); @@ -41,7 +40,7 @@ impl Batcher { // Spawn batching background task that contains all the inference logic tokio::spawn(batching_task( max_batch_size, - max_waiting_time, + max_waiting_tokens, client, db.clone(), shared.clone(), @@ -55,7 +54,7 @@ impl Batcher { &self, input_length: usize, request: GenerateRequest, - ) -> Result { + ) -> Result { // One shot channel to communicate with the background batching task let (response_tx, response_rx) = oneshot::channel(); @@ -65,6 +64,7 @@ impl Batcher { response_tx, input_length, time: Instant::now(), + batch_time: None, }); // Notify the background task that we have a new entry in the database that needs @@ -87,7 +87,7 @@ impl Batcher { #[instrument(skip(client, db, shared))] async fn batching_task( max_batch_size: usize, - max_waiting_time: Duration, + max_waiting_tokens: usize, client: ShardedClient, db: Db, shared: Arc, @@ -103,8 +103,10 @@ async fn batching_task( // Get the next batch from the DB // This batch might be smaller than the maximum batch size if there are not enough requests // waiting in the DB - if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size, None) { + let mut waiting_tokens = 0; + if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) { let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await; + waiting_tokens += 1; // We loop until we do not receive any cached batch from the inference server (== until // all requests have met their stopping criteria) @@ -116,10 +118,20 @@ async fn batching_task( // If the current batch is too small, we try to add more requests to it if batch_size <= limit_min_batch_size { - // Get the next batch from the DB that meet our minimum size criteria + let min_size = match waiting_tokens { + // If we didn't onboard any new requests since >= max_waiting_tokens, we try + // to add a new batch even though its size might be small + _ if waiting_tokens >= max_waiting_tokens => None, + // Minimum size criteria + _ => Some(limit_min_batch_size as usize), + }; + + // Try to get a new batch if let Some((new_request_ids, new_batch)) = - db.next_batch(Some(limit_min_batch_size as usize), max_batch_size, None) + db.next_batch(min_size, max_batch_size) { + // Reset waiting counter + waiting_tokens = 0; // Generate one token for this new batch to have the attention past in cache let new_cached_batch = wrap_future(client.generate(new_batch), new_request_ids, &db).await; @@ -129,24 +141,11 @@ async fn batching_task( batches.push(new_cached_batch); } } - // If we don't have enough requests to meet the minimum size criteria, we - // try to get the next batch from the DB that have been waiting over - // the max_waiting_time - else if let Some((new_request_ids, new_batch)) = - db.next_batch(None, max_batch_size, Some(max_waiting_time)) - { - let new_cached_batch = - wrap_future(client.generate(new_batch), new_request_ids, &db).await; - // Extend current batch with the new batch - if let Some(new_cached_batch) = new_cached_batch { - request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id)); - batches.push(new_cached_batch); - } - } } cached_batch = wrap_future(client.generate_with_cache(batches), request_ids, &db).await; + waiting_tokens += 1; } } } @@ -188,11 +187,25 @@ fn send_generated(finished: Vec, db: &Db) { let entry = db .remove(&output.request.unwrap().id) .expect("ID not found in db. This is a bug."); + let response = InferResponse { + output: output.output, + queued: entry.time, + start: entry.batch_time.unwrap(), // unwrap is always valid + end: Instant::now(), + }; // unwrap_or is valid here as we don't care if the receiver is gone. - entry.response_tx.send(Ok(output.output)).unwrap_or(()); + entry.response_tx.send(Ok(response)).unwrap_or(()); }); } +#[derive(Debug)] +pub(crate) struct InferResponse { + pub(crate) output: String, + pub(crate) queued: Instant, + pub(crate) start: Instant, + pub(crate) end: Instant, +} + #[derive(Debug, Error)] pub enum InferError { #[error("Request failed during generation: {0}")] diff --git a/router/src/db.rs b/router/src/db.rs index 9518fa1d..76a08ae0 100644 --- a/router/src/db.rs +++ b/router/src/db.rs @@ -1,10 +1,10 @@ +use crate::InferResponse; /// This code is massively inspired by Tokio mini-redis use crate::{GenerateParameters, GenerateRequest}; use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request}; use parking_lot::Mutex; use std::collections::BTreeMap; use std::sync::Arc; -use std::time::Duration; use tokio::sync::oneshot::Sender; use tokio::time::Instant; @@ -14,11 +14,13 @@ pub(crate) struct Entry { /// Request pub request: GenerateRequest, /// Response sender to communicate between the Batcher and the batching_task - pub response_tx: Sender>, + pub response_tx: Sender>, /// Number of tokens in the input pub input_length: usize, /// Instant when this entry was created pub time: Instant, + /// Instant when this entry was added to a batch + pub batch_time: Option, } /// Request Database @@ -51,11 +53,7 @@ struct State { impl State { /// Get the next requests - fn next_requests( - &self, - max_size: usize, - min_waiting_time: Option, - ) -> Option<(Vec, Vec)> { + fn next_requests(&self, max_size: usize) -> Option<(Vec, Vec)> { // Iterates for max_size over the BTreemap starting from next_batch_start_id let mut requests = Vec::new(); let mut ids = Vec::new(); @@ -67,15 +65,6 @@ impl State { // Take max_size .take(max_size) { - if let Some(min_waiting_time) = min_waiting_time { - // Only take entries that waited for at least min_waiting_time - if entry.time.elapsed() < min_waiting_time { - // Since entries are ordered, we already know that all following entries won't - // satisfy the condition - break; - } - } - requests.push(Request { id: *id, inputs: entry.request.inputs.clone(), @@ -134,19 +123,22 @@ impl Db { &self, min_size: Option, max_size: usize, - min_waiting_time: Option, ) -> Option<(Vec, Batch)> { // Acquire lock let mut state = self.shared.state.lock(); // Get requests from the database - if let Some((ids, requests)) = state.next_requests(max_size, min_waiting_time) { + if let Some((ids, requests)) = state.next_requests(max_size) { if let Some(min_size) = min_size { // If min_size is set, only return a batch if there are enough requests if requests.len() < min_size { return None; } } + ids.iter().for_each(|id| { + // Set batch_time for each request + state.entries.get_mut(id).unwrap().batch_time = Some(Instant::now()); + }); // Batch size let size = requests.len(); diff --git a/router/src/lib.rs b/router/src/lib.rs index 02b912a3..6604a91f 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -4,7 +4,7 @@ mod db; pub mod server; mod validation; -use batcher::Batcher; +use batcher::{Batcher, InferResponse}; use db::{Db, Entry}; use serde::{Deserialize, Serialize}; use validation::Validation; @@ -64,5 +64,3 @@ pub(crate) struct GenerateRequest { pub(crate) struct GeneratedText { pub generated_text: String, } - -pub(crate) type GenerateResponse = Vec; diff --git a/router/src/main.rs b/router/src/main.rs index 49051b37..6d1a0fb9 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -2,7 +2,6 @@ use bloom_inference_client::ShardedClient; use clap::Parser; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; -use std::time::Duration; use text_generation_router::server; use tokenizers::Tokenizer; @@ -16,8 +15,8 @@ struct Args { max_input_length: usize, #[clap(default_value = "32", long, env)] max_batch_size: usize, - #[clap(default_value = "5", long, env)] - max_waiting_time: u64, + #[clap(default_value = "20", long, env)] + max_waiting_tokens: usize, #[clap(default_value = "3000", long, short, env)] port: u16, #[clap(default_value = "/tmp/bloom-inference-0", long, env)] @@ -36,19 +35,19 @@ fn main() -> Result<(), std::io::Error> { max_concurrent_requests, max_input_length, max_batch_size, - max_waiting_time, + max_waiting_tokens, port, master_shard_uds_path, tokenizer_name, validation_workers, } = args; + tracing_subscriber::fmt().compact().with_ansi(false).init(); + if validation_workers == 1 { panic!("validation_workers must be > 0"); } - let max_waiting_time = Duration::from_secs(max_waiting_time); - // Download and instantiate tokenizer // This will only be used to validate payloads // @@ -61,8 +60,6 @@ fn main() -> Result<(), std::io::Error> { .build() .unwrap() .block_on(async { - tracing_subscriber::fmt::init(); - // Instantiate sharded client from the master unix socket let sharded_client = ShardedClient::connect_uds(master_shard_uds_path) .await @@ -82,7 +79,7 @@ fn main() -> Result<(), std::io::Error> { max_concurrent_requests, max_input_length, max_batch_size, - max_waiting_time, + max_waiting_tokens, sharded_client, tokenizer, validation_workers, diff --git a/router/src/server.rs b/router/src/server.rs index 42258313..02c4a497 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,14 +1,12 @@ -use crate::{ - Batcher, GenerateParameters, GenerateRequest, GenerateResponse, GeneratedText, Validation, -}; +use crate::{Batcher, GenerateParameters, GenerateRequest, GeneratedText, Validation}; use axum::extract::Extension; -use axum::http::StatusCode; +use axum::http::{HeaderMap, StatusCode}; +use axum::response::IntoResponse; use axum::routing::{get, post}; use axum::{Json, Router}; use bloom_inference_client::ShardedClient; use std::net::SocketAddr; use std::sync::Arc; -use std::time::Duration; use tokenizers::Tokenizer; use tokio::signal; use tokio::sync::Semaphore; @@ -59,12 +57,21 @@ async fn health(state: Extension) -> Result<(), (StatusCode, String } /// Generate method -#[instrument(skip(state), fields(time, time_per_token))] +#[instrument( + skip(state), + fields( + total_time, + validation_time, + queue_time, + inference_time, + time_per_token + ) +)] async fn generate( state: Extension, req: Json, -) -> Result, (StatusCode, String)> { - let start = Instant::now(); +) -> Result { + let start_time = Instant::now(); // Limit concurrent requests by acquiring a permit from the semaphore let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| { ( @@ -84,19 +91,51 @@ async fn generate( .await?; // Inference - let generated_text = state.batcher.infer(input_length, validated_request).await?; + let response = state.batcher.infer(input_length, validated_request).await?; + + // Timings + let total_time = start_time.elapsed(); + let validation_time = response.queued - start_time; + let queue_time = response.start - response.queued; + let inference_time = response.end - response.start; + let time_per_token = inference_time / req.parameters.max_new_tokens; + + // Headers + let mut headers = HeaderMap::new(); + headers.insert( + "x-total-time", + total_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-validation-time", + validation_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-queue-time", + queue_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-inference-time", + inference_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-time-per-token", + time_per_token.as_millis().to_string().parse().unwrap(), + ); // Tracing metadata - tracing::Span::current().record("time", format!("{:?}", start.elapsed())); - tracing::Span::current().record( - "time_per_token", - format!("{:?}", start.elapsed() / req.parameters.max_new_tokens), - ); - tracing::info!("response: {}", generated_text); + tracing::Span::current().record("total_time", format!("{:?}", total_time)); + tracing::Span::current().record("validation_time", format!("{:?}", validation_time)); + tracing::Span::current().record("queue_time", format!("{:?}", queue_time)); + tracing::Span::current().record("inference_time", format!("{:?}", inference_time)); + tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token)); + tracing::info!("Output: {}", response.output); // Send response - let response = vec![GeneratedText { generated_text }]; - Ok(Json(response)) + let response = vec![GeneratedText { + generated_text: response.output, + }]; + Ok((headers, Json(response))) } /// Serving method @@ -105,14 +144,14 @@ pub async fn run( max_concurrent_requests: usize, max_input_length: usize, max_batch_size: usize, - max_waiting_time: Duration, + max_waiting_tokens: usize, client: ShardedClient, tokenizer: Tokenizer, validation_workers: usize, addr: SocketAddr, ) { // Create state - let batcher = Batcher::new(client, max_batch_size, max_waiting_time); + let batcher = Batcher::new(client, max_batch_size, max_waiting_tokens); let validation = Validation::new(validation_workers, tokenizer, max_input_length); let shared_state = ServerState { validation, diff --git a/router/src/validation.rs b/router/src/validation.rs index 31b4c49f..11712d0a 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -127,7 +127,10 @@ fn validation_worker( if input_length > max_input_length { response_tx - .send(Err(ValidationError::InputLength(input_length, max_input_length))) + .send(Err(ValidationError::InputLength( + input_length, + max_input_length, + ))) .unwrap_or(()); continue; }