working example

This commit is contained in:
OlivierDehaene 2024-06-05 18:47:16 +02:00
parent 1cc86930a6
commit 35f27cbcc1
4 changed files with 122 additions and 55 deletions

View File

@ -1,10 +1,13 @@
use std::cmp::min;
use std::cmp::{max, min};
use thiserror::Error;
use tokio::sync::{mpsc, oneshot};
#[derive(Debug, Clone)]
pub(crate) struct BlockAllocation {
pub blocks: Vec<u32>,
pub slots: Vec<u32>,
prompt_tokens: u32,
decode_tokens: u32,
block_allocator: BlockAllocator,
}
@ -12,6 +15,14 @@ impl BlockAllocation {
pub(crate) fn len(&self) -> usize {
self.slots.len()
}
pub(crate) async fn extend(&mut self, current_length: u32) -> Result<(), AllocationError> {
let remaining_tokens = max(self.prompt_tokens + self.decode_tokens - current_length, 1);
self.block_allocator
.clone()
.extend(self, remaining_tokens)
.await
}
}
impl Drop for BlockAllocation {
@ -48,11 +59,16 @@ impl BlockAllocator {
}
}
pub(crate) async fn allocate(&self, tokens: u32) -> Option<BlockAllocation> {
pub(crate) async fn allocate(
&self,
prompt_tokens: u32,
decode_tokens: u32,
) -> Result<BlockAllocation, AllocationError> {
let (response_sender, response_receiver) = oneshot::channel();
self.block_allocator
.send(BlockAllocatorCommand::Allocate {
tokens,
prompt_tokens,
decode_tokens,
response_sender,
})
.unwrap();
@ -63,10 +79,32 @@ impl BlockAllocator {
.map(|(blocks, slots)| BlockAllocation {
blocks,
slots,
prompt_tokens,
decode_tokens,
block_allocator: self.clone(),
})
}
pub(crate) async fn extend(
&self,
block_allocation: &mut BlockAllocation,
tokens: u32,
) -> Result<(), AllocationError> {
let (response_sender, response_receiver) = oneshot::channel();
self.block_allocator
.send(BlockAllocatorCommand::Allocate {
prompt_tokens: 0,
decode_tokens: tokens,
response_sender,
})
.unwrap();
let (blocks, slots) = response_receiver.await.unwrap()?;
block_allocation.blocks.extend(blocks);
block_allocation.slots.extend(slots);
Ok(())
}
pub(crate) fn free(&self, blocks: Vec<u32>) {
self.block_allocator
.send(BlockAllocatorCommand::Free { blocks })
@ -86,10 +124,12 @@ async fn block_allocator_task(
match cmd {
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks),
BlockAllocatorCommand::Allocate {
tokens,
prompt_tokens,
decode_tokens,
response_sender,
} => {
// let tokens = 16;
let decode_tokens = min(decode_tokens, block_size);
let tokens = prompt_tokens + decode_tokens;
// Apply window size
let (required_blocks, repeats) = {
@ -106,9 +146,8 @@ async fn block_allocator_task(
(required_blocks, repeats)
};
let tokens = tokens as usize;
let allocation = if required_blocks > free_blocks.len() as u32 {
None
Err(AllocationError::NotEnoughPages)
} else {
let blocks =
free_blocks.split_off(free_blocks.len() - required_blocks as usize);
@ -116,15 +155,12 @@ async fn block_allocator_task(
(required_blocks * block_size * repeats as u32) as usize,
);
'slots: for block_id in blocks.repeat(repeats).iter() {
for block_id in blocks.repeat(repeats).iter() {
for s in (block_id * block_size)..((block_id + 1) * block_size) {
slots.push(s);
if slots.len() == tokens {
break 'slots;
}
}
}
Some((blocks, slots))
Ok((blocks, slots))
};
response_sender.send(allocation).unwrap();
}
@ -138,7 +174,15 @@ enum BlockAllocatorCommand {
blocks: Vec<u32>,
},
Allocate {
tokens: u32,
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>)>>,
prompt_tokens: u32,
decode_tokens: u32,
#[allow(clippy::type_complexity)]
response_sender: oneshot::Sender<Result<(Vec<u32>, Vec<u32>), AllocationError>>,
},
}
#[derive(Error, Debug)]
pub enum AllocationError {
#[error("Not enough pages")]
NotEnoughPages,
}

View File

@ -295,20 +295,20 @@ impl State {
break;
}
let tokens = entry.request.input_length
+ entry.request.stopping_parameters.max_new_tokens
+ self.speculate
- 1;
match block_allocator.allocate(tokens).await {
None => {
let decode_tokens =
entry.request.stopping_parameters.max_new_tokens + self.speculate - 1;
match block_allocator
.allocate(entry.request.input_length, decode_tokens)
.await
{
Err(_) => {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: not enough free blocks");
self.entries.push_front((id, entry));
break 'entry_loop;
}
Some(block_allocation) => {
Ok(block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}");
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
Some(block_allocation)

View File

@ -247,7 +247,7 @@ async fn prefill(
filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
let next_batch = filter_batch(client, next_batch, entries, false).await;
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
@ -288,10 +288,10 @@ async fn decode(
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
filter_update_allocations(entries).await;
let updated = filter_update_allocations(entries).await;
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
let next_batch = filter_batch(client, next_batch, entries, updated).await;
if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
@ -322,11 +322,12 @@ async fn filter_batch(
client: &mut ShardedClient,
next_batch: Option<CachedBatch>,
entries: &IntMap<u64, Entry>,
force_update: bool,
) -> Option<CachedBatch> {
let batch = next_batch?;
// No need to filter
if batch.size as usize == entries.len() {
if batch.size as usize == entries.len() && !force_update {
return Some(batch);
}
@ -348,6 +349,7 @@ async fn filter_batch(
.as_ref()
.map(|alloc| (alloc.blocks.clone(), alloc.slots.clone()))
.unwrap_or((Vec::new(), Vec::new()));
UpdatedRequest {
id: *request_id,
blocks,
@ -393,34 +395,58 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
/// Check if block allocations need to be extended
/// If we don't have enough blocks, request will be filtered with an OutOfPages error
#[instrument(skip_all)]
async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) {
entries.retain(|request_id, entry| {
if entry.block_allocation.is_none() {
return true;
async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) -> bool {
let ids: Vec<u64> = entries
.iter()
.filter_map(|(id, entry)| {
entry
.block_allocation
.as_ref()
.map(|block_allocation| {
if entry.current_length > block_allocation.len() as u32 {
// We need to re-allocate
Some(*id)
} else {
None
}
})
.unwrap_or(None)
})
.collect();
for id in ids.iter() {
// Get entry
// We can `expect` here as the request id should always be in the entries
let extension = {
let entry = entries
.get_mut(id)
.expect("ID not found in entries. This is a bug.");
entry
.block_allocation
.as_mut()
.unwrap()
.extend(entry.current_length)
.await
};
if extension.is_err() {
let entry = entries
.remove(id)
.expect("ID not found in entries. This is a bug.");
// Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::OutOfPages;
metrics::increment_counter!("tgi_request_failure", "err" => "out_of_pages");
tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone.
entry.response_tx.send(Err(err)).unwrap_or(());
}
}
// We can unwrap since we already validated above that block_allocation is not None
let mut block_allocation = entry.block_allocation.as_ref().unwrap();
// Nothing to update
if entry.current_length <= block_allocation.len() as u32 {
return true;
}
// Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::OutOfPages;
metrics::increment_counter!("tgi_request_failure", "err" => "out_of_pages");
tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Err(err))
.unwrap_or(());
false
});
// If ids is not empty, we need to update
!ids.is_empty()
}
/// Send responses through the `entry` response channel

View File

@ -402,9 +402,6 @@ class FlashCausalLMBatch(Batch):
) -> Optional["FlashCausalLMBatch"]:
if len(updated_requests) == 0:
raise ValueError("Batch must have at least one request")
# We assume that if len(requests) == len(self) then the requests are the same
if len(updated_requests) == len(self):
return self
device = self.input_ids.device