diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index e19450e3..4a60dae6 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -21,13 +21,28 @@ impl BlockAllocation { &self.allocated_blocks } - /// Extend an allocation by adding a new block + /// Extend an allocation by adding new blocks /// If the allocation length > window size, repeats blocks and slots to cover the /// whole `required_blocks` and `required_slots` pub(crate) fn extend(&mut self) -> Result<(), AllocationError> { - let block = self.block_allocator.allocate_block()?; + let required_blocks = match self.block_allocator.window_size { + None => self.required_blocks, + Some(window_size) => min( + (window_size as usize + self.block_size - 1) / self.block_size, + self.required_blocks, + ), + }; + let remaining_blocks = required_blocks - self.allocated_blocks.len(); + let new_blocks = min(remaining_blocks, 16); + + // Try to allocate all remaining blocks + let blocks = match self.block_allocator.allocate_blocks(new_blocks) { + Ok(blocks) => blocks, + // Failed, try to allocate one block + Err(_) => self.block_allocator.allocate_blocks(1)?, + }; // Add block and slots to current allocation - self.allocated_blocks.push(block); + self.allocated_blocks.extend(blocks); if let Some(window_size) = self.block_allocator.window_size { // if we have more slots than the window size, @@ -87,14 +102,18 @@ impl BlockAllocator { } } - fn allocate_block(&self) -> Result { + fn allocate_blocks(&self, blocks: usize) -> Result, AllocationError> { let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired"); - if free_blocks.is_empty() { + if blocks > free_blocks.len() { + // Not enough blocks to cover this allocation + // Early return return Err(AllocationError::NotEnoughPages); } - Ok(free_blocks.pop().unwrap()) + // Take the blocks + let n_free_blocks = free_blocks.len(); + Ok(free_blocks.split_off(n_free_blocks - blocks)) } /// For prompt tokens, we allocate enough blocks to cover all tokens diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index 50b33951..a901ba69 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -164,7 +164,8 @@ pub(crate) async fn batching_task( }; let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); - let max_size = max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); + let max_size = + max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); // Try to get a new batch if let Some((mut new_entries, new_batch, span)) = queue @@ -990,10 +991,10 @@ mod tests { content: "You are a friendly chatbot who always responds in the style of a pirate" .to_string(), }] - .iter() - .chain(&example_chat) - .cloned() - .collect::>(); + .iter() + .chain(&example_chat) + .cloned() + .collect::>(); let test_default_templates = vec![ ChatTemplateTestItem { diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 1bf9b7a5..47963aba 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -133,12 +133,9 @@ class FlashCausalLMBatch(Batch): batch_inputs.append(concat_text_chunks(r.input_chunks.chunks)) max_truncation = max(max_truncation, r.truncate) - logger.error(batch_inputs) - batch_tokenized_inputs = tokenizer( batch_inputs, truncation=True, max_length=max_truncation )["input_ids"] - logger.error(batch_tokenized_inputs) return batch_tokenized_inputs @classmethod