small refactor

This commit is contained in:
OlivierDehaene 2024-06-10 11:44:50 +02:00
parent 713d70b443
commit 6983ec9537
3 changed files with 97 additions and 95 deletions

View File

@ -1,44 +1,55 @@
use std::cmp::min;
use std::sync::{Arc, Mutex};
use thiserror::Error;
#[derive(Debug, Clone)]
pub(crate) struct BlockAllocation {
pub blocks: Vec<u32>,
pub slots: Vec<u32>,
prompt_tokens: u32,
decode_tokens: u32,
allocated_blocks: Vec<u32>,
allocated_slots: Vec<u32>,
required_blocks: usize,
required_slots: usize,
block_allocator: BlockAllocator,
}
impl BlockAllocation {
pub(crate) fn len(&self) -> usize {
self.slots.len()
self.allocated_slots.len()
}
pub(crate) fn blocks(&self) -> &[u32] {
&self.allocated_blocks
}
pub(crate) fn slots(&self) -> &[u32] {
&self.allocated_slots
}
pub(crate) fn extend(&mut self) -> Result<(), AllocationError> {
let (block, slots) = self.block_allocator.allocate_block()?;
// Add block and slots to current allocation
self.allocated_blocks.push(block);
self.allocated_slots.extend(slots);
match self.block_allocator.window_size {
None => {
self.blocks.push(block);
self.slots.extend(slots);
}
Some(window_size) => {
if self.len() as u32 > window_size {
let total_tokens = self.prompt_tokens + self.decode_tokens;
let repeats = (total_tokens + window_size - 1) / window_size;
}
if let Some(window_size) = self.block_allocator.window_size {
// if we have more slots than the window size,
// we will never need to re-allocate and we can just repeat the blocks/slots
let window_size = window_size as usize;
if self.len() > window_size {
let repeats = (self.required_slots + window_size - 1) / window_size;
self.allocated_blocks = self.allocated_blocks.repeat(repeats);
self.allocated_blocks.truncate(self.required_blocks);
self.allocated_slots = self.allocated_slots.repeat(repeats);
self.allocated_slots.truncate(self.required_slots);
}
}
Ok(())
}
}
impl Drop for BlockAllocation {
fn drop(&mut self) {
self.block_allocator.free(self.blocks.clone())
let allocated_blocks = std::mem::take(&mut self.allocated_blocks);
self.block_allocator.free(allocated_blocks)
}
}
@ -82,85 +93,76 @@ impl BlockAllocator {
/// For decode tokens, we allocate block by block
///
/// If prompt tokens + min(decode_tokens, block_size) > window size, we repeat blocks and slots
fn allocate(
&self,
prompt_tokens: u32,
decode_tokens: u32,
) -> Result<(Vec<u32>, Vec<u32>), AllocationError> {
// let decode_tokens = min(decode_tokens, self.block_size);
// let tokens = prompt_tokens + decode_tokens;
let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size;
// prompt blocks + a single block for decode
let required_blocks = required_prompt_blocks + 1;
let (required_blocks, repeats) = match self.window_size {
// Nothing to do
None => (required_blocks, 1),
Some(window_size) => {
// Number of blocks needed for this window size
let window_size_required_blocks = (window_size + self.block_size - 1) / self.block_size;
// Number of times we will need to repeat blocks to cover the required allocation
let repeats = (required_blocks + window_size_required_blocks -1) / window_size_required_blocks;
let required_blocks = min(required_blocks, window_size_required_blocks);
(required_blocks, repeats)
}
};
/// if prompt + decode < window size => do nothing
/// if prompt + decode > window size => do normal until we reach window size then
// Apply window size
let (required_blocks, repeats) = {
let (tokens, repeats) = match self.window_size {
None => (tokens, 1),
Some(window_size) => {
let repeats = (tokens + window_size - 1) / window_size;
let tokens = min(tokens, window_size);
(tokens, repeats as usize)
}
};
// Pad to a multiple of block size
let required_blocks = (tokens + self.block_size - 1) / self.block_size;
(required_blocks, repeats)
};
let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired");
if required_blocks > free_blocks.len() as u32 {
Err(AllocationError::NotEnoughPages)
} else {
let n_free_blocks = free_blocks.len();
let blocks =
free_blocks.split_off(n_free_blocks - required_blocks as usize);
let mut slots = Vec::with_capacity(
(required_blocks * self.block_size * repeats as u32) as usize,
);
for block_id in blocks.repeat(repeats).iter() {
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
slots.push(s);
}
}
Ok((blocks, slots))
}
}
pub(crate) fn block_allocation(
&self,
prompt_tokens: u32,
decode_tokens: u32,
) -> Result<BlockAllocation, AllocationError> {
self.allocate_inner(prompt_tokens, decode_tokens)
.map(|(blocks, slots)| BlockAllocation {
blocks,
slots,
prompt_tokens,
decode_tokens,
let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size;
// prompt blocks + a single block for decode
let required_blocks = required_prompt_blocks + 1;
let (clipped_required_blocks, repeats) = match self.window_size {
// Nothing to do
None => (required_blocks, 1),
Some(window_size) => {
// Number of blocks for this window size
let window_size_blocks = (window_size + self.block_size - 1) / self.block_size;
if required_blocks > window_size_blocks {
// Number of times we will need to repeat blocks to cover the required allocation
let repeats = (required_blocks + window_size_blocks - 1) / window_size_blocks;
(window_size_blocks, repeats)
} else {
(required_blocks, 1)
}
}
};
let repeats = repeats as usize;
let required_blocks = required_blocks as usize;
let clipped_required_blocks = clipped_required_blocks as usize;
let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired");
if clipped_required_blocks > free_blocks.len() {
Err(AllocationError::NotEnoughPages)
} else {
let n_free_blocks = free_blocks.len();
let allocated_blocks =
free_blocks.split_off(n_free_blocks - clipped_required_blocks);
let allocated_blocks = if repeats != 1 {
let mut allocated_blocks = allocated_blocks.repeat(repeats);
allocated_blocks.truncate(required_blocks);
allocated_blocks
} else {
allocated_blocks
};
let mut allocated_slots = Vec::with_capacity(
allocated_blocks.len() * self.block_size as usize * repeats,
);
let required_slots = (prompt_tokens + decode_tokens) as usize;
'slots: for block_id in allocated_blocks.iter() {
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
allocated_slots.push(s);
if allocated_slots.len() > required_slots {
break 'slots;
}
}
}
Ok(BlockAllocation {
allocated_blocks,
allocated_slots,
required_blocks,
required_slots,
block_allocator: self.clone(),
})
}
}
pub(crate) fn free(&self, blocks: Vec<u32>) {

View File

@ -283,7 +283,7 @@ impl State {
let decode_tokens =
entry.request.stopping_parameters.max_new_tokens + self.speculate;
match block_allocator
.allocate(entry.request.input_length, decode_tokens)
.block_allocation(entry.request.input_length, decode_tokens)
{
Err(_) => {
// Entry is over budget
@ -294,7 +294,7 @@ impl State {
}
Ok(block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}");
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
max_blocks = max(max_blocks, block_allocation.blocks().len() as u32);
Some(block_allocation)
}
}
@ -313,8 +313,8 @@ impl State {
let (blocks, slots) = match &block_allocation {
None => (Vec::new(), Vec::new()),
Some(block_allocation) => (
block_allocation.blocks.clone(),
block_allocation.slots.clone(),
block_allocation.blocks().to_vec(),
block_allocation.slots().to_vec(),
),
};

View File

@ -347,7 +347,7 @@ async fn filter_batch(
let (blocks, slots) = entry
.block_allocation
.as_ref()
.map(|alloc| (alloc.blocks.clone(), alloc.slots.clone()))
.map(|alloc| (alloc.blocks().to_vec(), alloc.slots().to_vec()))
.unwrap_or((Vec::new(), Vec::new()));
KeptRequest {