diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index 712b4fc4..ba54f058 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -197,6 +197,14 @@ Options: [env: MAX_WAITING_TOKENS=] [default: 20] +``` +## MAX_BATCH_SIZE +```shell + --max-batch-size + Enforce a maximum number of requests per batch Specific flag for hardware targets that do not support unpadded inference + + [env: MAX_BATCH_SIZE=] + ``` ## HOSTNAME ```shell diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 054e546c..a51742e6 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -279,6 +279,11 @@ struct Args { #[clap(default_value = "20", long, env)] max_waiting_tokens: usize, + /// Enforce a maximum number of requests per batch + /// Specific flag for hardware targets that do not support unpadded inference + #[clap(long, env)] + max_batch_size: Option, + /// The IP address to listen on #[clap(default_value = "0.0.0.0", long, env)] hostname: String, @@ -1046,6 +1051,12 @@ fn spawn_webserver( router_args.push(max_batch_total_tokens.to_string()); } + // Router optional max batch size + if let Some(max_batch_size) = args.max_batch_size { + router_args.push("--max-batch-size".to_string()); + router_args.push(max_batch_size.to_string()); + } + // Model optional revision if let Some(ref revision) = args.revision { router_args.push("--revision".to_string()); diff --git a/router/client/src/client.rs b/router/client/src/client.rs index fde5c402..7b9f90fb 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -105,6 +105,7 @@ impl Client { max_input_length: u32, max_prefill_tokens: u32, max_total_tokens: u32, + max_batch_size: Option, ) -> Result> { let mut n_tokens = 0; let mut requests = Vec::new(); @@ -137,6 +138,11 @@ impl Client { top_n_tokens: 20, }); n_tokens += max_input_length; + + // Check max_batch_size + if Some(requests.len()) == max_batch_size { + break; + } } let batch = Batch { diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index f0e65ce5..e1e52d59 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -97,12 +97,18 @@ impl ShardedClient { max_input_length: u32, max_prefill_tokens: u32, max_total_tokens: u32, + max_batch_size: Option, ) -> Result> { let futures: Vec<_> = self .clients .iter_mut() .map(|client| { - Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens)) + Box::pin(client.warmup( + max_input_length, + max_prefill_tokens, + max_total_tokens, + max_batch_size, + )) }) .collect(); // Take the minimum value diff --git a/router/src/infer.rs b/router/src/infer.rs index 4da0da0a..d7b9b52b 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -61,6 +61,7 @@ impl Infer { max_batch_prefill_tokens: u32, max_batch_total_tokens: u32, max_waiting_tokens: usize, + max_batch_size: Option, max_concurrent_requests: usize, requires_padding: bool, window_size: Option, @@ -81,6 +82,7 @@ impl Infer { max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, + max_batch_size, queue.clone(), shared.clone(), generation_health, @@ -338,6 +340,7 @@ async fn batching_task( max_batch_prefill_tokens: u32, max_batch_total_tokens: u32, max_waiting_tokens: usize, + max_batch_size: Option, queue: Queue, shared: Arc, generation_health: Arc, @@ -351,7 +354,12 @@ async fn batching_task( // This batch might be smaller than the maximum batch size if there are not enough requests // waiting in the queue while let Some((mut entries, batch, span)) = queue - .next_batch(None, max_batch_prefill_tokens, max_batch_total_tokens) + .next_batch( + None, + max_batch_size, + max_batch_prefill_tokens, + max_batch_total_tokens, + ) .await { let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health) @@ -379,10 +387,11 @@ async fn batching_task( }; let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); + let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize); // Try to get a new batch if let Some((mut new_entries, new_batch, span)) = queue - .next_batch(min_size, max_batch_prefill_tokens, token_budget) + .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget) .await { // Tracking metrics diff --git a/router/src/lib.rs b/router/src/lib.rs index 7c44d642..3ce9eca8 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -73,6 +73,8 @@ pub struct Info { pub max_batch_total_tokens: u32, #[schema(example = "20")] pub max_waiting_tokens: usize, + #[schema(nullable = true, example = "null")] + pub max_batch_size: Option, #[schema(example = "2")] pub validation_workers: usize, /// Router Info diff --git a/router/src/main.rs b/router/src/main.rs index 2a080468..a1f8d97b 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -45,6 +45,8 @@ struct Args { max_batch_total_tokens: Option, #[clap(default_value = "20", long, env)] max_waiting_tokens: usize, + #[clap(long, env)] + max_batch_size: Option, #[clap(default_value = "0.0.0.0", long, env)] hostname: String, #[clap(default_value = "3000", long, short, env)] @@ -91,6 +93,7 @@ async fn main() -> Result<(), RouterError> { max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, + max_batch_size, hostname, port, master_shard_uds_path, @@ -288,6 +291,7 @@ async fn main() -> Result<(), RouterError> { max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32, + max_batch_size, ) .await .map_err(RouterError::Warmup)? @@ -344,6 +348,7 @@ async fn main() -> Result<(), RouterError> { max_batch_prefill_tokens, max_supported_batch_total_tokens, max_waiting_tokens, + max_batch_size, sharded_client, tokenizer, validation_workers, diff --git a/router/src/queue.rs b/router/src/queue.rs index 73a7169b..3675e0f5 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -70,6 +70,7 @@ impl Queue { pub(crate) async fn next_batch( &self, min_size: Option, + max_size: Option, prefill_token_budget: u32, token_budget: u32, ) -> Option { @@ -80,6 +81,7 @@ impl Queue { self.queue_sender .send(QueueCommand::NextBatch { min_size, + max_size, prefill_token_budget, token_budget, response_sender, @@ -110,12 +112,14 @@ async fn queue_task( } QueueCommand::NextBatch { min_size, + max_size, prefill_token_budget, token_budget, response_sender, span, } => span.in_scope(|| { - let next_batch = state.next_batch(min_size, prefill_token_budget, token_budget); + let next_batch = + state.next_batch(min_size, max_size, prefill_token_budget, token_budget); response_sender.send(next_batch).unwrap(); metrics::gauge!("tgi_queue_size", state.entries.len() as f64); }), @@ -181,6 +185,7 @@ impl State { fn next_batch( &mut self, min_size: Option, + max_size: Option, prefill_token_budget: u32, token_budget: u32, ) -> Option { @@ -274,6 +279,11 @@ impl State { entry.batch_time = Some(Instant::now()); // Insert in batch_entries IntMap batch_entries.insert(id, entry); + + // Check if max_size + if Some(batch_requests.len()) == max_size { + break; + } } // Empty batch @@ -322,6 +332,7 @@ enum QueueCommand { Append(Box, Span), NextBatch { min_size: Option, + max_size: Option, prefill_token_budget: u32, token_budget: u32, response_sender: oneshot::Sender>, @@ -394,8 +405,8 @@ mod tests { fn test_next_batch_empty() { let mut state = State::new(false, 1, None, 0); - assert!(state.next_batch(None, 1, 1).is_none()); - assert!(state.next_batch(Some(1), 1, 1).is_none()); + assert!(state.next_batch(None, None, 1, 1).is_none()); + assert!(state.next_batch(Some(1), None, 1, 1).is_none()); } #[test] @@ -406,7 +417,7 @@ mod tests { state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, 2, 2).unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&0)); assert!(entries.contains_key(&1)); @@ -422,7 +433,7 @@ mod tests { let (entry3, _guard3) = default_entry(); state.append(entry3); - assert!(state.next_batch(Some(2), 2, 2).is_none()); + assert!(state.next_batch(Some(2), None, 2, 2).is_none()); assert_eq!(state.next_id, 3); assert_eq!(state.entries.len(), 1); @@ -430,6 +441,26 @@ mod tests { assert_eq!(id, 2); } + #[test] + fn test_next_batch_max_size() { + let mut state = State::new(false, 1, None, 0); + let (entry1, _guard1) = default_entry(); + let (entry2, _guard2) = default_entry(); + state.append(entry1); + state.append(entry2); + + let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap(); + assert_eq!(entries.len(), 1); + assert!(entries.contains_key(&0)); + assert!(entries.get(&0).unwrap().batch_time.is_some()); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 1); + + assert_eq!(state.next_id, 2); + assert_eq!(state.entries.len(), 1); + assert_eq!(state.next_batch_id, 1); + } + #[test] fn test_next_batch_token_budget() { let mut state = State::new(false, 1, None, 0); @@ -438,7 +469,7 @@ mod tests { state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, 1, 1).unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert_eq!(batch.id, 0); @@ -451,7 +482,7 @@ mod tests { let (entry3, _guard3) = default_entry(); state.append(entry3); - let (entries, batch, _) = state.next_batch(None, 3, 3).unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&1)); assert!(entries.contains_key(&2)); @@ -474,8 +505,8 @@ mod tests { async fn test_queue_next_batch_empty() { let queue = Queue::new(false, 1, None, 0); - assert!(queue.next_batch(None, 1, 1).await.is_none()); - assert!(queue.next_batch(Some(1), 1, 1).await.is_none()); + assert!(queue.next_batch(None, None, 1, 1).await.is_none()); + assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); } #[tokio::test] @@ -486,7 +517,7 @@ mod tests { queue.append(entry1); queue.append(entry2); - let (entries, batch, _) = queue.next_batch(None, 2, 2).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, None, 2, 2).await.unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&0)); assert!(entries.contains_key(&1)); @@ -499,11 +530,11 @@ mod tests { queue.append(entry3); // Not enough requests pending - assert!(queue.next_batch(Some(2), 2, 2).await.is_none()); + assert!(queue.next_batch(Some(2), None, 2, 2).await.is_none()); // Not enough token budget - assert!(queue.next_batch(Some(1), 0, 0).await.is_none()); + assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none()); // Ok - let (entries2, batch2, _) = queue.next_batch(Some(1), 2, 2).await.unwrap(); + let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 2).await.unwrap(); assert_eq!(entries2.len(), 1); assert!(entries2.contains_key(&2)); assert!(entries2.get(&2).unwrap().batch_time.is_some()); @@ -511,6 +542,22 @@ mod tests { assert_eq!(batch2.size, 1); } + #[tokio::test] + async fn test_queue_next_batch_max_size() { + let queue = Queue::new(false, 1, None, 0); + let (entry1, _guard1) = default_entry(); + let (entry2, _guard2) = default_entry(); + queue.append(entry1); + queue.append(entry2); + + let (entries, batch, _) = queue.next_batch(None, Some(1), 2, 2).await.unwrap(); + assert_eq!(entries.len(), 1); + assert!(entries.contains_key(&0)); + assert!(entries.get(&0).unwrap().batch_time.is_some()); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 1); + } + #[tokio::test] async fn test_queue_next_batch_token_budget() { let queue = Queue::new(false, 1, None, 0); @@ -519,7 +566,7 @@ mod tests { queue.append(entry1); queue.append(entry2); - let (entries, batch, _) = queue.next_batch(None, 1, 1).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, None, 1, 1).await.unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert_eq!(batch.id, 0); @@ -528,7 +575,7 @@ mod tests { let (entry3, _guard3) = default_entry(); queue.append(entry3); - let (entries, batch, _) = queue.next_batch(None, 3, 3).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, None, 3, 3).await.unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&1)); assert!(entries.contains_key(&2)); @@ -545,9 +592,9 @@ mod tests { queue.append(entry2); // Budget of 1 is not enough - assert!(queue.next_batch(None, 1, 1).await.is_none()); + assert!(queue.next_batch(None, None, 1, 1).await.is_none()); - let (entries, batch, _) = queue.next_batch(None, 6, 6).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, None, 6, 6).await.unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&0)); assert!(entries.contains_key(&1)); @@ -561,6 +608,6 @@ mod tests { let (entry, _) = default_entry(); queue.append(entry); - assert!(queue.next_batch(None, 1, 1).await.is_none()); + assert!(queue.next_batch(None, None, 1, 1).await.is_none()); } } diff --git a/router/src/server.rs b/router/src/server.rs index acfdef91..00b793e3 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -768,6 +768,7 @@ pub async fn run( max_batch_prefill_tokens: u32, max_batch_total_tokens: u32, max_waiting_tokens: usize, + max_batch_size: Option, client: ShardedClient, tokenizer: Option, validation_workers: usize, @@ -849,6 +850,7 @@ pub async fn run( max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, + max_batch_size, max_concurrent_requests, shard_info.requires_padding, shard_info.window_size, @@ -930,6 +932,7 @@ pub async fn run( waiting_served_ratio, max_batch_total_tokens, max_waiting_tokens, + max_batch_size, validation_workers, version: env!("CARGO_PKG_VERSION"), sha: option_env!("VERGEN_GIT_SHA"),