fix windowing

This commit is contained in:
OlivierDehaene 2024-06-11 18:40:38 +02:00
parent 37266e2dbb
commit c2fb459bc1
2 changed files with 61 additions and 47 deletions

View File

@ -24,6 +24,9 @@ impl BlockAllocation {
&self.allocated_slots
}
/// Extend an allocation by adding a new block
/// If the allocation length > window size, repeats blocks and slots to cover the
/// whole `required_blocks` and `required_slots`
pub(crate) fn extend(&mut self) -> Result<(), AllocationError> {
let (block, slots) = self.block_allocator.allocate_block()?;
// Add block and slots to current allocation
@ -48,6 +51,7 @@ impl BlockAllocation {
}
impl Drop for BlockAllocation {
/// Free the blocks
fn drop(&mut self) {
let allocated_blocks = std::mem::take(&mut self.allocated_blocks);
self.block_allocator.free(allocated_blocks)
@ -114,39 +118,47 @@ impl BlockAllocator {
let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size;
// prompt blocks + a single block for decode
let required_blocks = required_prompt_blocks + 1;
let required_slots = required_blocks * self.block_size;
// Slots and blocks required for the whole request
let total_slots = prompt_tokens + decode_tokens;
let total_required_blocks = (total_slots + self.block_size - 1) / self.block_size;
let (clipped_required_blocks, repeats) = match self.window_size {
// Nothing to do
None => (required_blocks, 1),
Some(window_size) => {
Some(window_size) if required_slots >= window_size => {
// Number of blocks for this window size
let window_size_blocks = (window_size + self.block_size - 1) / self.block_size;
if required_blocks > window_size_blocks {
// Number of times we will need to repeat blocks to cover the required allocation
let repeats = (required_blocks + window_size_blocks - 1) / window_size_blocks;
// Number of times we will need to repeat blocks to cover the total allocation
let repeats = (total_slots + window_size - 1) / window_size;
(window_size_blocks, repeats)
} else {
(required_blocks, 1)
}
// Nothing to do
_ => (required_blocks, 1),
};
// Scoped to drop the lock early
let allocated_blocks = {
let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired");
let clipped_required_blocks = clipped_required_blocks as usize;
if clipped_required_blocks > free_blocks.len() {
// Not enough blocks to cover this allocation
// Early return
return Err(AllocationError::NotEnoughPages);
}
// Take the blocks
let n_free_blocks = free_blocks.len();
free_blocks.split_off(n_free_blocks - clipped_required_blocks)
};
let repeats = repeats as usize;
let required_blocks = required_blocks as usize;
let clipped_required_blocks = clipped_required_blocks as usize;
let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired");
if clipped_required_blocks > free_blocks.len() {
Err(AllocationError::NotEnoughPages)
} else {
let n_free_blocks = free_blocks.len();
let allocated_blocks = free_blocks.split_off(n_free_blocks - clipped_required_blocks);
let total_slots = total_slots as usize;
let total_required_blocks = total_required_blocks as usize;
let allocated_blocks = if repeats != 1 {
let mut allocated_blocks = allocated_blocks.repeat(repeats);
allocated_blocks.truncate(required_blocks);
allocated_blocks.truncate(total_required_blocks);
allocated_blocks
} else {
allocated_blocks
@ -155,12 +167,10 @@ impl BlockAllocator {
let mut allocated_slots =
Vec::with_capacity(allocated_blocks.len() * self.block_size as usize * repeats);
let required_slots = (prompt_tokens + decode_tokens) as usize;
'slots: for block_id in allocated_blocks.iter() {
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
allocated_slots.push(s);
if allocated_slots.len() > required_slots {
if allocated_slots.len() > total_slots {
break 'slots;
}
}
@ -169,12 +179,11 @@ impl BlockAllocator {
Ok(BlockAllocation {
allocated_blocks,
allocated_slots,
required_blocks,
required_slots,
required_blocks: total_required_blocks,
required_slots: total_slots,
block_allocator: self.clone(),
})
}
}
pub(crate) fn free(&self, blocks: Vec<u32>) {
self.free_blocks

View File

@ -361,6 +361,7 @@ async fn decode(
}
/// Filter a `batch` and remove all requests not present in `entries`
/// Ask the server to generate the full texts for entries in `terminated_entries`
#[instrument(skip_all)]
async fn filter_batch(
client: &mut ShardedClient,
@ -408,7 +409,10 @@ async fn filter_batch(
}
}
///
/// Send `InferStreamResponse::Intermediate` and the final `InferStreamResponse::End` messages
/// to terminated requests
/// It modifies the last `InferStreamResponse::Intermediate` to add the final full text in
/// `terminated_generations`
#[instrument(skip_all)]
fn send_terminated_generations(
terminated_generations: Vec<TerminatedGeneration>,
@ -530,7 +534,7 @@ fn send_stream_responses(
}
/// Check if block allocations need to be extended
/// If we don't have enough blocks, request will be filtered with be added to an IntMap of
/// If we don't have enough blocks, request will be filtered and added to an IntMap of
/// terminated entries.
/// If at least one entry allocation was extended, we return true to force an update
#[instrument(skip_all)]
@ -592,6 +596,7 @@ fn filter_send_update_allocations(
}
/// Map `Generation` to `<(bool, Vec<(u64, InferStreamResponse)>)>`
/// `bool` is `true` if the generation is finished
fn map_generation(generation: Generation, entry: &Entry) -> (bool, Vec<InferStreamResponse>) {
let mut finished = false;
let mut stream_responses = Vec::with_capacity(16);