feat(router): add max_batch_size (#1542)
Some hardware require a maximum batch size.
This commit is contained in:
parent
a4e5801684
commit
532146338b
|
@ -197,6 +197,14 @@ Options:
|
|||
[env: MAX_WAITING_TOKENS=]
|
||||
[default: 20]
|
||||
|
||||
```
|
||||
## MAX_BATCH_SIZE
|
||||
```shell
|
||||
--max-batch-size <MAX_BATCH_SIZE>
|
||||
Enforce a maximum number of requests per batch Specific flag for hardware targets that do not support unpadded inference
|
||||
|
||||
[env: MAX_BATCH_SIZE=]
|
||||
|
||||
```
|
||||
## HOSTNAME
|
||||
```shell
|
||||
|
|
|
@ -279,6 +279,11 @@ struct Args {
|
|||
#[clap(default_value = "20", long, env)]
|
||||
max_waiting_tokens: usize,
|
||||
|
||||
/// Enforce a maximum number of requests per batch
|
||||
/// Specific flag for hardware targets that do not support unpadded inference
|
||||
#[clap(long, env)]
|
||||
max_batch_size: Option<usize>,
|
||||
|
||||
/// The IP address to listen on
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
hostname: String,
|
||||
|
@ -1046,6 +1051,12 @@ fn spawn_webserver(
|
|||
router_args.push(max_batch_total_tokens.to_string());
|
||||
}
|
||||
|
||||
// Router optional max batch size
|
||||
if let Some(max_batch_size) = args.max_batch_size {
|
||||
router_args.push("--max-batch-size".to_string());
|
||||
router_args.push(max_batch_size.to_string());
|
||||
}
|
||||
|
||||
// Model optional revision
|
||||
if let Some(ref revision) = args.revision {
|
||||
router_args.push("--revision".to_string());
|
||||
|
|
|
@ -105,6 +105,7 @@ impl Client {
|
|||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<Option<u32>> {
|
||||
let mut n_tokens = 0;
|
||||
let mut requests = Vec::new();
|
||||
|
@ -137,6 +138,11 @@ impl Client {
|
|||
top_n_tokens: 20,
|
||||
});
|
||||
n_tokens += max_input_length;
|
||||
|
||||
// Check max_batch_size
|
||||
if Some(requests.len()) == max_batch_size {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let batch = Batch {
|
||||
|
|
|
@ -97,12 +97,18 @@ impl ShardedClient {
|
|||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<Option<u32>> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| {
|
||||
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))
|
||||
Box::pin(client.warmup(
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_size,
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
// Take the minimum value
|
||||
|
|
|
@ -61,6 +61,7 @@ impl Infer {
|
|||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
max_concurrent_requests: usize,
|
||||
requires_padding: bool,
|
||||
window_size: Option<u32>,
|
||||
|
@ -81,6 +82,7 @@ impl Infer {
|
|||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
queue.clone(),
|
||||
shared.clone(),
|
||||
generation_health,
|
||||
|
@ -338,6 +340,7 @@ async fn batching_task(
|
|||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
queue: Queue,
|
||||
shared: Arc<Shared>,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
|
@ -351,7 +354,12 @@ async fn batching_task(
|
|||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||
// waiting in the queue
|
||||
while let Some((mut entries, batch, span)) = queue
|
||||
.next_batch(None, max_batch_prefill_tokens, max_batch_total_tokens)
|
||||
.next_batch(
|
||||
None,
|
||||
max_batch_size,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
)
|
||||
.await
|
||||
{
|
||||
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
|
||||
|
@ -379,10 +387,11 @@ async fn batching_task(
|
|||
};
|
||||
|
||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||
let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize);
|
||||
|
||||
// Try to get a new batch
|
||||
if let Some((mut new_entries, new_batch, span)) = queue
|
||||
.next_batch(min_size, max_batch_prefill_tokens, token_budget)
|
||||
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
||||
.await
|
||||
{
|
||||
// Tracking metrics
|
||||
|
|
|
@ -73,6 +73,8 @@ pub struct Info {
|
|||
pub max_batch_total_tokens: u32,
|
||||
#[schema(example = "20")]
|
||||
pub max_waiting_tokens: usize,
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub max_batch_size: Option<usize>,
|
||||
#[schema(example = "2")]
|
||||
pub validation_workers: usize,
|
||||
/// Router Info
|
||||
|
|
|
@ -45,6 +45,8 @@ struct Args {
|
|||
max_batch_total_tokens: Option<u32>,
|
||||
#[clap(default_value = "20", long, env)]
|
||||
max_waiting_tokens: usize,
|
||||
#[clap(long, env)]
|
||||
max_batch_size: Option<usize>,
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
hostname: String,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
|
@ -91,6 +93,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
hostname,
|
||||
port,
|
||||
master_shard_uds_path,
|
||||
|
@ -288,6 +291,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
max_input_length as u32,
|
||||
max_batch_prefill_tokens,
|
||||
max_total_tokens as u32,
|
||||
max_batch_size,
|
||||
)
|
||||
.await
|
||||
.map_err(RouterError::Warmup)?
|
||||
|
@ -344,6 +348,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
max_batch_prefill_tokens,
|
||||
max_supported_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
sharded_client,
|
||||
tokenizer,
|
||||
validation_workers,
|
||||
|
|
|
@ -70,6 +70,7 @@ impl Queue {
|
|||
pub(crate) async fn next_batch(
|
||||
&self,
|
||||
min_size: Option<usize>,
|
||||
max_size: Option<usize>,
|
||||
prefill_token_budget: u32,
|
||||
token_budget: u32,
|
||||
) -> Option<NextBatch> {
|
||||
|
@ -80,6 +81,7 @@ impl Queue {
|
|||
self.queue_sender
|
||||
.send(QueueCommand::NextBatch {
|
||||
min_size,
|
||||
max_size,
|
||||
prefill_token_budget,
|
||||
token_budget,
|
||||
response_sender,
|
||||
|
@ -110,12 +112,14 @@ async fn queue_task(
|
|||
}
|
||||
QueueCommand::NextBatch {
|
||||
min_size,
|
||||
max_size,
|
||||
prefill_token_budget,
|
||||
token_budget,
|
||||
response_sender,
|
||||
span,
|
||||
} => span.in_scope(|| {
|
||||
let next_batch = state.next_batch(min_size, prefill_token_budget, token_budget);
|
||||
let next_batch =
|
||||
state.next_batch(min_size, max_size, prefill_token_budget, token_budget);
|
||||
response_sender.send(next_batch).unwrap();
|
||||
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
|
||||
}),
|
||||
|
@ -181,6 +185,7 @@ impl State {
|
|||
fn next_batch(
|
||||
&mut self,
|
||||
min_size: Option<usize>,
|
||||
max_size: Option<usize>,
|
||||
prefill_token_budget: u32,
|
||||
token_budget: u32,
|
||||
) -> Option<NextBatch> {
|
||||
|
@ -274,6 +279,11 @@ impl State {
|
|||
entry.batch_time = Some(Instant::now());
|
||||
// Insert in batch_entries IntMap
|
||||
batch_entries.insert(id, entry);
|
||||
|
||||
// Check if max_size
|
||||
if Some(batch_requests.len()) == max_size {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Empty batch
|
||||
|
@ -322,6 +332,7 @@ enum QueueCommand {
|
|||
Append(Box<Entry>, Span),
|
||||
NextBatch {
|
||||
min_size: Option<usize>,
|
||||
max_size: Option<usize>,
|
||||
prefill_token_budget: u32,
|
||||
token_budget: u32,
|
||||
response_sender: oneshot::Sender<Option<NextBatch>>,
|
||||
|
@ -394,8 +405,8 @@ mod tests {
|
|||
fn test_next_batch_empty() {
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
|
||||
assert!(state.next_batch(None, 1, 1).is_none());
|
||||
assert!(state.next_batch(Some(1), 1, 1).is_none());
|
||||
assert!(state.next_batch(None, None, 1, 1).is_none());
|
||||
assert!(state.next_batch(Some(1), None, 1, 1).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -406,7 +417,7 @@ mod tests {
|
|||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, 2, 2).unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.contains_key(&1));
|
||||
|
@ -422,7 +433,7 @@ mod tests {
|
|||
let (entry3, _guard3) = default_entry();
|
||||
state.append(entry3);
|
||||
|
||||
assert!(state.next_batch(Some(2), 2, 2).is_none());
|
||||
assert!(state.next_batch(Some(2), None, 2, 2).is_none());
|
||||
|
||||
assert_eq!(state.next_id, 3);
|
||||
assert_eq!(state.entries.len(), 1);
|
||||
|
@ -430,6 +441,26 @@ mod tests {
|
|||
assert_eq!(id, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_max_size() {
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||
assert_eq!(batch.id, 0);
|
||||
assert_eq!(batch.size, 1);
|
||||
|
||||
assert_eq!(state.next_id, 2);
|
||||
assert_eq!(state.entries.len(), 1);
|
||||
assert_eq!(state.next_batch_id, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_token_budget() {
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
|
@ -438,7 +469,7 @@ mod tests {
|
|||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, 1, 1).unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert_eq!(batch.id, 0);
|
||||
|
@ -451,7 +482,7 @@ mod tests {
|
|||
let (entry3, _guard3) = default_entry();
|
||||
state.append(entry3);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, 3, 3).unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&1));
|
||||
assert!(entries.contains_key(&2));
|
||||
|
@ -474,8 +505,8 @@ mod tests {
|
|||
async fn test_queue_next_batch_empty() {
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
|
||||
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
||||
assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
|
||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
@ -486,7 +517,7 @@ mod tests {
|
|||
queue.append(entry1);
|
||||
queue.append(entry2);
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, 2, 2).await.unwrap();
|
||||
let (entries, batch, _) = queue.next_batch(None, None, 2, 2).await.unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.contains_key(&1));
|
||||
|
@ -499,11 +530,11 @@ mod tests {
|
|||
queue.append(entry3);
|
||||
|
||||
// Not enough requests pending
|
||||
assert!(queue.next_batch(Some(2), 2, 2).await.is_none());
|
||||
assert!(queue.next_batch(Some(2), None, 2, 2).await.is_none());
|
||||
// Not enough token budget
|
||||
assert!(queue.next_batch(Some(1), 0, 0).await.is_none());
|
||||
assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none());
|
||||
// Ok
|
||||
let (entries2, batch2, _) = queue.next_batch(Some(1), 2, 2).await.unwrap();
|
||||
let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 2).await.unwrap();
|
||||
assert_eq!(entries2.len(), 1);
|
||||
assert!(entries2.contains_key(&2));
|
||||
assert!(entries2.get(&2).unwrap().batch_time.is_some());
|
||||
|
@ -511,6 +542,22 @@ mod tests {
|
|||
assert_eq!(batch2.size, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_max_size() {
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
queue.append(entry2);
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, Some(1), 2, 2).await.unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||
assert_eq!(batch.id, 0);
|
||||
assert_eq!(batch.size, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_budget() {
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
|
@ -519,7 +566,7 @@ mod tests {
|
|||
queue.append(entry1);
|
||||
queue.append(entry2);
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, 1, 1).await.unwrap();
|
||||
let (entries, batch, _) = queue.next_batch(None, None, 1, 1).await.unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert_eq!(batch.id, 0);
|
||||
|
@ -528,7 +575,7 @@ mod tests {
|
|||
let (entry3, _guard3) = default_entry();
|
||||
queue.append(entry3);
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, 3, 3).await.unwrap();
|
||||
let (entries, batch, _) = queue.next_batch(None, None, 3, 3).await.unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&1));
|
||||
assert!(entries.contains_key(&2));
|
||||
|
@ -545,9 +592,9 @@ mod tests {
|
|||
queue.append(entry2);
|
||||
|
||||
// Budget of 1 is not enough
|
||||
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, 6, 6).await.unwrap();
|
||||
let (entries, batch, _) = queue.next_batch(None, None, 6, 6).await.unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.contains_key(&1));
|
||||
|
@ -561,6 +608,6 @@ mod tests {
|
|||
let (entry, _) = default_entry();
|
||||
queue.append(entry);
|
||||
|
||||
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -768,6 +768,7 @@ pub async fn run(
|
|||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
client: ShardedClient,
|
||||
tokenizer: Option<Tokenizer>,
|
||||
validation_workers: usize,
|
||||
|
@ -849,6 +850,7 @@ pub async fn run(
|
|||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
max_concurrent_requests,
|
||||
shard_info.requires_padding,
|
||||
shard_info.window_size,
|
||||
|
@ -930,6 +932,7 @@ pub async fn run(
|
|||
waiting_served_ratio,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
validation_workers,
|
||||
version: env!("CARGO_PKG_VERSION"),
|
||||
sha: option_env!("VERGEN_GIT_SHA"),
|
||||
|
|
Loading…
Reference in New Issue