diff --git a/launcher/src/main.rs b/launcher/src/main.rs index fcac736..7450b3f 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -39,8 +39,12 @@ struct Args { max_input_length: usize, #[clap(default_value = "1512", long, env)] max_total_tokens: usize, - #[clap(default_value = "32", long, env)] - max_batch_size: usize, + #[clap(long, env)] + max_batch_size: Option, + #[clap(default_value = "1.2", long, env)] + waiting_served_ratio: f32, + #[clap(default_value = "32000", long, env)] + max_batch_total_tokens: u32, #[clap(default_value = "20", long, env)] max_waiting_tokens: usize, #[clap(default_value = "3000", long, short, env)] @@ -93,6 +97,8 @@ fn main() -> ExitCode { max_input_length, max_total_tokens, max_batch_size, + max_batch_total_tokens, + waiting_served_ratio, max_waiting_tokens, port, shard_uds_path, @@ -380,8 +386,8 @@ fn main() -> ExitCode { max_input_length.to_string(), "--max-total-tokens".to_string(), max_total_tokens.to_string(), - "--max-batch-size".to_string(), - max_batch_size.to_string(), + "--waiting-served-ratio".to_string(), + waiting_served_ratio.to_string(), "--max-waiting-tokens".to_string(), max_waiting_tokens.to_string(), "--port".to_string(), @@ -392,6 +398,15 @@ fn main() -> ExitCode { model_id, ]; + // Deprecate max_batch_size + if let Some(max_batch_size) = max_batch_size { + argv.push("--max-batch-size".to_string()); + argv.push(max_batch_size.to_string()) + } else { + argv.push("--max-batch-total-tokens".to_string()); + argv.push(max_batch_total_tokens.to_string()) + } + // Model optional revision if let Some(ref revision) = revision { argv.push("--revision".to_string()); diff --git a/proto/generate.proto b/proto/generate.proto index 2bf7385..ad47409 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -9,6 +9,8 @@ service TextGenerationService { rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {} /// Empties batch cache rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse); + /// Remove requests from a cached batch + rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse); /// Prefill batch and decode first token rpc Prefill (PrefillRequest) returns (PrefillResponse); /// Decode token for a list of prefilled batches @@ -89,6 +91,8 @@ message Batch { repeated Request requests = 2; /// Batch size (==len(requests)) uint32 size = 3; + /// Maximum number of tokens this batch will grow to + uint32 max_tokens = 4; } enum FinishReason { @@ -134,6 +138,19 @@ message Generation { GeneratedText generated_text = 7; } +message FilterBatchRequest { + /// Batch ID + uint64 batch_id = 1; + /// Requests to keep + repeated Request keep_requests = 2; +} + +message FilterBatchResponse { + /// Filtered Batch (cached) + Batch batch = 1; +} + + message PrefillRequest { /// Batch Batch batch = 1; diff --git a/router/client/src/client.rs b/router/client/src/client.rs index cccd500..7cadf43 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -70,6 +70,22 @@ impl Client { Ok(()) } + /// Filter a cached batch + #[instrument(skip(self))] + pub async fn filter_batch( + &mut self, + batch_id: u64, + keep_requests: Vec, + ) -> Result> { + let request = tonic::Request::new(FilterBatchRequest { + batch_id, + keep_requests, + }) + .inject_context(); + let filtered_batch = self.stub.filter_batch(request).await?.into_inner(); + Ok(filtered_batch.batch) + } + /// Generate one token for each request in the given batch /// /// Returns Generation for each request in batch diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 903c7a6..469d75f 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -1,6 +1,6 @@ /// Multi shard Client use crate::Result; -use crate::{Batch, Client, Generation, ShardInfo}; +use crate::{Batch, Client, Generation, Request, ShardInfo}; use futures::future::join_all; use tonic::transport::Uri; use tracing::instrument; @@ -59,6 +59,22 @@ impl ShardedClient { join_all(futures).await.into_iter().collect() } + /// Filter a cached batch + #[instrument(skip(self))] + pub async fn filter_batch( + &mut self, + batch_id: u64, + keep_requests: Vec, + ) -> Result> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.filter_batch(batch_id, keep_requests.clone()))) + .collect(); + // all shards return the same message + join_all(futures).await.pop().unwrap() + } + /// Generate one token for each request in the given batch /// /// Returns Generation for each request in batch diff --git a/router/src/infer.rs b/router/src/infer.rs index 484720a..8b44ec8 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -39,12 +39,14 @@ impl Infer { pub(crate) fn new( client: ShardedClient, validation: Validation, - max_batch_size: usize, + waiting_served_ratio: f32, + max_batch_total_tokens: u32, max_waiting_tokens: usize, max_concurrent_requests: usize, + requires_padding: bool, ) -> Self { // Infer shared state - let queue = Queue::new(); + let queue = Queue::new(requires_padding); let shared = Arc::new(Shared { batching_task: Notify::new(), }); @@ -52,7 +54,8 @@ impl Infer { // Spawn batching background task that contains all the inference logic tokio::spawn(batching_task( client, - max_batch_size, + waiting_served_ratio, + max_batch_total_tokens, max_waiting_tokens, queue.clone(), shared.clone(), @@ -232,18 +235,12 @@ impl Infer { /// Batches requests and sends them to the inference server async fn batching_task( mut client: ShardedClient, - max_batch_size: usize, + waiting_served_ratio: f32, + max_batch_total_tokens: u32, max_waiting_tokens: usize, queue: Queue, shared: Arc, ) { - // Minimum batch size after which we try to add more requests - let limit_min_batch_size = if max_batch_size > 1 { - (max_batch_size / 2) as u32 - } else { - 0 - }; - // Infinite loop loop { // Wait for a notification from the Infer struct @@ -252,7 +249,9 @@ async fn batching_task( // Get the next batch from the queue // This batch might be smaller than the maximum batch size if there are not enough requests // waiting in the queue - while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await { + while let Some((mut entries, batch, span)) = + queue.next_batch(None, max_batch_total_tokens).await + { let mut cached_batch = prefill(&mut client, batch, &mut entries) .instrument(span) .await; @@ -263,48 +262,57 @@ async fn batching_task( while let Some(batch) = cached_batch { // Get current batch info let batch_size = batch.size; + let batch_max_tokens = batch.max_tokens; let mut batches = vec![batch]; metrics::gauge!("tgi_batch_current_size", batch_size as f64); + metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64); - // If the current batch is too small, we try to add more requests to it - if batch_size <= limit_min_batch_size { - let min_size = match waiting_tokens { - // If we didn't onboard any new requests since >= max_waiting_tokens, we try - // to add a new batch even though its size might be small - _ if waiting_tokens >= max_waiting_tokens => None, - // Minimum size criteria - _ => Some(limit_min_batch_size as usize), - }; + let min_size = if waiting_tokens >= max_waiting_tokens { + // If we didn't onboard any new requests since >= max_waiting_tokens, we try + // to add a new batch even though its size might be small + None + } else { + // Minimum batch size + Some((batch_size as f32 * waiting_served_ratio).floor() as usize) + }; - // Try to get a new batch - if let Some((mut new_entries, new_batch, span)) = queue - .next_batch(min_size, max_batch_size - batch_size as usize) - .await - { - entries.iter_mut().for_each(|(_, entry)| { - // Create a new span to add the info that this entry is waiting - // because a new batch is being computed - let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); - // Add relationships - span.follows_from(&entry_waiting_span); - entry_waiting_span.follows_from(&span); - // Update entry - entry.temp_span = Some(entry_waiting_span); - }); + let token_budget = max_batch_total_tokens - batch_max_tokens; - // Generate one token for this new batch to have the attention past in cache - let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries) - .instrument(span) - .await; - // Reset waiting counter - waiting_tokens = 1; - // Extend current batch with the new batch - if let Some(new_cached_batch) = new_cached_batch { - entries.extend(new_entries); - batches.push(new_cached_batch); - } + // Try to get a new batch + if let Some((mut new_entries, new_batch, span)) = + queue.next_batch(min_size, token_budget).await + { + // Tracking metrics + if min_size.is_some() { + metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure"); + } else { + metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded"); + } + + entries.iter_mut().for_each(|(_, entry)| { + // Create a new span to add the info that this entry is waiting + // because a new batch is being computed + let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); + // Add relationships + span.follows_from(&entry_waiting_span); + entry_waiting_span.follows_from(&span); + // Update entry + entry.temp_span = Some(entry_waiting_span); + }); + + // Generate one token for this new batch to have the attention past in cache + let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries) + .instrument(span) + .await; + // Reset waiting counter + waiting_tokens = 1; + // Extend current batch with the new batch + if let Some(new_cached_batch) = new_cached_batch { + entries.extend(new_entries); + batches.push(new_cached_batch); } } + // Create span for this batch to add context to inference calls let next_batch_size = entries.len(); let next_batch_span = @@ -325,6 +333,7 @@ async fn batching_task( waiting_tokens += 1; } metrics::gauge!("tgi_batch_current_size", 0.0); + metrics::gauge!("tgi_batch_current_max_tokens", 0.0); } } } @@ -341,22 +350,11 @@ async fn prefill( match client.prefill(batch).await { Ok((generations, next_batch)) => { + // Send generated tokens and filter stopped entries filter_send_generations(generations, entries); // Filter next batch and remove requests that were stopped - let next_batch = match next_batch { - None => None, - Some(batch) => { - let id = batch.id; - let next_batch = filter_batch(batch, entries); - // Next batch is now empty - // Clear it from the Python shards cache - if next_batch.is_none() { - let _ = client.clear_cache(Some(id)).await; - } - next_batch - } - }; + let next_batch = filter_batch(client, next_batch, entries).await; metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill"); @@ -384,22 +382,11 @@ async fn decode( match client.decode(batches).await { Ok((generations, next_batch)) => { + // Send generated tokens and filter stopped entries filter_send_generations(generations, entries); // Filter next batch and remove requests that were stopped - let next_batch = match next_batch { - None => None, - Some(batch) => { - let id = batch.id; - let next_batch = filter_batch(batch, entries); - // Next batch is now empty - // Clear it from the Python shards cache - if next_batch.is_none() { - let _ = client.clear_cache(Some(id)).await; - } - next_batch - } - }; + let next_batch = filter_batch(client, next_batch, entries).await; metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); @@ -419,14 +406,35 @@ async fn decode( /// Filter a `batch` and remove all requests not present in `entries` #[instrument(skip_all)] -fn filter_batch(mut batch: Batch, entries: &IntMap) -> Option { - batch.requests.retain(|r| entries.contains_key(&r.id)); - let size = batch.requests.len(); - if size == 0 { - return None; +async fn filter_batch( + client: &mut ShardedClient, + next_batch: Option, + entries: &IntMap, +) -> Option { + let mut batch = next_batch?; + + // No need to filter + if batch.size as usize == entries.len() { + return Some(batch); + } + + let id = batch.id; + + // Retain only requests that are still in entries + batch.requests.retain(|r| entries.contains_key(&r.id)); + + if batch.requests.is_empty() { + // All requests have been filtered out + // Next batch is now empty + // Clear it from the Python shards cache + // We unwrap here as we need to panic since we cannot recover if this method fails + client.clear_cache(Some(id)).await.unwrap(); + None + } else { + // Filter Python shard cache + // We unwrap here as we need to panic since we cannot recover if this method fails + client.filter_batch(id, batch.requests).await.unwrap() } - batch.size = size as u32; - Some(batch) } /// Send one or multiple `InferStreamResponse` to Infer for all `entries` diff --git a/router/src/main.rs b/router/src/main.rs index 712071b..28db607 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -31,8 +31,12 @@ struct Args { max_input_length: usize, #[clap(default_value = "1512", long, env)] max_total_tokens: usize, - #[clap(default_value = "32", long, env)] - max_batch_size: usize, + #[clap(long, env)] + max_batch_size: Option, + #[clap(default_value = "1.2", long, env)] + waiting_served_ratio: f32, + #[clap(default_value = "32000", long, env)] + max_batch_total_tokens: u32, #[clap(default_value = "20", long, env)] max_waiting_tokens: usize, #[clap(default_value = "3000", long, short, env)] @@ -64,6 +68,8 @@ fn main() -> Result<(), std::io::Error> { max_input_length, max_total_tokens, max_batch_size, + waiting_served_ratio, + mut max_batch_total_tokens, max_waiting_tokens, port, master_shard_uds_path, @@ -119,6 +125,12 @@ fn main() -> Result<(), std::io::Error> { .block_on(async { init_logging(otlp_endpoint, json_output); + if let Some(max_batch_size) = max_batch_size{ + tracing::warn!("`max-batch-size` is deprecated. Use `max-batch-total-tokens` instead"); + max_batch_total_tokens = (max_batch_size * max_total_tokens) as u32; + tracing::warn!("Overriding `max-batch-total-tokens` value with `max-batch-size` * `max-total-tokens` = {max_batch_total_tokens}"); + } + if tokenizer.is_none() { tracing::warn!( "Could not find a fast tokenizer implementation for {tokenizer_name}" @@ -174,7 +186,8 @@ fn main() -> Result<(), std::io::Error> { max_stop_sequences, max_input_length, max_total_tokens, - max_batch_size, + waiting_served_ratio, + max_batch_total_tokens, max_waiting_tokens, sharded_client, tokenizer, diff --git a/router/src/queue.rs b/router/src/queue.rs index 43651ff..d970ebf 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -2,7 +2,6 @@ use crate::infer::InferError; use crate::infer::InferStreamResponse; use crate::validation::ValidGenerateRequest; use nohash_hasher::{BuildNoHashHasher, IntMap}; -use std::cmp::min; use std::collections::VecDeque; use text_generation_client::{Batch, Request}; use tokio::sync::oneshot; @@ -34,12 +33,12 @@ pub(crate) struct Queue { } impl Queue { - pub(crate) fn new() -> Self { + pub(crate) fn new(requires_padding: bool) -> Self { // Create channel let (queue_sender, queue_receiver) = flume::unbounded(); // Launch background queue task - tokio::spawn(queue_task(queue_receiver)); + tokio::spawn(queue_task(requires_padding, queue_receiver)); Self { queue_sender } } @@ -59,7 +58,7 @@ impl Queue { pub(crate) async fn next_batch( &self, min_size: Option, - max_size: usize, + token_budget: u32, ) -> Option { // Create response channel let (response_sender, response_receiver) = oneshot::channel(); @@ -68,7 +67,7 @@ impl Queue { self.queue_sender .send(QueueCommand::NextBatch { min_size, - max_size, + token_budget, response_sender, span: Span::current(), }) @@ -80,20 +79,24 @@ impl Queue { } // Background task responsible of the queue state -async fn queue_task(receiver: flume::Receiver) { - let mut state = State::new(); +async fn queue_task(requires_padding: bool, receiver: flume::Receiver) { + let mut state = State::new(requires_padding); while let Ok(cmd) = receiver.recv_async().await { match cmd { - QueueCommand::Append(entry, span) => span.in_scope(|| state.append(entry)), + QueueCommand::Append(entry, span) => { + span.in_scope(|| state.append(entry)); + metrics::increment_gauge!("tgi_queue_size", 1.0); + } QueueCommand::NextBatch { min_size, - max_size, + token_budget, response_sender, span, } => span.in_scope(|| { - let next_batch = state.next_batch(min_size, max_size); + let next_batch = state.next_batch(min_size, token_budget); response_sender.send(next_batch).unwrap_or(()); + metrics::gauge!("tgi_queue_size", state.entries.len() as f64); }), } } @@ -110,14 +113,18 @@ struct State { /// Id of the next batch next_batch_id: u64, + + /// Whether the model is using padding + requires_padding: bool, } impl State { - fn new() -> Self { + fn new(requires_padding: bool) -> Self { Self { entries: VecDeque::with_capacity(128), next_id: 0, next_batch_id: 0, + requires_padding, } } @@ -130,11 +137,10 @@ impl State { // Push entry in the queue self.entries.push_back((self.next_id, entry)); self.next_id += 1; - metrics::increment_gauge!("tgi_queue_size", 1.0); } // Get the next batch - fn next_batch(&mut self, min_size: Option, max_size: usize) -> Option { + fn next_batch(&mut self, min_size: Option, token_budget: u32) -> Option { if self.entries.is_empty() { return None; } @@ -146,17 +152,19 @@ impl State { } } - let max_batch_size = min(self.entries.len(), max_size); - // Create span for this batch to add context to inference calls let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); next_batch_span.follows_from(&Span::current()); - let mut batch_requests = Vec::with_capacity(max_batch_size); + let mut batch_requests = Vec::with_capacity(self.entries.len()); let mut batch_entries = - IntMap::with_capacity_and_hasher(max_batch_size, BuildNoHashHasher::default()); + IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default()); - // Iterate on buffer + let mut max_input_length = 0; + let mut prefill_tokens: u32 = 0; + let mut decode_tokens: u32 = 0; + + // Pop entries starting from the front of the queue while let Some((id, mut entry)) = self.entries.pop_front() { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) @@ -165,6 +173,24 @@ impl State { continue; } + if self.requires_padding { + // We pad to max input length in the Python shards + // We need to take these padding tokens into the equation + max_input_length = max_input_length.max(entry.request.input_length); + prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length + } else { + prefill_tokens += entry.request.input_length; + } + + decode_tokens += entry.request.stopping_parameters.max_new_tokens; + + if (prefill_tokens + decode_tokens) > token_budget { + // Entry is over budget + // Add it back to the front + self.entries.push_front((id, entry)); + break; + } + // Create a new span to link the batch back to this entry let entry_batch_span = info_span!(parent: &entry.span, "infer"); // Add relationships @@ -184,21 +210,29 @@ impl State { entry.batch_time = Some(Instant::now()); // Insert in batch_entries IntMap batch_entries.insert(id, entry); - - if batch_requests.len() == max_batch_size { - // We have enough requests in the batch - break; - } } - metrics::gauge!("tgi_queue_size", self.entries.len() as f64); - - // Maybe all entries were dropped because their channel were closed + // Empty batch if batch_requests.is_empty() { return None; } - // Final batch size once we dropped entries + // Check if our batch is big enough + if let Some(min_size) = min_size { + // Batch is too small + if batch_requests.len() < min_size { + // Add back entries to the queue in the correct order + for r in batch_requests.into_iter().rev() { + let id = r.id; + let entry = batch_entries.remove(&id).unwrap(); + self.entries.push_front((id, entry)); + } + + return None; + } + } + + // Final batch size let size = batch_requests.len() as u32; next_batch_span.record("batch_size", size); @@ -206,11 +240,13 @@ impl State { id: self.next_batch_id, requests: batch_requests, size, + max_tokens: (prefill_tokens + decode_tokens), }; // Increment batch id self.next_batch_id += 1; metrics::histogram!("tgi_batch_next_size", batch.size as f64); + Some((batch_entries, batch, next_batch_span)) } } @@ -222,7 +258,7 @@ enum QueueCommand { Append(Entry, Span), NextBatch { min_size: Option, - max_size: usize, + token_budget: u32, response_sender: oneshot::Sender>, span: Span, }, @@ -243,6 +279,7 @@ mod tests { let entry = Entry { request: ValidGenerateRequest { inputs: "".to_string(), + input_length: 0, truncate: 0, parameters: NextTokenChooserParameters { temperature: 0.0, @@ -256,7 +293,7 @@ mod tests { }, stopping_parameters: StoppingCriteriaParameters { ignore_eos_token: false, - max_new_tokens: 0, + max_new_tokens: 1, stop_sequences: vec![], }, }, @@ -271,7 +308,7 @@ mod tests { #[test] fn test_append() { - let mut state = State::new(); + let mut state = State::new(false); let (entry, _guard) = default_entry(); assert_eq!(state.next_id, 0); @@ -287,7 +324,7 @@ mod tests { #[test] fn test_next_batch_empty() { - let mut state = State::new(); + let mut state = State::new(false); assert!(state.next_batch(None, 1).is_none()); assert!(state.next_batch(Some(1), 1).is_none()); @@ -295,7 +332,7 @@ mod tests { #[test] fn test_next_batch_min_size() { - let mut state = State::new(); + let mut state = State::new(false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -326,8 +363,8 @@ mod tests { } #[test] - fn test_next_batch_max_size() { - let mut state = State::new(); + fn test_next_batch_token_budget() { + let mut state = State::new(false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -360,14 +397,14 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = Queue::new(); + let queue = Queue::new(false); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = Queue::new(); + let queue = Queue::new(false); assert!(queue.next_batch(None, 1).await.is_none()); assert!(queue.next_batch(Some(1), 1).await.is_none()); @@ -375,7 +412,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = Queue::new(); + let queue = Queue::new(false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -397,8 +434,8 @@ mod tests { } #[tokio::test] - async fn test_queue_next_batch_max_size() { - let queue = Queue::new(); + async fn test_queue_next_batch_token_budget() { + let queue = Queue::new(false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -423,7 +460,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(); + let queue = Queue::new(false); let (entry, _) = default_entry(); queue.append(entry); diff --git a/router/src/server.rs b/router/src/server.rs index 8891443..d1f7ae1 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -511,7 +511,8 @@ pub async fn run( max_stop_sequences: usize, max_input_length: usize, max_total_tokens: usize, - max_batch_size: usize, + waiting_served_ratio: f32, + max_batch_total_tokens: u32, max_waiting_tokens: usize, client: ShardedClient, tokenizer: Option, @@ -571,9 +572,11 @@ pub async fn run( let infer = Infer::new( client, validation, - max_batch_size, + waiting_served_ratio, + max_batch_total_tokens, max_waiting_tokens, max_concurrent_requests, + shard_info.requires_padding, ); // Duration buckets @@ -604,7 +607,7 @@ pub async fn run( .collect(); // Batch size buckets let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size")); - let batch_size_buckets: Vec = (0..max_batch_size).map(|x| (x + 1) as f64).collect(); + let batch_size_buckets: Vec = (0..1024).map(|x| (x + 1) as f64).collect(); // Prometheus handler let builder = PrometheusBuilder::new() diff --git a/router/src/validation.rs b/router/src/validation.rs index 5f1b89b..7f2f76e 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -69,7 +69,7 @@ impl Validation { inputs: String, truncate: Option, max_new_tokens: u32, - ) -> Result { + ) -> Result<(String, usize), ValidationError> { // If we have a fast tokenizer if let Some(sender) = &self.sender { // Create response channel @@ -105,25 +105,24 @@ impl Validation { } metrics::histogram!("tgi_request_input_length", input_length as f64); - Ok(inputs) + Ok((inputs, input_length)) } // Return inputs without validation else { // In this case, we don't know the real length in tokens of the inputs // However, the inputs will be truncated by the python servers // We make sure that truncate + max_new_tokens <= self.max_total_tokens + let input_length = truncate.unwrap_or(self.max_input_length); // Validate MaxNewTokens - if (truncate.unwrap_or(self.max_input_length) as u32 + max_new_tokens) - > self.max_total_tokens as u32 - { + if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 { return Err(ValidationError::MaxNewTokens( self.max_total_tokens - self.max_input_length, max_new_tokens, )); } - Ok(inputs) + Ok((inputs, input_length)) } } @@ -238,7 +237,7 @@ impl Validation { .unwrap_or(Ok(None))?; // Validate inputs - let inputs = self + let (inputs, input_length) = self .validate_input(request.inputs, truncate, max_new_tokens) .await?; @@ -262,6 +261,7 @@ impl Validation { Ok(ValidGenerateRequest { inputs, + input_length: input_length as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32, parameters, stopping_parameters, @@ -333,6 +333,7 @@ type TokenizerRequest = ( #[derive(Debug)] pub(crate) struct ValidGenerateRequest { pub inputs: String, + pub input_length: u32, pub truncate: u32, pub parameters: NextTokenChooserParameters, pub stopping_parameters: StoppingCriteriaParameters, diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 47d701e..f0adab9 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -181,9 +181,7 @@ def test_causal_lm_generate_token_completion_multi( next_batch = next_batch.filter([next_batch.requests[0]]) for _ in range( - stopping_criterias[0].max_new_tokens - - stopping_criterias[1].max_new_tokens - - 1 + stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1 ): generations, next_batch = default_bloom.generate_token(next_batch) assert len(generations) == len(next_batch) diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 03d3ef9..f1f13e4 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -174,14 +174,14 @@ def test_causal_lm_generate_token_completion_multi( == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens ) # Copy stopping_criterias before filtering - stopping_criterias = default_multi_requests_causal_lm_batch.stopping_criterias.copy() + stopping_criterias = ( + default_multi_requests_causal_lm_batch.stopping_criterias.copy() + ) next_batch = next_batch.filter([next_batch.requests[0]]) for _ in range( - stopping_criterias[0].max_new_tokens - - stopping_criterias[1].max_new_tokens - - 1 + stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1 ): generations, next_batch = default_causal_lm.generate_token(next_batch) assert len(generations) == len(next_batch) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 1db5abc..336c982 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -46,6 +46,9 @@ class CausalLMBatch(Batch): max_input_length: int padding_right_offset: int + # Maximum number of tokens this batch will grow to + max_tokens: int + # Past metadata keys_head_dim_last: bool = True @@ -54,6 +57,7 @@ class CausalLMBatch(Batch): id=self.batch_id, requests=self.requests, size=len(self), + max_tokens=self.max_tokens, ) @classmethod @@ -73,6 +77,7 @@ class CausalLMBatch(Batch): # Parse batch max_truncation = 0 padding_right_offset = 0 + max_decode_tokens = 0 for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i inputs.append(r.inputs) @@ -84,6 +89,7 @@ class CausalLMBatch(Batch): ) stopping_criterias.append(stopping_criteria) max_truncation = max(max_truncation, r.truncate) + max_decode_tokens += stopping_criteria.max_new_tokens padding_right_offset = max( padding_right_offset, stopping_criteria.max_new_tokens ) @@ -112,6 +118,8 @@ class CausalLMBatch(Batch): position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) + max_tokens = len(inputs) * max_input_length + max_decode_tokens + return cls( batch_id=pb.id, requests=pb.requests, @@ -128,6 +136,7 @@ class CausalLMBatch(Batch): stopping_criterias=stopping_criterias, max_input_length=max_input_length.item(), padding_right_offset=padding_right_offset, + max_tokens=max_tokens, ) @tracer.start_as_current_span("filter") @@ -150,6 +159,7 @@ class CausalLMBatch(Batch): next_token_choosers = [] stopping_criterias = [] + total_remaining_decode_tokens = 0 new_padding_right_offset = 0 for i, r in enumerate(requests): @@ -168,19 +178,23 @@ class CausalLMBatch(Batch): next_token_choosers.append(self.next_token_choosers[idx]) stopping_criteria = self.stopping_criterias[idx] stopping_criterias.append(stopping_criteria) - - new_padding_right_offset = max( - new_padding_right_offset, + remaining_decode_tokens = ( stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) + total_remaining_decode_tokens += remaining_decode_tokens + new_padding_right_offset = max( + new_padding_right_offset, remaining_decode_tokens + ) # Apply indices to input_ids, attention mask, past key values and other items that need to be cached input_ids = self.input_ids[keep_indices] position_ids = self.position_ids[keep_indices] self.attention_mask = self.attention_mask[ keep_indices, - -(self.padding_right_offset + max_input_length): - (self.attention_mask.shape[1] - self.padding_right_offset) + new_padding_right_offset, + -(self.padding_right_offset + max_input_length) : ( + self.attention_mask.shape[1] - self.padding_right_offset + ) + + new_padding_right_offset, ] # Ensure that past_key_values tensors can be updated in-place @@ -203,6 +217,8 @@ class CausalLMBatch(Batch): layer[1] = past_values[keep_indices, :, -past_kv_length:, :] del past_values + max_tokens = len(requests) * max_input_length + total_remaining_decode_tokens + self.requests = requests self.requests_idx_mapping = requests_idx_mapping self.input_ids = input_ids @@ -215,6 +231,7 @@ class CausalLMBatch(Batch): self.stopping_criterias = stopping_criterias self.max_input_length = max_input_length self.padding_right_offset = new_padding_right_offset + self.max_tokens = max_tokens return self @@ -239,6 +256,7 @@ class CausalLMBatch(Batch): all_input_ids = [] next_token_choosers = [] stopping_criterias = [] + max_tokens = 0 # Batch tensors input_ids = None @@ -314,7 +332,8 @@ class CausalLMBatch(Batch): # And ensure that we can update tensors in-place if type(batch.past_key_values[0]) == tuple: batch.past_key_values = [ - [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] for layer in batch.past_key_values + [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] + for layer in batch.past_key_values ] elif batch.past_key_values[0][0].shape == 3: for layer in batch.past_key_values: @@ -322,6 +341,10 @@ class CausalLMBatch(Batch): layer[k] = t.view(len(batch), -1, *t.shape[-2:]) start_index = end_index + # Add eventual padding tokens that were added while concatenating + max_tokens += batch.max_tokens + ( + max_input_length - batch.max_input_length + ) * len(batch) first_past_kvs = batches[0].past_key_values _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape @@ -371,7 +394,9 @@ class CausalLMBatch(Batch): start_index = end_index - padded_past_values = first_past_kvs[j][1].new_zeros(padded_past_values_shape) + padded_past_values = first_past_kvs[j][1].new_zeros( + padded_past_values_shape + ) start_index = 0 for batch in batches: past_values = batch.past_key_values[j][1] @@ -387,6 +412,7 @@ class CausalLMBatch(Batch): ] = past_values[:, :, -past_seq_len:, :] del past_values + # Update values start_index = end_index past_key_values.append([padded_past_keys, padded_past_values]) @@ -408,6 +434,7 @@ class CausalLMBatch(Batch): max_input_length=max_input_length, padding_right_offset=padding_right_offset, keys_head_dim_last=batches[0].keys_head_dim_last, + max_tokens=max_tokens, ) def __len__(self): diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 9cd9ed8..61ccca8 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -56,9 +56,15 @@ class FlashCausalLMBatch(Batch): # Constant shared tensor, ref here just so that it's accessible in concatentate() past_pad: Optional[torch.Tensor] + # Maximum number of tokens this batch will grow to + max_tokens: int + def to_pb(self) -> generate_pb2.Batch: return generate_pb2.Batch( - id=self.batch_id, requests=self.requests, size=len(self) + id=self.batch_id, + requests=self.requests, + size=len(self), + max_tokens=self.max_tokens, ) @classmethod @@ -86,6 +92,8 @@ class FlashCausalLMBatch(Batch): # Cumulative length cumulative_length = 0 + max_tokens = 0 + # Parse batch for i, r in enumerate(pb.requests): # request id -> idx in list mapping @@ -115,16 +123,20 @@ class FlashCausalLMBatch(Batch): cu_seqlens.append(cumulative_length + input_length) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) + max_new_tokens = stopping_criteria.max_new_tokens stopping_criterias.append(stopping_criteria) + all_input_ids_tensor.append( F.pad(tokenized_input, (0, stopping_criteria.max_new_tokens)) ) # Update cumulative_length += input_length + max_tokens += input_length + max_new_tokens return cls( batch_id=pb.id, @@ -143,6 +155,7 @@ class FlashCausalLMBatch(Batch): next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, past_pad=None, + max_tokens=max_tokens, ) @tracer.start_as_current_span("filter") @@ -177,6 +190,8 @@ class FlashCausalLMBatch(Batch): next_token_choosers = [] stopping_criterias = [] + max_tokens = 0 + for i, r in enumerate(requests): idx = self.requests_idx_mapping[r.id] requests_idx_mapping[r.id] = i @@ -203,9 +218,14 @@ class FlashCausalLMBatch(Batch): token_offsets.append(self.token_offsets[idx]) next_token_choosers.append(self.next_token_choosers[idx]) - stopping_criterias.append(self.stopping_criterias[idx]) + + stopping_criteria = self.stopping_criterias[idx] + stopping_criterias.append(stopping_criteria) cumulative_length += request_input_length + max_tokens += request_input_length + ( + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + ) if single_request: # Preallocate tensor for bs = 1 case @@ -241,6 +261,7 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor=all_input_ids_tensor, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, + max_tokens=max_tokens, ) @classmethod @@ -269,6 +290,7 @@ class FlashCausalLMBatch(Batch): # Cumulative length cumulative_batch_size = 0 cumulative_length = 0 + max_tokens = 0 for i, batch in enumerate(batches): requests.extend(batch.requests) @@ -310,6 +332,7 @@ class FlashCausalLMBatch(Batch): # Update cumulative_length += batch.cu_seqlens[-1] cumulative_batch_size += len(batch) + max_tokens += batch.max_tokens return FlashCausalLMBatch( batch_id=batches[0].batch_id, @@ -328,6 +351,7 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor=all_input_ids_tensor, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, + max_tokens=max_tokens, ) def __len__(self): diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 753e86e..c1ec156 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -101,6 +101,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): # Parse batch max_truncation = 0 padding_right_offset = 0 + max_decode_tokens = 0 for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i # Add escape_custom_split_sequence to the CausalLMBatch logic @@ -113,6 +114,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): ) stopping_criterias.append(stopping_criteria) max_truncation = max(max_truncation, r.truncate) + max_decode_tokens += stopping_criteria.max_new_tokens padding_right_offset = max( padding_right_offset, stopping_criteria.max_new_tokens ) @@ -141,6 +143,8 @@ class GalacticaCausalLMBatch(CausalLMBatch): position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) + max_tokens = len(inputs) * max_input_length + max_decode_tokens + return cls( batch_id=pb.id, requests=pb.requests, @@ -157,6 +161,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): stopping_criterias=stopping_criterias, max_input_length=max_input_length.item(), padding_right_offset=padding_right_offset, + max_tokens=max_tokens, ) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 2252fcf..0cb2076 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -54,10 +54,16 @@ class Seq2SeqLMBatch(Batch): max_decoder_input_length: int padding_right_offset: int + # Maximum number of tokens this batch will grow to + max_tokens: int + def to_pb(self) -> generate_pb2.Batch: """Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf""" return generate_pb2.Batch( - id=self.batch_id, requests=self.requests, size=len(self) + id=self.batch_id, + requests=self.requests, + size=len(self), + max_tokens=self.max_tokens, ) @classmethod @@ -80,6 +86,7 @@ class Seq2SeqLMBatch(Batch): # Parse batch max_truncation = 0 padding_right_offset = 0 + max_decode_tokens = 0 for i, r in enumerate(pb.requests): inputs.append(r.inputs) requests_idx_mapping[r.id] = i @@ -92,6 +99,7 @@ class Seq2SeqLMBatch(Batch): ) stopping_criterias.append(stopping_criteria) max_truncation = max(max_truncation, r.truncate) + max_decode_tokens += stopping_criteria.max_new_tokens padding_right_offset = max( padding_right_offset, stopping_criteria.max_new_tokens ) @@ -117,6 +125,8 @@ class Seq2SeqLMBatch(Batch): ) all_decoder_input_ids = decoder_input_ids.view(-1).split(1) + max_tokens = len(inputs) * max_input_length + max_decode_tokens + return cls( batch_id=pb.id, requests=pb.requests, @@ -137,6 +147,7 @@ class Seq2SeqLMBatch(Batch): max_input_length=max_input_length.item(), max_decoder_input_length=1, padding_right_offset=padding_right_offset, + max_tokens=max_tokens, ) @tracer.start_as_current_span("filter") @@ -166,6 +177,8 @@ class Seq2SeqLMBatch(Batch): max_decoder_input_length = 0 padding_right_offset = 0 + remaining_decode_tokens = 0 + for i, r in enumerate(requests): idx = self.requests_idx_mapping[r.id] requests_idx_mapping[r.id] = i @@ -187,27 +200,38 @@ class Seq2SeqLMBatch(Batch): ) padding_right_offset = max( padding_right_offset, - self.stopping_criterias[idx].max_new_tokens - self.stopping_criterias[idx].current_tokens + self.stopping_criterias[idx].max_new_tokens + - self.stopping_criterias[idx].current_tokens, ) next_token_choosers.append(self.next_token_choosers[idx]) - stopping_criterias.append(self.stopping_criterias[idx]) + stopping_criteria = self.stopping_criterias[idx] + stopping_criterias.append(stopping_criteria) + remaining_decode_tokens += ( + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + ) # Apply indices to input_ids, attention mask, past key values and other items that need to be cached self.decoder_input_ids = self.decoder_input_ids[keep_indices] self.attention_mask = self.attention_mask[keep_indices, -max_input_length:] if self.decoder_attention_mask is not None: self.decoder_attention_mask = self.decoder_attention_mask[ - keep_indices, - -(self.padding_right_offset + max_decoder_input_length): - (self.decoder_attention_mask.shape[1] - self.padding_right_offset) + padding_right_offset, + keep_indices, + -(self.padding_right_offset + max_decoder_input_length) : ( + self.decoder_attention_mask.shape[1] - self.padding_right_offset + ) + + padding_right_offset, ] - self.encoder_last_hidden_state = self.encoder_last_hidden_state[keep_indices, -max_input_length:] + self.encoder_last_hidden_state = self.encoder_last_hidden_state[ + keep_indices, -max_input_length: + ] # Ensure that past_key_values tensors can be updated in-place if type(self.past_key_values[0]) == tuple: - self.past_key_values = [[t for t in layer] for layer in self.past_key_values] + self.past_key_values = [ + [t for t in layer] for layer in self.past_key_values + ] decoder_past_seq_len = max_decoder_input_length - 1 for layer in self.past_key_values: @@ -216,6 +240,11 @@ class Seq2SeqLMBatch(Batch): layer[2] = layer[2][keep_indices, :, -max_input_length:] layer[3] = layer[3][keep_indices, :, -max_input_length:] + max_tokens = ( + len(requests) * (max_input_length + max_decoder_input_length) + + remaining_decode_tokens + ) + self.requests = requests self.requests_idx_mapping = requests_idx_mapping self.input_ids = None @@ -229,10 +258,10 @@ class Seq2SeqLMBatch(Batch): self.max_input_length = max_input_length self.max_decoder_input_length = max_decoder_input_length self.padding_right_offset = padding_right_offset + self.max_tokens = max_tokens return self - @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch": @@ -261,6 +290,7 @@ class Seq2SeqLMBatch(Batch): token_offsets = [] next_token_choosers = [] stopping_criterias = [] + max_tokens = 0 # Batch tensors attention_mask = None @@ -363,9 +393,18 @@ class Seq2SeqLMBatch(Batch): # Ensure that we can update tensors in-place if type(batch.past_key_values[0]) == tuple: - batch.past_key_values = [[t for t in layer] for layer in batch.past_key_values] + batch.past_key_values = [ + [t for t in layer] for layer in batch.past_key_values + ] start_index = end_index + # Add eventual padding tokens that were added while concatenating + max_tokens += batch.max_tokens + ( + max_input_length + - batch.max_input_length + + max_decoder_input_length + - batch.max_decoder_input_length + ) * len(batch) # Determine shapes for new past kv tensors first_past_kvs = batches[0].past_key_values @@ -404,9 +443,9 @@ class Seq2SeqLMBatch(Batch): end_index = start_index + len(batch) # We slice the past keys and values to remove the padding from previous batches past_seq_len = batch.max_decoder_input_length - 1 - padded_past_values[ - start_index:end_index, :, -past_seq_len:, : - ] = t[:, :, -past_seq_len:, :] + padded_past_values[start_index:end_index, :, -past_seq_len:, :] = t[ + :, :, -past_seq_len:, : + ] del t start_index = end_index @@ -426,8 +465,8 @@ class Seq2SeqLMBatch(Batch): end_index = start_index + len(batch) # We slice the past keys and values to remove the padding from previous batches padded_past_values[ - start_index:end_index, :, -batch.max_input_length:, : - ] = t[:, :, -batch.max_input_length:, :] + start_index:end_index, :, -batch.max_input_length :, : + ] = t[:, :, -batch.max_input_length :, :] del t start_index = end_index @@ -452,6 +491,7 @@ class Seq2SeqLMBatch(Batch): max_input_length=max_input_length, max_decoder_input_length=max_decoder_input_length, padding_right_offset=padding_right_offset, + max_tokens=max_tokens, ) def __len__(self): diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 95b431c..ddb7aae 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -41,6 +41,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): torch.cuda.empty_cache() return generate_pb2.ClearCacheResponse() + async def FilterBatch(self, request, context): + batch = self.cache.pop(request.batch_id) + if batch is None: + raise ValueError(f"Batch ID {request.batch_id} not found in cache.") + filtered_batch = batch.filter(request.keep_requests) + self.cache.set(filtered_batch) + + return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) + async def Prefill(self, request, context): batch = self.model.batch_type.from_pb( request.batch, self.model.tokenizer, self.model.device @@ -63,9 +72,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): batch = self.cache.pop(batch_pb.id) if batch is None: raise ValueError(f"Batch ID {batch_pb.id} not found in cache.") - batch = batch.filter(batch_pb.requests) - if batch is not None: - batches.append(batch) + batches.append(batch) if len(batches) == 0: raise ValueError("All batches are empty")