allocate 16 by 16
This commit is contained in:
parent
9ac7b7bc52
commit
05eb4dcb17
|
@ -21,13 +21,28 @@ impl BlockAllocation {
|
||||||
&self.allocated_blocks
|
&self.allocated_blocks
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extend an allocation by adding a new block
|
/// Extend an allocation by adding new blocks
|
||||||
/// If the allocation length > window size, repeats blocks and slots to cover the
|
/// If the allocation length > window size, repeats blocks and slots to cover the
|
||||||
/// whole `required_blocks` and `required_slots`
|
/// whole `required_blocks` and `required_slots`
|
||||||
pub(crate) fn extend(&mut self) -> Result<(), AllocationError> {
|
pub(crate) fn extend(&mut self) -> Result<(), AllocationError> {
|
||||||
let block = self.block_allocator.allocate_block()?;
|
let required_blocks = match self.block_allocator.window_size {
|
||||||
|
None => self.required_blocks,
|
||||||
|
Some(window_size) => min(
|
||||||
|
(window_size as usize + self.block_size - 1) / self.block_size,
|
||||||
|
self.required_blocks,
|
||||||
|
),
|
||||||
|
};
|
||||||
|
let remaining_blocks = required_blocks - self.allocated_blocks.len();
|
||||||
|
let new_blocks = min(remaining_blocks, 16);
|
||||||
|
|
||||||
|
// Try to allocate all remaining blocks
|
||||||
|
let blocks = match self.block_allocator.allocate_blocks(new_blocks) {
|
||||||
|
Ok(blocks) => blocks,
|
||||||
|
// Failed, try to allocate one block
|
||||||
|
Err(_) => self.block_allocator.allocate_blocks(1)?,
|
||||||
|
};
|
||||||
// Add block and slots to current allocation
|
// Add block and slots to current allocation
|
||||||
self.allocated_blocks.push(block);
|
self.allocated_blocks.extend(blocks);
|
||||||
|
|
||||||
if let Some(window_size) = self.block_allocator.window_size {
|
if let Some(window_size) = self.block_allocator.window_size {
|
||||||
// if we have more slots than the window size,
|
// if we have more slots than the window size,
|
||||||
|
@ -87,14 +102,18 @@ impl BlockAllocator {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn allocate_block(&self) -> Result<u32, AllocationError> {
|
fn allocate_blocks(&self, blocks: usize) -> Result<Vec<u32>, AllocationError> {
|
||||||
let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired");
|
let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired");
|
||||||
|
|
||||||
if free_blocks.is_empty() {
|
if blocks > free_blocks.len() {
|
||||||
|
// Not enough blocks to cover this allocation
|
||||||
|
// Early return
|
||||||
return Err(AllocationError::NotEnoughPages);
|
return Err(AllocationError::NotEnoughPages);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(free_blocks.pop().unwrap())
|
// Take the blocks
|
||||||
|
let n_free_blocks = free_blocks.len();
|
||||||
|
Ok(free_blocks.split_off(n_free_blocks - blocks))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// For prompt tokens, we allocate enough blocks to cover all tokens
|
/// For prompt tokens, we allocate enough blocks to cover all tokens
|
||||||
|
|
|
@ -164,7 +164,8 @@ pub(crate) async fn batching_task(
|
||||||
};
|
};
|
||||||
|
|
||||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||||
let max_size = max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
|
let max_size =
|
||||||
|
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
|
||||||
|
|
||||||
// Try to get a new batch
|
// Try to get a new batch
|
||||||
if let Some((mut new_entries, new_batch, span)) = queue
|
if let Some((mut new_entries, new_batch, span)) = queue
|
||||||
|
|
|
@ -133,12 +133,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
batch_inputs.append(concat_text_chunks(r.input_chunks.chunks))
|
batch_inputs.append(concat_text_chunks(r.input_chunks.chunks))
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
|
|
||||||
logger.error(batch_inputs)
|
|
||||||
|
|
||||||
batch_tokenized_inputs = tokenizer(
|
batch_tokenized_inputs = tokenizer(
|
||||||
batch_inputs, truncation=True, max_length=max_truncation
|
batch_inputs, truncation=True, max_length=max_truncation
|
||||||
)["input_ids"]
|
)["input_ids"]
|
||||||
logger.error(batch_tokenized_inputs)
|
|
||||||
return batch_tokenized_inputs
|
return batch_tokenized_inputs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
Loading…
Reference in New Issue