allocate 16 by 16

This commit is contained in:
OlivierDehaene 2024-06-12 18:53:14 +02:00
parent 9ac7b7bc52
commit 05eb4dcb17
3 changed files with 31 additions and 14 deletions

View File

@ -21,13 +21,28 @@ impl BlockAllocation {
&self.allocated_blocks &self.allocated_blocks
} }
/// Extend an allocation by adding a new block /// Extend an allocation by adding new blocks
/// If the allocation length > window size, repeats blocks and slots to cover the /// If the allocation length > window size, repeats blocks and slots to cover the
/// whole `required_blocks` and `required_slots` /// whole `required_blocks` and `required_slots`
pub(crate) fn extend(&mut self) -> Result<(), AllocationError> { pub(crate) fn extend(&mut self) -> Result<(), AllocationError> {
let block = self.block_allocator.allocate_block()?; let required_blocks = match self.block_allocator.window_size {
None => self.required_blocks,
Some(window_size) => min(
(window_size as usize + self.block_size - 1) / self.block_size,
self.required_blocks,
),
};
let remaining_blocks = required_blocks - self.allocated_blocks.len();
let new_blocks = min(remaining_blocks, 16);
// Try to allocate all remaining blocks
let blocks = match self.block_allocator.allocate_blocks(new_blocks) {
Ok(blocks) => blocks,
// Failed, try to allocate one block
Err(_) => self.block_allocator.allocate_blocks(1)?,
};
// Add block and slots to current allocation // Add block and slots to current allocation
self.allocated_blocks.push(block); self.allocated_blocks.extend(blocks);
if let Some(window_size) = self.block_allocator.window_size { if let Some(window_size) = self.block_allocator.window_size {
// if we have more slots than the window size, // if we have more slots than the window size,
@ -87,14 +102,18 @@ impl BlockAllocator {
} }
} }
fn allocate_block(&self) -> Result<u32, AllocationError> { fn allocate_blocks(&self, blocks: usize) -> Result<Vec<u32>, AllocationError> {
let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired"); let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired");
if free_blocks.is_empty() { if blocks > free_blocks.len() {
// Not enough blocks to cover this allocation
// Early return
return Err(AllocationError::NotEnoughPages); return Err(AllocationError::NotEnoughPages);
} }
Ok(free_blocks.pop().unwrap()) // Take the blocks
let n_free_blocks = free_blocks.len();
Ok(free_blocks.split_off(n_free_blocks - blocks))
} }
/// For prompt tokens, we allocate enough blocks to cover all tokens /// For prompt tokens, we allocate enough blocks to cover all tokens

View File

@ -164,7 +164,8 @@ pub(crate) async fn batching_task(
}; };
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let max_size = max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); let max_size =
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
// Try to get a new batch // Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue if let Some((mut new_entries, new_batch, span)) = queue
@ -990,10 +991,10 @@ mod tests {
content: "You are a friendly chatbot who always responds in the style of a pirate" content: "You are a friendly chatbot who always responds in the style of a pirate"
.to_string(), .to_string(),
}] }]
.iter() .iter()
.chain(&example_chat) .chain(&example_chat)
.cloned() .cloned()
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let test_default_templates = vec![ let test_default_templates = vec![
ChatTemplateTestItem { ChatTemplateTestItem {

View File

@ -133,12 +133,9 @@ class FlashCausalLMBatch(Batch):
batch_inputs.append(concat_text_chunks(r.input_chunks.chunks)) batch_inputs.append(concat_text_chunks(r.input_chunks.chunks))
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)
logger.error(batch_inputs)
batch_tokenized_inputs = tokenizer( batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"] )["input_ids"]
logger.error(batch_tokenized_inputs)
return batch_tokenized_inputs return batch_tokenized_inputs
@classmethod @classmethod