allocate 16 by 16

This commit is contained in:
OlivierDehaene 2024-06-12 18:53:14 +02:00
parent 9ac7b7bc52
commit 05eb4dcb17
3 changed files with 31 additions and 14 deletions

View File

@ -21,13 +21,28 @@ impl BlockAllocation {
&self.allocated_blocks
}
/// Extend an allocation by adding a new block
/// Extend an allocation by adding new blocks
/// If the allocation length > window size, repeats blocks and slots to cover the
/// whole `required_blocks` and `required_slots`
pub(crate) fn extend(&mut self) -> Result<(), AllocationError> {
let block = self.block_allocator.allocate_block()?;
let required_blocks = match self.block_allocator.window_size {
None => self.required_blocks,
Some(window_size) => min(
(window_size as usize + self.block_size - 1) / self.block_size,
self.required_blocks,
),
};
let remaining_blocks = required_blocks - self.allocated_blocks.len();
let new_blocks = min(remaining_blocks, 16);
// Try to allocate all remaining blocks
let blocks = match self.block_allocator.allocate_blocks(new_blocks) {
Ok(blocks) => blocks,
// Failed, try to allocate one block
Err(_) => self.block_allocator.allocate_blocks(1)?,
};
// Add block and slots to current allocation
self.allocated_blocks.push(block);
self.allocated_blocks.extend(blocks);
if let Some(window_size) = self.block_allocator.window_size {
// if we have more slots than the window size,
@ -87,14 +102,18 @@ impl BlockAllocator {
}
}
fn allocate_block(&self) -> Result<u32, AllocationError> {
fn allocate_blocks(&self, blocks: usize) -> Result<Vec<u32>, AllocationError> {
let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired");
if free_blocks.is_empty() {
if blocks > free_blocks.len() {
// Not enough blocks to cover this allocation
// Early return
return Err(AllocationError::NotEnoughPages);
}
Ok(free_blocks.pop().unwrap())
// Take the blocks
let n_free_blocks = free_blocks.len();
Ok(free_blocks.split_off(n_free_blocks - blocks))
}
/// For prompt tokens, we allocate enough blocks to cover all tokens

View File

@ -164,7 +164,8 @@ pub(crate) async fn batching_task(
};
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let max_size = max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
let max_size =
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
// Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue

View File

@ -133,12 +133,9 @@ class FlashCausalLMBatch(Batch):
batch_inputs.append(concat_text_chunks(r.input_chunks.chunks))
max_truncation = max(max_truncation, r.truncate)
logger.error(batch_inputs)
batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"]
logger.error(batch_tokenized_inputs)
return batch_tokenized_inputs
@classmethod