fix windowing
This commit is contained in:
parent
37266e2dbb
commit
c2fb459bc1
|
@ -24,6 +24,9 @@ impl BlockAllocation {
|
|||
&self.allocated_slots
|
||||
}
|
||||
|
||||
/// Extend an allocation by adding a new block
|
||||
/// If the allocation length > window size, repeats blocks and slots to cover the
|
||||
/// whole `required_blocks` and `required_slots`
|
||||
pub(crate) fn extend(&mut self) -> Result<(), AllocationError> {
|
||||
let (block, slots) = self.block_allocator.allocate_block()?;
|
||||
// Add block and slots to current allocation
|
||||
|
@ -48,6 +51,7 @@ impl BlockAllocation {
|
|||
}
|
||||
|
||||
impl Drop for BlockAllocation {
|
||||
/// Free the blocks
|
||||
fn drop(&mut self) {
|
||||
let allocated_blocks = std::mem::take(&mut self.allocated_blocks);
|
||||
self.block_allocator.free(allocated_blocks)
|
||||
|
@ -114,39 +118,47 @@ impl BlockAllocator {
|
|||
let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size;
|
||||
// prompt blocks + a single block for decode
|
||||
let required_blocks = required_prompt_blocks + 1;
|
||||
let required_slots = required_blocks * self.block_size;
|
||||
|
||||
// Slots and blocks required for the whole request
|
||||
let total_slots = prompt_tokens + decode_tokens;
|
||||
let total_required_blocks = (total_slots + self.block_size - 1) / self.block_size;
|
||||
|
||||
let (clipped_required_blocks, repeats) = match self.window_size {
|
||||
// Nothing to do
|
||||
None => (required_blocks, 1),
|
||||
Some(window_size) => {
|
||||
Some(window_size) if required_slots >= window_size => {
|
||||
// Number of blocks for this window size
|
||||
let window_size_blocks = (window_size + self.block_size - 1) / self.block_size;
|
||||
|
||||
if required_blocks > window_size_blocks {
|
||||
// Number of times we will need to repeat blocks to cover the required allocation
|
||||
let repeats = (required_blocks + window_size_blocks - 1) / window_size_blocks;
|
||||
// Number of times we will need to repeat blocks to cover the total allocation
|
||||
let repeats = (total_slots + window_size - 1) / window_size;
|
||||
(window_size_blocks, repeats)
|
||||
} else {
|
||||
(required_blocks, 1)
|
||||
}
|
||||
// Nothing to do
|
||||
_ => (required_blocks, 1),
|
||||
};
|
||||
|
||||
// Scoped to drop the lock early
|
||||
let allocated_blocks = {
|
||||
let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired");
|
||||
let clipped_required_blocks = clipped_required_blocks as usize;
|
||||
|
||||
if clipped_required_blocks > free_blocks.len() {
|
||||
// Not enough blocks to cover this allocation
|
||||
// Early return
|
||||
return Err(AllocationError::NotEnoughPages);
|
||||
}
|
||||
|
||||
// Take the blocks
|
||||
let n_free_blocks = free_blocks.len();
|
||||
free_blocks.split_off(n_free_blocks - clipped_required_blocks)
|
||||
};
|
||||
|
||||
let repeats = repeats as usize;
|
||||
let required_blocks = required_blocks as usize;
|
||||
let clipped_required_blocks = clipped_required_blocks as usize;
|
||||
|
||||
let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired");
|
||||
|
||||
if clipped_required_blocks > free_blocks.len() {
|
||||
Err(AllocationError::NotEnoughPages)
|
||||
} else {
|
||||
let n_free_blocks = free_blocks.len();
|
||||
let allocated_blocks = free_blocks.split_off(n_free_blocks - clipped_required_blocks);
|
||||
let total_slots = total_slots as usize;
|
||||
let total_required_blocks = total_required_blocks as usize;
|
||||
|
||||
let allocated_blocks = if repeats != 1 {
|
||||
let mut allocated_blocks = allocated_blocks.repeat(repeats);
|
||||
allocated_blocks.truncate(required_blocks);
|
||||
allocated_blocks.truncate(total_required_blocks);
|
||||
allocated_blocks
|
||||
} else {
|
||||
allocated_blocks
|
||||
|
@ -155,12 +167,10 @@ impl BlockAllocator {
|
|||
let mut allocated_slots =
|
||||
Vec::with_capacity(allocated_blocks.len() * self.block_size as usize * repeats);
|
||||
|
||||
let required_slots = (prompt_tokens + decode_tokens) as usize;
|
||||
|
||||
'slots: for block_id in allocated_blocks.iter() {
|
||||
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
||||
allocated_slots.push(s);
|
||||
if allocated_slots.len() > required_slots {
|
||||
if allocated_slots.len() > total_slots {
|
||||
break 'slots;
|
||||
}
|
||||
}
|
||||
|
@ -169,12 +179,11 @@ impl BlockAllocator {
|
|||
Ok(BlockAllocation {
|
||||
allocated_blocks,
|
||||
allocated_slots,
|
||||
required_blocks,
|
||||
required_slots,
|
||||
required_blocks: total_required_blocks,
|
||||
required_slots: total_slots,
|
||||
block_allocator: self.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn free(&self, blocks: Vec<u32>) {
|
||||
self.free_blocks
|
||||
|
|
|
@ -361,6 +361,7 @@ async fn decode(
|
|||
}
|
||||
|
||||
/// Filter a `batch` and remove all requests not present in `entries`
|
||||
/// Ask the server to generate the full texts for entries in `terminated_entries`
|
||||
#[instrument(skip_all)]
|
||||
async fn filter_batch(
|
||||
client: &mut ShardedClient,
|
||||
|
@ -408,7 +409,10 @@ async fn filter_batch(
|
|||
}
|
||||
}
|
||||
|
||||
///
|
||||
/// Send `InferStreamResponse::Intermediate` and the final `InferStreamResponse::End` messages
|
||||
/// to terminated requests
|
||||
/// It modifies the last `InferStreamResponse::Intermediate` to add the final full text in
|
||||
/// `terminated_generations`
|
||||
#[instrument(skip_all)]
|
||||
fn send_terminated_generations(
|
||||
terminated_generations: Vec<TerminatedGeneration>,
|
||||
|
@ -530,7 +534,7 @@ fn send_stream_responses(
|
|||
}
|
||||
|
||||
/// Check if block allocations need to be extended
|
||||
/// If we don't have enough blocks, request will be filtered with be added to an IntMap of
|
||||
/// If we don't have enough blocks, request will be filtered and added to an IntMap of
|
||||
/// terminated entries.
|
||||
/// If at least one entry allocation was extended, we return true to force an update
|
||||
#[instrument(skip_all)]
|
||||
|
@ -592,6 +596,7 @@ fn filter_send_update_allocations(
|
|||
}
|
||||
|
||||
/// Map `Generation` to `<(bool, Vec<(u64, InferStreamResponse)>)>`
|
||||
/// `bool` is `true` if the generation is finished
|
||||
fn map_generation(generation: Generation, entry: &Entry) -> (bool, Vec<InferStreamResponse>) {
|
||||
let mut finished = false;
|
||||
let mut stream_responses = Vec::with_capacity(16);
|
||||
|
|
Loading…
Reference in New Issue