diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index a47e62dc..935f7980 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -122,7 +122,7 @@ impl Backend for BackendV3 { #[allow(clippy::too_many_arguments)] pub(crate) async fn batching_task( mut client: ShardedClient, - _waiting_served_ratio: f32, + waiting_served_ratio: f32, max_batch_prefill_tokens: u32, max_batch_total_tokens: u32, max_waiting_tokens: usize, @@ -170,8 +170,7 @@ pub(crate) async fn batching_task( // Minimum batch size // TODO: temporarily disable to avoid incorrect deallocation + // reallocation when using prefix caching. - // Some((batch_size as f32 * waiting_served_ratio).floor() as usize) - None + Some((batch_size as f32 * waiting_served_ratio).floor() as usize) }; let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 2a8c4c53..978a495c 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -252,17 +252,14 @@ impl State { let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); next_batch_span.follows_from(Span::current()); - let mut batch_requests = Vec::with_capacity(self.entries.len()); - let mut batch_entries = - IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default()); - + let mut batch = Vec::with_capacity(self.entries.len()); let mut max_input_length = 0; let mut prefill_tokens: u32 = 0; let mut decode_tokens: u32 = 0; let mut max_blocks = 0; // Pop entries starting from the front of the queue - 'entry_loop: while let Some((id, mut entry)) = self.entries.pop_front() { + 'entry_loop: while let Some((id, entry)) = self.entries.pop_front() { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) if entry.response_tx.is_closed() { @@ -276,7 +273,7 @@ impl State { // We pad to max input length in the Python shards // We need to take these padding tokens into the equation max_input_length = max_input_length.max(entry.request.input_length); - prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length; + prefill_tokens = (batch.len() + 1) as u32 * max_input_length; decode_tokens += entry.request.stopping_parameters.max_new_tokens; let total_tokens = prefill_tokens + decode_tokens + self.speculate; @@ -290,7 +287,7 @@ impl State { } None } - Some(block_allocator) => { + Some(_block_allocator) => { prefill_tokens += entry.request.input_length; let max_new_tokens = match self.window_size { None => entry.request.stopping_parameters.max_new_tokens, @@ -324,23 +321,59 @@ impl State { entry.request.input_ids.clone() }; - match block_allocator.allocate(tokens, input_ids).await { - None => { - // Entry is over budget - // Add it back to the front - tracing::debug!("Over budget: not enough free blocks"); - self.entries.push_front((id, entry)); - break 'entry_loop; - } - Some(block_allocation) => { - tracing::debug!("Allocation: {block_allocation:?}"); - max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); - Some(block_allocation) - } - } + Some((tokens, input_ids)) } }; + batch.push((id, entry, block_allocation)); + if Some(batch.len()) == max_size { + break; + } + } + // Empty batch + if batch.is_empty() { + tracing::debug!("Filterered out all entries"); + return None; + } + + // XXX We haven't allocated yet, so we're allowed to ditch the results. + // Check if our batch is big enough + if let Some(min_size) = min_size { + // Batch is too small + if batch.len() < min_size { + // Add back entries to the queue in the correct order + for (id, entry, _) in batch.into_iter().rev() { + self.entries.push_front((id, entry)); + } + return None; + } + } + + let mut batch_requests = Vec::with_capacity(self.entries.len()); + let mut batch_entries = + IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default()); + + for (id, mut entry, block_allocation) in batch { + let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) = + (block_allocation, &self.block_allocator) + { + match block_allocator.allocate(tokens, input_ids).await { + None => { + // Entry is over budget + // Add it back to the front + tracing::debug!("Over budget: not enough free blocks"); + self.entries.push_front((id, entry)); + break; + } + Some(block_allocation) => { + tracing::debug!("Allocation: {block_allocation:?}"); + max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); + Some(block_allocation) + } + } + } else { + None + }; tracing::debug!("Accepting entry"); // Create a new span to link the batch back to this entry let entry_batch_span = info_span!(parent: &entry.span, "infer"); @@ -400,32 +433,6 @@ impl State { entry.batch_time = Some(Instant::now()); // Insert in batch_entries IntMap batch_entries.insert(id, entry); - - // Check if max_size - if Some(batch_requests.len()) == max_size { - break; - } - } - - // Empty batch - if batch_requests.is_empty() { - tracing::debug!("Filterered out all entries"); - return None; - } - - // Check if our batch is big enough - if let Some(min_size) = min_size { - // Batch is too small - if batch_requests.len() < min_size { - // Add back entries to the queue in the correct order - for r in batch_requests.into_iter().rev() { - let id = r.id; - let entry = batch_entries.remove(&id).unwrap(); - self.entries.push_front((id, entry)); - } - - return None; - } } // Final batch size diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index bb6582b0..1f3bef15 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -89,6 +89,8 @@ impl Allocator for RadixAllocator { let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size; + tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}"); + match self.alloc_or_reclaim(suffix_blocks as usize) { Some(suffix_blocks) => blocks.extend(suffix_blocks), None => {