From c2fb459bc1a3a207308243f5fcc32bf6781618d0 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 11 Jun 2024 18:40:38 +0200 Subject: [PATCH] fix windowing --- router/src/infer/v3/block_allocator.rs | 99 ++++++++++++++------------ router/src/infer/v3/scheduler.rs | 9 ++- 2 files changed, 61 insertions(+), 47 deletions(-) diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index 563f173f..18480dbb 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -24,6 +24,9 @@ impl BlockAllocation { &self.allocated_slots } + /// Extend an allocation by adding a new block + /// If the allocation length > window size, repeats blocks and slots to cover the + /// whole `required_blocks` and `required_slots` pub(crate) fn extend(&mut self) -> Result<(), AllocationError> { let (block, slots) = self.block_allocator.allocate_block()?; // Add block and slots to current allocation @@ -48,6 +51,7 @@ impl BlockAllocation { } impl Drop for BlockAllocation { + /// Free the blocks fn drop(&mut self) { let allocated_blocks = std::mem::take(&mut self.allocated_blocks); self.block_allocator.free(allocated_blocks) @@ -114,66 +118,71 @@ impl BlockAllocator { let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size; // prompt blocks + a single block for decode let required_blocks = required_prompt_blocks + 1; + let required_slots = required_blocks * self.block_size; + + // Slots and blocks required for the whole request + let total_slots = prompt_tokens + decode_tokens; + let total_required_blocks = (total_slots + self.block_size - 1) / self.block_size; let (clipped_required_blocks, repeats) = match self.window_size { - // Nothing to do - None => (required_blocks, 1), - Some(window_size) => { + Some(window_size) if required_slots >= window_size => { // Number of blocks for this window size let window_size_blocks = (window_size + self.block_size - 1) / self.block_size; - - if required_blocks > window_size_blocks { - // Number of times we will need to repeat blocks to cover the required allocation - let repeats = (required_blocks + window_size_blocks - 1) / window_size_blocks; - (window_size_blocks, repeats) - } else { - (required_blocks, 1) - } + // Number of times we will need to repeat blocks to cover the total allocation + let repeats = (total_slots + window_size - 1) / window_size; + (window_size_blocks, repeats) } + // Nothing to do + _ => (required_blocks, 1), + }; + + // Scoped to drop the lock early + let allocated_blocks = { + let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired"); + let clipped_required_blocks = clipped_required_blocks as usize; + + if clipped_required_blocks > free_blocks.len() { + // Not enough blocks to cover this allocation + // Early return + return Err(AllocationError::NotEnoughPages); + } + + // Take the blocks + let n_free_blocks = free_blocks.len(); + free_blocks.split_off(n_free_blocks - clipped_required_blocks) }; let repeats = repeats as usize; - let required_blocks = required_blocks as usize; - let clipped_required_blocks = clipped_required_blocks as usize; + let total_slots = total_slots as usize; + let total_required_blocks = total_required_blocks as usize; - let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired"); - - if clipped_required_blocks > free_blocks.len() { - Err(AllocationError::NotEnoughPages) + let allocated_blocks = if repeats != 1 { + let mut allocated_blocks = allocated_blocks.repeat(repeats); + allocated_blocks.truncate(total_required_blocks); + allocated_blocks } else { - let n_free_blocks = free_blocks.len(); - let allocated_blocks = free_blocks.split_off(n_free_blocks - clipped_required_blocks); + allocated_blocks + }; - let allocated_blocks = if repeats != 1 { - let mut allocated_blocks = allocated_blocks.repeat(repeats); - allocated_blocks.truncate(required_blocks); - allocated_blocks - } else { - allocated_blocks - }; + let mut allocated_slots = + Vec::with_capacity(allocated_blocks.len() * self.block_size as usize * repeats); - let mut allocated_slots = - Vec::with_capacity(allocated_blocks.len() * self.block_size as usize * repeats); - - let required_slots = (prompt_tokens + decode_tokens) as usize; - - 'slots: for block_id in allocated_blocks.iter() { - for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { - allocated_slots.push(s); - if allocated_slots.len() > required_slots { - break 'slots; - } + 'slots: for block_id in allocated_blocks.iter() { + for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { + allocated_slots.push(s); + if allocated_slots.len() > total_slots { + break 'slots; } } - - Ok(BlockAllocation { - allocated_blocks, - allocated_slots, - required_blocks, - required_slots, - block_allocator: self.clone(), - }) } + + Ok(BlockAllocation { + allocated_blocks, + allocated_slots, + required_blocks: total_required_blocks, + required_slots: total_slots, + block_allocator: self.clone(), + }) } pub(crate) fn free(&self, blocks: Vec) { diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index c03328b2..6e5ffa7e 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -361,6 +361,7 @@ async fn decode( } /// Filter a `batch` and remove all requests not present in `entries` +/// Ask the server to generate the full texts for entries in `terminated_entries` #[instrument(skip_all)] async fn filter_batch( client: &mut ShardedClient, @@ -408,7 +409,10 @@ async fn filter_batch( } } -/// +/// Send `InferStreamResponse::Intermediate` and the final `InferStreamResponse::End` messages +/// to terminated requests +/// It modifies the last `InferStreamResponse::Intermediate` to add the final full text in +/// `terminated_generations` #[instrument(skip_all)] fn send_terminated_generations( terminated_generations: Vec, @@ -530,7 +534,7 @@ fn send_stream_responses( } /// Check if block allocations need to be extended -/// If we don't have enough blocks, request will be filtered with be added to an IntMap of +/// If we don't have enough blocks, request will be filtered and added to an IntMap of /// terminated entries. /// If at least one entry allocation was extended, we return true to force an update #[instrument(skip_all)] @@ -592,6 +596,7 @@ fn filter_send_update_allocations( } /// Map `Generation` to `<(bool, Vec<(u64, InferStreamResponse)>)>` +/// `bool` is `true` if the generation is finished fn map_generation(generation: Generation, entry: &Entry) -> (bool, Vec) { let mut finished = false; let mut stream_responses = Vec::with_capacity(16);