small refactor
This commit is contained in:
parent
713d70b443
commit
6983ec9537
|
@ -1,44 +1,55 @@
|
|||
use std::cmp::min;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct BlockAllocation {
|
||||
pub blocks: Vec<u32>,
|
||||
pub slots: Vec<u32>,
|
||||
prompt_tokens: u32,
|
||||
decode_tokens: u32,
|
||||
allocated_blocks: Vec<u32>,
|
||||
allocated_slots: Vec<u32>,
|
||||
required_blocks: usize,
|
||||
required_slots: usize,
|
||||
block_allocator: BlockAllocator,
|
||||
}
|
||||
|
||||
impl BlockAllocation {
|
||||
pub(crate) fn len(&self) -> usize {
|
||||
self.slots.len()
|
||||
self.allocated_slots.len()
|
||||
}
|
||||
|
||||
pub(crate) fn blocks(&self) -> &[u32] {
|
||||
&self.allocated_blocks
|
||||
}
|
||||
|
||||
pub(crate) fn slots(&self) -> &[u32] {
|
||||
&self.allocated_slots
|
||||
}
|
||||
|
||||
pub(crate) fn extend(&mut self) -> Result<(), AllocationError> {
|
||||
let (block, slots) = self.block_allocator.allocate_block()?;
|
||||
// Add block and slots to current allocation
|
||||
self.allocated_blocks.push(block);
|
||||
self.allocated_slots.extend(slots);
|
||||
|
||||
match self.block_allocator.window_size {
|
||||
None => {
|
||||
self.blocks.push(block);
|
||||
self.slots.extend(slots);
|
||||
}
|
||||
Some(window_size) => {
|
||||
if self.len() as u32 > window_size {
|
||||
let total_tokens = self.prompt_tokens + self.decode_tokens;
|
||||
|
||||
let repeats = (total_tokens + window_size - 1) / window_size;
|
||||
}
|
||||
if let Some(window_size) = self.block_allocator.window_size {
|
||||
// if we have more slots than the window size,
|
||||
// we will never need to re-allocate and we can just repeat the blocks/slots
|
||||
let window_size = window_size as usize;
|
||||
if self.len() > window_size {
|
||||
let repeats = (self.required_slots + window_size - 1) / window_size;
|
||||
self.allocated_blocks = self.allocated_blocks.repeat(repeats);
|
||||
self.allocated_blocks.truncate(self.required_blocks);
|
||||
self.allocated_slots = self.allocated_slots.repeat(repeats);
|
||||
self.allocated_slots.truncate(self.required_slots);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for BlockAllocation {
|
||||
fn drop(&mut self) {
|
||||
self.block_allocator.free(self.blocks.clone())
|
||||
let allocated_blocks = std::mem::take(&mut self.allocated_blocks);
|
||||
self.block_allocator.free(allocated_blocks)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -82,85 +93,76 @@ impl BlockAllocator {
|
|||
/// For decode tokens, we allocate block by block
|
||||
///
|
||||
/// If prompt tokens + min(decode_tokens, block_size) > window size, we repeat blocks and slots
|
||||
fn allocate(
|
||||
&self,
|
||||
prompt_tokens: u32,
|
||||
decode_tokens: u32,
|
||||
) -> Result<(Vec<u32>, Vec<u32>), AllocationError> {
|
||||
// let decode_tokens = min(decode_tokens, self.block_size);
|
||||
// let tokens = prompt_tokens + decode_tokens;
|
||||
|
||||
let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size;
|
||||
// prompt blocks + a single block for decode
|
||||
let required_blocks = required_prompt_blocks + 1;
|
||||
|
||||
let (required_blocks, repeats) = match self.window_size {
|
||||
// Nothing to do
|
||||
None => (required_blocks, 1),
|
||||
Some(window_size) => {
|
||||
// Number of blocks needed for this window size
|
||||
let window_size_required_blocks = (window_size + self.block_size - 1) / self.block_size;
|
||||
// Number of times we will need to repeat blocks to cover the required allocation
|
||||
let repeats = (required_blocks + window_size_required_blocks -1) / window_size_required_blocks;
|
||||
let required_blocks = min(required_blocks, window_size_required_blocks);
|
||||
|
||||
(required_blocks, repeats)
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// if prompt + decode < window size => do nothing
|
||||
/// if prompt + decode > window size => do normal until we reach window size then
|
||||
|
||||
// Apply window size
|
||||
let (required_blocks, repeats) = {
|
||||
let (tokens, repeats) = match self.window_size {
|
||||
None => (tokens, 1),
|
||||
Some(window_size) => {
|
||||
let repeats = (tokens + window_size - 1) / window_size;
|
||||
let tokens = min(tokens, window_size);
|
||||
(tokens, repeats as usize)
|
||||
}
|
||||
};
|
||||
// Pad to a multiple of block size
|
||||
let required_blocks = (tokens + self.block_size - 1) / self.block_size;
|
||||
(required_blocks, repeats)
|
||||
};
|
||||
|
||||
let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired");
|
||||
|
||||
if required_blocks > free_blocks.len() as u32 {
|
||||
Err(AllocationError::NotEnoughPages)
|
||||
} else {
|
||||
let n_free_blocks = free_blocks.len();
|
||||
let blocks =
|
||||
free_blocks.split_off(n_free_blocks - required_blocks as usize);
|
||||
let mut slots = Vec::with_capacity(
|
||||
(required_blocks * self.block_size * repeats as u32) as usize,
|
||||
);
|
||||
|
||||
for block_id in blocks.repeat(repeats).iter() {
|
||||
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
||||
slots.push(s);
|
||||
}
|
||||
}
|
||||
Ok((blocks, slots))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn block_allocation(
|
||||
&self,
|
||||
prompt_tokens: u32,
|
||||
decode_tokens: u32,
|
||||
) -> Result<BlockAllocation, AllocationError> {
|
||||
self.allocate_inner(prompt_tokens, decode_tokens)
|
||||
.map(|(blocks, slots)| BlockAllocation {
|
||||
blocks,
|
||||
slots,
|
||||
prompt_tokens,
|
||||
decode_tokens,
|
||||
let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size;
|
||||
// prompt blocks + a single block for decode
|
||||
let required_blocks = required_prompt_blocks + 1;
|
||||
|
||||
let (clipped_required_blocks, repeats) = match self.window_size {
|
||||
// Nothing to do
|
||||
None => (required_blocks, 1),
|
||||
Some(window_size) => {
|
||||
// Number of blocks for this window size
|
||||
let window_size_blocks = (window_size + self.block_size - 1) / self.block_size;
|
||||
|
||||
if required_blocks > window_size_blocks {
|
||||
// Number of times we will need to repeat blocks to cover the required allocation
|
||||
let repeats = (required_blocks + window_size_blocks - 1) / window_size_blocks;
|
||||
(window_size_blocks, repeats)
|
||||
} else {
|
||||
(required_blocks, 1)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let repeats = repeats as usize;
|
||||
let required_blocks = required_blocks as usize;
|
||||
let clipped_required_blocks = clipped_required_blocks as usize;
|
||||
|
||||
let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired");
|
||||
|
||||
if clipped_required_blocks > free_blocks.len() {
|
||||
Err(AllocationError::NotEnoughPages)
|
||||
} else {
|
||||
let n_free_blocks = free_blocks.len();
|
||||
let allocated_blocks =
|
||||
free_blocks.split_off(n_free_blocks - clipped_required_blocks);
|
||||
|
||||
let allocated_blocks = if repeats != 1 {
|
||||
let mut allocated_blocks = allocated_blocks.repeat(repeats);
|
||||
allocated_blocks.truncate(required_blocks);
|
||||
allocated_blocks
|
||||
} else {
|
||||
allocated_blocks
|
||||
};
|
||||
|
||||
let mut allocated_slots = Vec::with_capacity(
|
||||
allocated_blocks.len() * self.block_size as usize * repeats,
|
||||
);
|
||||
|
||||
let required_slots = (prompt_tokens + decode_tokens) as usize;
|
||||
|
||||
'slots: for block_id in allocated_blocks.iter() {
|
||||
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
||||
allocated_slots.push(s);
|
||||
if allocated_slots.len() > required_slots {
|
||||
break 'slots;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(BlockAllocation {
|
||||
allocated_blocks,
|
||||
allocated_slots,
|
||||
required_blocks,
|
||||
required_slots,
|
||||
block_allocator: self.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn free(&self, blocks: Vec<u32>) {
|
||||
|
|
|
@ -283,7 +283,7 @@ impl State {
|
|||
let decode_tokens =
|
||||
entry.request.stopping_parameters.max_new_tokens + self.speculate;
|
||||
match block_allocator
|
||||
.allocate(entry.request.input_length, decode_tokens)
|
||||
.block_allocation(entry.request.input_length, decode_tokens)
|
||||
{
|
||||
Err(_) => {
|
||||
// Entry is over budget
|
||||
|
@ -294,7 +294,7 @@ impl State {
|
|||
}
|
||||
Ok(block_allocation) => {
|
||||
tracing::debug!("Allocation: {block_allocation:?}");
|
||||
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
||||
max_blocks = max(max_blocks, block_allocation.blocks().len() as u32);
|
||||
Some(block_allocation)
|
||||
}
|
||||
}
|
||||
|
@ -313,8 +313,8 @@ impl State {
|
|||
let (blocks, slots) = match &block_allocation {
|
||||
None => (Vec::new(), Vec::new()),
|
||||
Some(block_allocation) => (
|
||||
block_allocation.blocks.clone(),
|
||||
block_allocation.slots.clone(),
|
||||
block_allocation.blocks().to_vec(),
|
||||
block_allocation.slots().to_vec(),
|
||||
),
|
||||
};
|
||||
|
||||
|
|
|
@ -347,7 +347,7 @@ async fn filter_batch(
|
|||
let (blocks, slots) = entry
|
||||
.block_allocation
|
||||
.as_ref()
|
||||
.map(|alloc| (alloc.blocks.clone(), alloc.slots.clone()))
|
||||
.map(|alloc| (alloc.blocks().to_vec(), alloc.slots().to_vec()))
|
||||
.unwrap_or((Vec::new(), Vec::new()));
|
||||
|
||||
KeptRequest {
|
||||
|
|
Loading…
Reference in New Issue