re-working logic, wip
This commit is contained in:
parent
298bf31e69
commit
713d70b443
|
@ -1,6 +1,6 @@
|
|||
use std::cmp::min;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct BlockAllocation {
|
||||
|
@ -16,13 +16,23 @@ impl BlockAllocation {
|
|||
self.slots.len()
|
||||
}
|
||||
|
||||
pub(crate) async fn extend(&mut self, cache_length: u32) -> Result<(), AllocationError> {
|
||||
let remaining_tokens =
|
||||
(self.prompt_tokens + self.decode_tokens).saturating_sub(cache_length);
|
||||
self.block_allocator
|
||||
.clone()
|
||||
.extend(self, remaining_tokens)
|
||||
.await
|
||||
pub(crate) fn extend(&mut self) -> Result<(), AllocationError> {
|
||||
let (block, slots) = self.block_allocator.allocate_block()?;
|
||||
|
||||
match self.block_allocator.window_size {
|
||||
None => {
|
||||
self.blocks.push(block);
|
||||
self.slots.extend(slots);
|
||||
}
|
||||
Some(window_size) => {
|
||||
if self.len() as u32 > window_size {
|
||||
let total_tokens = self.prompt_tokens + self.decode_tokens;
|
||||
|
||||
let repeats = (total_tokens + window_size - 1) / window_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -34,8 +44,9 @@ impl Drop for BlockAllocation {
|
|||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct BlockAllocator {
|
||||
/// Channel to communicate with the background task
|
||||
block_allocator: mpsc::UnboundedSender<BlockAllocatorCommand>,
|
||||
free_blocks: Arc<Mutex<Vec<u32>>>,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
}
|
||||
|
||||
impl BlockAllocator {
|
||||
|
@ -44,39 +55,105 @@ impl BlockAllocator {
|
|||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
) -> Self {
|
||||
// Create channel
|
||||
let (sender, receiver) = mpsc::unbounded_channel();
|
||||
|
||||
// Launch background queue task
|
||||
tokio::spawn(block_allocator_task(
|
||||
max_batch_total_tokens / block_size,
|
||||
block_size,
|
||||
window_size,
|
||||
receiver,
|
||||
));
|
||||
let blocks = max_batch_total_tokens / block_size;
|
||||
// Block 0 is reserved for health checks
|
||||
let free_blocks: Vec<u32> = (1..blocks).collect();
|
||||
|
||||
Self {
|
||||
block_allocator: sender,
|
||||
free_blocks: Arc::new(Mutex::new(free_blocks)),
|
||||
block_size,
|
||||
window_size,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn allocate(
|
||||
fn allocate_block(&self) -> Result<(u32, Vec<u32>), AllocationError> {
|
||||
let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired");
|
||||
|
||||
if free_blocks.is_empty() {
|
||||
return Err(AllocationError::NotEnoughPages);
|
||||
}
|
||||
|
||||
let block_id = free_blocks.pop().unwrap();
|
||||
let slots = ((block_id * self.block_size)..((block_id + 1) * self.block_size)).collect();
|
||||
Ok((block_id, slots))
|
||||
}
|
||||
|
||||
/// For prompt tokens, we allocate enough blocks to cover all tokens
|
||||
/// For decode tokens, we allocate block by block
|
||||
///
|
||||
/// If prompt tokens + min(decode_tokens, block_size) > window size, we repeat blocks and slots
|
||||
fn allocate(
|
||||
&self,
|
||||
prompt_tokens: u32,
|
||||
decode_tokens: u32,
|
||||
) -> Result<(Vec<u32>, Vec<u32>), AllocationError> {
|
||||
// let decode_tokens = min(decode_tokens, self.block_size);
|
||||
// let tokens = prompt_tokens + decode_tokens;
|
||||
|
||||
let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size;
|
||||
// prompt blocks + a single block for decode
|
||||
let required_blocks = required_prompt_blocks + 1;
|
||||
|
||||
let (required_blocks, repeats) = match self.window_size {
|
||||
// Nothing to do
|
||||
None => (required_blocks, 1),
|
||||
Some(window_size) => {
|
||||
// Number of blocks needed for this window size
|
||||
let window_size_required_blocks = (window_size + self.block_size - 1) / self.block_size;
|
||||
// Number of times we will need to repeat blocks to cover the required allocation
|
||||
let repeats = (required_blocks + window_size_required_blocks -1) / window_size_required_blocks;
|
||||
let required_blocks = min(required_blocks, window_size_required_blocks);
|
||||
|
||||
(required_blocks, repeats)
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// if prompt + decode < window size => do nothing
|
||||
/// if prompt + decode > window size => do normal until we reach window size then
|
||||
|
||||
// Apply window size
|
||||
let (required_blocks, repeats) = {
|
||||
let (tokens, repeats) = match self.window_size {
|
||||
None => (tokens, 1),
|
||||
Some(window_size) => {
|
||||
let repeats = (tokens + window_size - 1) / window_size;
|
||||
let tokens = min(tokens, window_size);
|
||||
(tokens, repeats as usize)
|
||||
}
|
||||
};
|
||||
// Pad to a multiple of block size
|
||||
let required_blocks = (tokens + self.block_size - 1) / self.block_size;
|
||||
(required_blocks, repeats)
|
||||
};
|
||||
|
||||
let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired");
|
||||
|
||||
if required_blocks > free_blocks.len() as u32 {
|
||||
Err(AllocationError::NotEnoughPages)
|
||||
} else {
|
||||
let n_free_blocks = free_blocks.len();
|
||||
let blocks =
|
||||
free_blocks.split_off(n_free_blocks - required_blocks as usize);
|
||||
let mut slots = Vec::with_capacity(
|
||||
(required_blocks * self.block_size * repeats as u32) as usize,
|
||||
);
|
||||
|
||||
for block_id in blocks.repeat(repeats).iter() {
|
||||
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
||||
slots.push(s);
|
||||
}
|
||||
}
|
||||
Ok((blocks, slots))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn block_allocation(
|
||||
&self,
|
||||
prompt_tokens: u32,
|
||||
decode_tokens: u32,
|
||||
) -> Result<BlockAllocation, AllocationError> {
|
||||
let (response_sender, response_receiver) = oneshot::channel();
|
||||
self.block_allocator
|
||||
.send(BlockAllocatorCommand::Allocate {
|
||||
prompt_tokens,
|
||||
decode_tokens,
|
||||
response_sender,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
response_receiver
|
||||
.await
|
||||
.unwrap()
|
||||
self.allocate_inner(prompt_tokens, decode_tokens)
|
||||
.map(|(blocks, slots)| BlockAllocation {
|
||||
blocks,
|
||||
slots,
|
||||
|
@ -86,103 +163,11 @@ impl BlockAllocator {
|
|||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn extend(
|
||||
&self,
|
||||
block_allocation: &mut BlockAllocation,
|
||||
tokens: u32,
|
||||
) -> Result<(), AllocationError> {
|
||||
let (response_sender, response_receiver) = oneshot::channel();
|
||||
self.block_allocator
|
||||
.send(BlockAllocatorCommand::Allocate {
|
||||
prompt_tokens: 0,
|
||||
decode_tokens: tokens,
|
||||
response_sender,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let (blocks, slots) = response_receiver.await.unwrap()?;
|
||||
block_allocation.blocks.extend(blocks);
|
||||
block_allocation.slots.extend(slots);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn free(&self, blocks: Vec<u32>) {
|
||||
self.block_allocator
|
||||
.send(BlockAllocatorCommand::Free { blocks })
|
||||
.unwrap();
|
||||
self.free_blocks.lock().expect("Lock could not be acquired. This is a bug.").extend(blocks)
|
||||
}
|
||||
}
|
||||
|
||||
async fn block_allocator_task(
|
||||
blocks: u32,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
|
||||
) {
|
||||
// Block 0 is reserved for health checks
|
||||
let mut free_blocks: Vec<u32> = (1..blocks).collect();
|
||||
while let Some(cmd) = receiver.recv().await {
|
||||
match cmd {
|
||||
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks),
|
||||
BlockAllocatorCommand::Allocate {
|
||||
prompt_tokens,
|
||||
decode_tokens,
|
||||
response_sender,
|
||||
} => {
|
||||
let decode_tokens = min(decode_tokens, block_size);
|
||||
let tokens = prompt_tokens + decode_tokens;
|
||||
|
||||
// FIXME: window size is not working
|
||||
// Apply window size
|
||||
let (required_blocks, repeats) = {
|
||||
let (tokens, repeats) = match window_size {
|
||||
None => (tokens, 1),
|
||||
Some(window_size) => {
|
||||
let repeats = (tokens + window_size - 1) / window_size;
|
||||
let tokens = min(tokens, window_size);
|
||||
(tokens, repeats as usize)
|
||||
}
|
||||
};
|
||||
// Pad to a multiple of block size
|
||||
let required_blocks = (tokens + block_size - 1) / block_size;
|
||||
(required_blocks, repeats)
|
||||
};
|
||||
|
||||
let allocation = if required_blocks > free_blocks.len() as u32 {
|
||||
Err(AllocationError::NotEnoughPages)
|
||||
} else {
|
||||
let blocks =
|
||||
free_blocks.split_off(free_blocks.len() - required_blocks as usize);
|
||||
let mut slots = Vec::with_capacity(
|
||||
(required_blocks * block_size * repeats as u32) as usize,
|
||||
);
|
||||
|
||||
for block_id in blocks.repeat(repeats).iter() {
|
||||
for s in (block_id * block_size)..((block_id + 1) * block_size) {
|
||||
slots.push(s);
|
||||
}
|
||||
}
|
||||
Ok((blocks, slots))
|
||||
};
|
||||
response_sender.send(allocation).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum BlockAllocatorCommand {
|
||||
Free {
|
||||
blocks: Vec<u32>,
|
||||
},
|
||||
Allocate {
|
||||
prompt_tokens: u32,
|
||||
decode_tokens: u32,
|
||||
#[allow(clippy::type_complexity)]
|
||||
response_sender: oneshot::Sender<Result<(Vec<u32>, Vec<u32>), AllocationError>>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum AllocationError {
|
||||
#[error("Not enough pages")]
|
||||
|
|
|
@ -284,7 +284,6 @@ impl State {
|
|||
entry.request.stopping_parameters.max_new_tokens + self.speculate;
|
||||
match block_allocator
|
||||
.allocate(entry.request.input_length, decode_tokens)
|
||||
.await
|
||||
{
|
||||
Err(_) => {
|
||||
// Entry is over budget
|
||||
|
|
|
@ -428,8 +428,7 @@ async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) -> bool {
|
|||
.block_allocation
|
||||
.as_mut()
|
||||
.expect("We checked that the block allocation exists above")
|
||||
.extend(entry.cache_length)
|
||||
.await
|
||||
.extend()
|
||||
};
|
||||
|
||||
if extension.is_err() {
|
||||
|
|
Loading…
Reference in New Issue