re-working logic, wip

This commit is contained in:
OlivierDehaene 2024-06-07 13:39:42 +02:00
parent 298bf31e69
commit 713d70b443
3 changed files with 113 additions and 130 deletions

View File

@ -1,6 +1,6 @@
use std::cmp::min;
use std::sync::{Arc, Mutex};
use thiserror::Error;
use tokio::sync::{mpsc, oneshot};
#[derive(Debug, Clone)]
pub(crate) struct BlockAllocation {
@ -16,13 +16,23 @@ impl BlockAllocation {
self.slots.len()
}
pub(crate) async fn extend(&mut self, cache_length: u32) -> Result<(), AllocationError> {
let remaining_tokens =
(self.prompt_tokens + self.decode_tokens).saturating_sub(cache_length);
self.block_allocator
.clone()
.extend(self, remaining_tokens)
.await
pub(crate) fn extend(&mut self) -> Result<(), AllocationError> {
let (block, slots) = self.block_allocator.allocate_block()?;
match self.block_allocator.window_size {
None => {
self.blocks.push(block);
self.slots.extend(slots);
}
Some(window_size) => {
if self.len() as u32 > window_size {
let total_tokens = self.prompt_tokens + self.decode_tokens;
let repeats = (total_tokens + window_size - 1) / window_size;
}
}
}
Ok(())
}
}
@ -34,8 +44,9 @@ impl Drop for BlockAllocation {
#[derive(Debug, Clone)]
pub(crate) struct BlockAllocator {
/// Channel to communicate with the background task
block_allocator: mpsc::UnboundedSender<BlockAllocatorCommand>,
free_blocks: Arc<Mutex<Vec<u32>>>,
block_size: u32,
window_size: Option<u32>,
}
impl BlockAllocator {
@ -44,39 +55,105 @@ impl BlockAllocator {
block_size: u32,
window_size: Option<u32>,
) -> Self {
// Create channel
let (sender, receiver) = mpsc::unbounded_channel();
// Launch background queue task
tokio::spawn(block_allocator_task(
max_batch_total_tokens / block_size,
block_size,
window_size,
receiver,
));
let blocks = max_batch_total_tokens / block_size;
// Block 0 is reserved for health checks
let free_blocks: Vec<u32> = (1..blocks).collect();
Self {
block_allocator: sender,
free_blocks: Arc::new(Mutex::new(free_blocks)),
block_size,
window_size,
}
}
pub(crate) async fn allocate(
fn allocate_block(&self) -> Result<(u32, Vec<u32>), AllocationError> {
let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired");
if free_blocks.is_empty() {
return Err(AllocationError::NotEnoughPages);
}
let block_id = free_blocks.pop().unwrap();
let slots = ((block_id * self.block_size)..((block_id + 1) * self.block_size)).collect();
Ok((block_id, slots))
}
/// For prompt tokens, we allocate enough blocks to cover all tokens
/// For decode tokens, we allocate block by block
///
/// If prompt tokens + min(decode_tokens, block_size) > window size, we repeat blocks and slots
fn allocate(
&self,
prompt_tokens: u32,
decode_tokens: u32,
) -> Result<(Vec<u32>, Vec<u32>), AllocationError> {
// let decode_tokens = min(decode_tokens, self.block_size);
// let tokens = prompt_tokens + decode_tokens;
let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size;
// prompt blocks + a single block for decode
let required_blocks = required_prompt_blocks + 1;
let (required_blocks, repeats) = match self.window_size {
// Nothing to do
None => (required_blocks, 1),
Some(window_size) => {
// Number of blocks needed for this window size
let window_size_required_blocks = (window_size + self.block_size - 1) / self.block_size;
// Number of times we will need to repeat blocks to cover the required allocation
let repeats = (required_blocks + window_size_required_blocks -1) / window_size_required_blocks;
let required_blocks = min(required_blocks, window_size_required_blocks);
(required_blocks, repeats)
}
};
/// if prompt + decode < window size => do nothing
/// if prompt + decode > window size => do normal until we reach window size then
// Apply window size
let (required_blocks, repeats) = {
let (tokens, repeats) = match self.window_size {
None => (tokens, 1),
Some(window_size) => {
let repeats = (tokens + window_size - 1) / window_size;
let tokens = min(tokens, window_size);
(tokens, repeats as usize)
}
};
// Pad to a multiple of block size
let required_blocks = (tokens + self.block_size - 1) / self.block_size;
(required_blocks, repeats)
};
let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired");
if required_blocks > free_blocks.len() as u32 {
Err(AllocationError::NotEnoughPages)
} else {
let n_free_blocks = free_blocks.len();
let blocks =
free_blocks.split_off(n_free_blocks - required_blocks as usize);
let mut slots = Vec::with_capacity(
(required_blocks * self.block_size * repeats as u32) as usize,
);
for block_id in blocks.repeat(repeats).iter() {
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
slots.push(s);
}
}
Ok((blocks, slots))
}
}
pub(crate) fn block_allocation(
&self,
prompt_tokens: u32,
decode_tokens: u32,
) -> Result<BlockAllocation, AllocationError> {
let (response_sender, response_receiver) = oneshot::channel();
self.block_allocator
.send(BlockAllocatorCommand::Allocate {
prompt_tokens,
decode_tokens,
response_sender,
})
.unwrap();
response_receiver
.await
.unwrap()
self.allocate_inner(prompt_tokens, decode_tokens)
.map(|(blocks, slots)| BlockAllocation {
blocks,
slots,
@ -86,103 +163,11 @@ impl BlockAllocator {
})
}
pub(crate) async fn extend(
&self,
block_allocation: &mut BlockAllocation,
tokens: u32,
) -> Result<(), AllocationError> {
let (response_sender, response_receiver) = oneshot::channel();
self.block_allocator
.send(BlockAllocatorCommand::Allocate {
prompt_tokens: 0,
decode_tokens: tokens,
response_sender,
})
.unwrap();
let (blocks, slots) = response_receiver.await.unwrap()?;
block_allocation.blocks.extend(blocks);
block_allocation.slots.extend(slots);
Ok(())
}
pub(crate) fn free(&self, blocks: Vec<u32>) {
self.block_allocator
.send(BlockAllocatorCommand::Free { blocks })
.unwrap();
self.free_blocks.lock().expect("Lock could not be acquired. This is a bug.").extend(blocks)
}
}
async fn block_allocator_task(
blocks: u32,
block_size: u32,
window_size: Option<u32>,
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
) {
// Block 0 is reserved for health checks
let mut free_blocks: Vec<u32> = (1..blocks).collect();
while let Some(cmd) = receiver.recv().await {
match cmd {
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks),
BlockAllocatorCommand::Allocate {
prompt_tokens,
decode_tokens,
response_sender,
} => {
let decode_tokens = min(decode_tokens, block_size);
let tokens = prompt_tokens + decode_tokens;
// FIXME: window size is not working
// Apply window size
let (required_blocks, repeats) = {
let (tokens, repeats) = match window_size {
None => (tokens, 1),
Some(window_size) => {
let repeats = (tokens + window_size - 1) / window_size;
let tokens = min(tokens, window_size);
(tokens, repeats as usize)
}
};
// Pad to a multiple of block size
let required_blocks = (tokens + block_size - 1) / block_size;
(required_blocks, repeats)
};
let allocation = if required_blocks > free_blocks.len() as u32 {
Err(AllocationError::NotEnoughPages)
} else {
let blocks =
free_blocks.split_off(free_blocks.len() - required_blocks as usize);
let mut slots = Vec::with_capacity(
(required_blocks * block_size * repeats as u32) as usize,
);
for block_id in blocks.repeat(repeats).iter() {
for s in (block_id * block_size)..((block_id + 1) * block_size) {
slots.push(s);
}
}
Ok((blocks, slots))
};
response_sender.send(allocation).unwrap();
}
}
}
}
#[derive(Debug)]
enum BlockAllocatorCommand {
Free {
blocks: Vec<u32>,
},
Allocate {
prompt_tokens: u32,
decode_tokens: u32,
#[allow(clippy::type_complexity)]
response_sender: oneshot::Sender<Result<(Vec<u32>, Vec<u32>), AllocationError>>,
},
}
#[derive(Error, Debug)]
pub enum AllocationError {
#[error("Not enough pages")]

View File

@ -284,7 +284,6 @@ impl State {
entry.request.stopping_parameters.max_new_tokens + self.speculate;
match block_allocator
.allocate(entry.request.input_length, decode_tokens)
.await
{
Err(_) => {
// Entry is over budget

View File

@ -428,8 +428,7 @@ async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) -> bool {
.block_allocation
.as_mut()
.expect("We checked that the block allocation exists above")
.extend(entry.cache_length)
.await
.extend()
};
if extension.is_err() {