Fixing more correctly the invalid drop of the batch. (#2498)

This commit is contained in:
Nicolas Patry 2024-09-06 17:35:49 +02:00 committed by GitHub
parent aaea212d0f
commit c1fe28d694
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 58 additions and 50 deletions

View File

@ -122,7 +122,7 @@ impl Backend for BackendV3 {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub(crate) async fn batching_task( pub(crate) async fn batching_task(
mut client: ShardedClient, mut client: ShardedClient,
_waiting_served_ratio: f32, waiting_served_ratio: f32,
max_batch_prefill_tokens: u32, max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
max_waiting_tokens: usize, max_waiting_tokens: usize,
@ -170,8 +170,7 @@ pub(crate) async fn batching_task(
// Minimum batch size // Minimum batch size
// TODO: temporarily disable to avoid incorrect deallocation + // TODO: temporarily disable to avoid incorrect deallocation +
// reallocation when using prefix caching. // reallocation when using prefix caching.
// Some((batch_size as f32 * waiting_served_ratio).floor() as usize) Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
None
}; };
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);

View File

@ -252,17 +252,14 @@ impl State {
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
next_batch_span.follows_from(Span::current()); next_batch_span.follows_from(Span::current());
let mut batch_requests = Vec::with_capacity(self.entries.len()); let mut batch = Vec::with_capacity(self.entries.len());
let mut batch_entries =
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
let mut max_input_length = 0; let mut max_input_length = 0;
let mut prefill_tokens: u32 = 0; let mut prefill_tokens: u32 = 0;
let mut decode_tokens: u32 = 0; let mut decode_tokens: u32 = 0;
let mut max_blocks = 0; let mut max_blocks = 0;
// Pop entries starting from the front of the queue // Pop entries starting from the front of the queue
'entry_loop: while let Some((id, mut entry)) = self.entries.pop_front() { 'entry_loop: while let Some((id, entry)) = self.entries.pop_front() {
// Filter entries where the response receiver was dropped (== entries where the request // Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client) // was dropped by the client)
if entry.response_tx.is_closed() { if entry.response_tx.is_closed() {
@ -276,7 +273,7 @@ impl State {
// We pad to max input length in the Python shards // We pad to max input length in the Python shards
// We need to take these padding tokens into the equation // We need to take these padding tokens into the equation
max_input_length = max_input_length.max(entry.request.input_length); max_input_length = max_input_length.max(entry.request.input_length);
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length; prefill_tokens = (batch.len() + 1) as u32 * max_input_length;
decode_tokens += entry.request.stopping_parameters.max_new_tokens; decode_tokens += entry.request.stopping_parameters.max_new_tokens;
let total_tokens = prefill_tokens + decode_tokens + self.speculate; let total_tokens = prefill_tokens + decode_tokens + self.speculate;
@ -290,7 +287,7 @@ impl State {
} }
None None
} }
Some(block_allocator) => { Some(_block_allocator) => {
prefill_tokens += entry.request.input_length; prefill_tokens += entry.request.input_length;
let max_new_tokens = match self.window_size { let max_new_tokens = match self.window_size {
None => entry.request.stopping_parameters.max_new_tokens, None => entry.request.stopping_parameters.max_new_tokens,
@ -324,23 +321,59 @@ impl State {
entry.request.input_ids.clone() entry.request.input_ids.clone()
}; };
match block_allocator.allocate(tokens, input_ids).await { Some((tokens, input_ids))
None => {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: not enough free blocks");
self.entries.push_front((id, entry));
break 'entry_loop;
}
Some(block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}");
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
Some(block_allocation)
}
}
} }
}; };
batch.push((id, entry, block_allocation));
if Some(batch.len()) == max_size {
break;
}
}
// Empty batch
if batch.is_empty() {
tracing::debug!("Filterered out all entries");
return None;
}
// XXX We haven't allocated yet, so we're allowed to ditch the results.
// Check if our batch is big enough
if let Some(min_size) = min_size {
// Batch is too small
if batch.len() < min_size {
// Add back entries to the queue in the correct order
for (id, entry, _) in batch.into_iter().rev() {
self.entries.push_front((id, entry));
}
return None;
}
}
let mut batch_requests = Vec::with_capacity(self.entries.len());
let mut batch_entries =
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
for (id, mut entry, block_allocation) in batch {
let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) =
(block_allocation, &self.block_allocator)
{
match block_allocator.allocate(tokens, input_ids).await {
None => {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: not enough free blocks");
self.entries.push_front((id, entry));
break;
}
Some(block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}");
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
Some(block_allocation)
}
}
} else {
None
};
tracing::debug!("Accepting entry"); tracing::debug!("Accepting entry");
// Create a new span to link the batch back to this entry // Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer"); let entry_batch_span = info_span!(parent: &entry.span, "infer");
@ -400,32 +433,6 @@ impl State {
entry.batch_time = Some(Instant::now()); entry.batch_time = Some(Instant::now());
// Insert in batch_entries IntMap // Insert in batch_entries IntMap
batch_entries.insert(id, entry); batch_entries.insert(id, entry);
// Check if max_size
if Some(batch_requests.len()) == max_size {
break;
}
}
// Empty batch
if batch_requests.is_empty() {
tracing::debug!("Filterered out all entries");
return None;
}
// Check if our batch is big enough
if let Some(min_size) = min_size {
// Batch is too small
if batch_requests.len() < min_size {
// Add back entries to the queue in the correct order
for r in batch_requests.into_iter().rev() {
let id = r.id;
let entry = batch_entries.remove(&id).unwrap();
self.entries.push_front((id, entry));
}
return None;
}
} }
// Final batch size // Final batch size

View File

@ -89,6 +89,8 @@ impl Allocator for RadixAllocator {
let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size; let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size;
tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}");
match self.alloc_or_reclaim(suffix_blocks as usize) { match self.alloc_or_reclaim(suffix_blocks as usize) {
Some(suffix_blocks) => blocks.extend(suffix_blocks), Some(suffix_blocks) => blocks.extend(suffix_blocks),
None => { None => {