working example
This commit is contained in:
parent
1cc86930a6
commit
35f27cbcc1
|
@ -1,10 +1,13 @@
|
|||
use std::cmp::min;
|
||||
use std::cmp::{max, min};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct BlockAllocation {
|
||||
pub blocks: Vec<u32>,
|
||||
pub slots: Vec<u32>,
|
||||
prompt_tokens: u32,
|
||||
decode_tokens: u32,
|
||||
block_allocator: BlockAllocator,
|
||||
}
|
||||
|
||||
|
@ -12,6 +15,14 @@ impl BlockAllocation {
|
|||
pub(crate) fn len(&self) -> usize {
|
||||
self.slots.len()
|
||||
}
|
||||
|
||||
pub(crate) async fn extend(&mut self, current_length: u32) -> Result<(), AllocationError> {
|
||||
let remaining_tokens = max(self.prompt_tokens + self.decode_tokens - current_length, 1);
|
||||
self.block_allocator
|
||||
.clone()
|
||||
.extend(self, remaining_tokens)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for BlockAllocation {
|
||||
|
@ -48,11 +59,16 @@ impl BlockAllocator {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn allocate(&self, tokens: u32) -> Option<BlockAllocation> {
|
||||
pub(crate) async fn allocate(
|
||||
&self,
|
||||
prompt_tokens: u32,
|
||||
decode_tokens: u32,
|
||||
) -> Result<BlockAllocation, AllocationError> {
|
||||
let (response_sender, response_receiver) = oneshot::channel();
|
||||
self.block_allocator
|
||||
.send(BlockAllocatorCommand::Allocate {
|
||||
tokens,
|
||||
prompt_tokens,
|
||||
decode_tokens,
|
||||
response_sender,
|
||||
})
|
||||
.unwrap();
|
||||
|
@ -63,10 +79,32 @@ impl BlockAllocator {
|
|||
.map(|(blocks, slots)| BlockAllocation {
|
||||
blocks,
|
||||
slots,
|
||||
prompt_tokens,
|
||||
decode_tokens,
|
||||
block_allocator: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn extend(
|
||||
&self,
|
||||
block_allocation: &mut BlockAllocation,
|
||||
tokens: u32,
|
||||
) -> Result<(), AllocationError> {
|
||||
let (response_sender, response_receiver) = oneshot::channel();
|
||||
self.block_allocator
|
||||
.send(BlockAllocatorCommand::Allocate {
|
||||
prompt_tokens: 0,
|
||||
decode_tokens: tokens,
|
||||
response_sender,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let (blocks, slots) = response_receiver.await.unwrap()?;
|
||||
block_allocation.blocks.extend(blocks);
|
||||
block_allocation.slots.extend(slots);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn free(&self, blocks: Vec<u32>) {
|
||||
self.block_allocator
|
||||
.send(BlockAllocatorCommand::Free { blocks })
|
||||
|
@ -86,10 +124,12 @@ async fn block_allocator_task(
|
|||
match cmd {
|
||||
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks),
|
||||
BlockAllocatorCommand::Allocate {
|
||||
tokens,
|
||||
prompt_tokens,
|
||||
decode_tokens,
|
||||
response_sender,
|
||||
} => {
|
||||
// let tokens = 16;
|
||||
let decode_tokens = min(decode_tokens, block_size);
|
||||
let tokens = prompt_tokens + decode_tokens;
|
||||
|
||||
// Apply window size
|
||||
let (required_blocks, repeats) = {
|
||||
|
@ -106,9 +146,8 @@ async fn block_allocator_task(
|
|||
(required_blocks, repeats)
|
||||
};
|
||||
|
||||
let tokens = tokens as usize;
|
||||
let allocation = if required_blocks > free_blocks.len() as u32 {
|
||||
None
|
||||
Err(AllocationError::NotEnoughPages)
|
||||
} else {
|
||||
let blocks =
|
||||
free_blocks.split_off(free_blocks.len() - required_blocks as usize);
|
||||
|
@ -116,15 +155,12 @@ async fn block_allocator_task(
|
|||
(required_blocks * block_size * repeats as u32) as usize,
|
||||
);
|
||||
|
||||
'slots: for block_id in blocks.repeat(repeats).iter() {
|
||||
for block_id in blocks.repeat(repeats).iter() {
|
||||
for s in (block_id * block_size)..((block_id + 1) * block_size) {
|
||||
slots.push(s);
|
||||
if slots.len() == tokens {
|
||||
break 'slots;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some((blocks, slots))
|
||||
Ok((blocks, slots))
|
||||
};
|
||||
response_sender.send(allocation).unwrap();
|
||||
}
|
||||
|
@ -138,7 +174,15 @@ enum BlockAllocatorCommand {
|
|||
blocks: Vec<u32>,
|
||||
},
|
||||
Allocate {
|
||||
tokens: u32,
|
||||
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>)>>,
|
||||
prompt_tokens: u32,
|
||||
decode_tokens: u32,
|
||||
#[allow(clippy::type_complexity)]
|
||||
response_sender: oneshot::Sender<Result<(Vec<u32>, Vec<u32>), AllocationError>>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum AllocationError {
|
||||
#[error("Not enough pages")]
|
||||
NotEnoughPages,
|
||||
}
|
||||
|
|
|
@ -295,20 +295,20 @@ impl State {
|
|||
break;
|
||||
}
|
||||
|
||||
let tokens = entry.request.input_length
|
||||
+ entry.request.stopping_parameters.max_new_tokens
|
||||
+ self.speculate
|
||||
- 1;
|
||||
|
||||
match block_allocator.allocate(tokens).await {
|
||||
None => {
|
||||
let decode_tokens =
|
||||
entry.request.stopping_parameters.max_new_tokens + self.speculate - 1;
|
||||
match block_allocator
|
||||
.allocate(entry.request.input_length, decode_tokens)
|
||||
.await
|
||||
{
|
||||
Err(_) => {
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: not enough free blocks");
|
||||
self.entries.push_front((id, entry));
|
||||
break 'entry_loop;
|
||||
}
|
||||
Some(block_allocation) => {
|
||||
Ok(block_allocation) => {
|
||||
tracing::debug!("Allocation: {block_allocation:?}");
|
||||
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
||||
Some(block_allocation)
|
||||
|
|
|
@ -247,7 +247,7 @@ async fn prefill(
|
|||
filter_send_generations(generations, entries);
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
let next_batch = filter_batch(client, next_batch, entries, false).await;
|
||||
|
||||
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
|
||||
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
|
||||
|
@ -288,10 +288,10 @@ async fn decode(
|
|||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
filter_update_allocations(entries).await;
|
||||
let updated = filter_update_allocations(entries).await;
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
let next_batch = filter_batch(client, next_batch, entries, updated).await;
|
||||
|
||||
if let Some(concat_duration) = timings.concat {
|
||||
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
|
||||
|
@ -322,11 +322,12 @@ async fn filter_batch(
|
|||
client: &mut ShardedClient,
|
||||
next_batch: Option<CachedBatch>,
|
||||
entries: &IntMap<u64, Entry>,
|
||||
force_update: bool,
|
||||
) -> Option<CachedBatch> {
|
||||
let batch = next_batch?;
|
||||
|
||||
// No need to filter
|
||||
if batch.size as usize == entries.len() {
|
||||
if batch.size as usize == entries.len() && !force_update {
|
||||
return Some(batch);
|
||||
}
|
||||
|
||||
|
@ -348,6 +349,7 @@ async fn filter_batch(
|
|||
.as_ref()
|
||||
.map(|alloc| (alloc.blocks.clone(), alloc.slots.clone()))
|
||||
.unwrap_or((Vec::new(), Vec::new()));
|
||||
|
||||
UpdatedRequest {
|
||||
id: *request_id,
|
||||
blocks,
|
||||
|
@ -393,34 +395,58 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
|||
/// Check if block allocations need to be extended
|
||||
/// If we don't have enough blocks, request will be filtered with an OutOfPages error
|
||||
#[instrument(skip_all)]
|
||||
async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) {
|
||||
entries.retain(|request_id, entry| {
|
||||
if entry.block_allocation.is_none() {
|
||||
return true;
|
||||
async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) -> bool {
|
||||
let ids: Vec<u64> = entries
|
||||
.iter()
|
||||
.filter_map(|(id, entry)| {
|
||||
entry
|
||||
.block_allocation
|
||||
.as_ref()
|
||||
.map(|block_allocation| {
|
||||
if entry.current_length > block_allocation.len() as u32 {
|
||||
// We need to re-allocate
|
||||
Some(*id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap_or(None)
|
||||
})
|
||||
.collect();
|
||||
|
||||
for id in ids.iter() {
|
||||
// Get entry
|
||||
// We can `expect` here as the request id should always be in the entries
|
||||
let extension = {
|
||||
let entry = entries
|
||||
.get_mut(id)
|
||||
.expect("ID not found in entries. This is a bug.");
|
||||
entry
|
||||
.block_allocation
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.extend(entry.current_length)
|
||||
.await
|
||||
};
|
||||
|
||||
if extension.is_err() {
|
||||
let entry = entries
|
||||
.remove(id)
|
||||
.expect("ID not found in entries. This is a bug.");
|
||||
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||
let err = InferError::OutOfPages;
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "out_of_pages");
|
||||
tracing::error!("{err}");
|
||||
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
entry.response_tx.send(Err(err)).unwrap_or(());
|
||||
}
|
||||
}
|
||||
|
||||
// We can unwrap since we already validated above that block_allocation is not None
|
||||
let mut block_allocation = entry.block_allocation.as_ref().unwrap();
|
||||
|
||||
// Nothing to update
|
||||
if entry.current_length <= block_allocation.len() as u32 {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||
let err = InferError::OutOfPages;
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "out_of_pages");
|
||||
tracing::error!("{err}");
|
||||
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
entry
|
||||
.response_tx
|
||||
.send(Err(err))
|
||||
.unwrap_or(());
|
||||
|
||||
false
|
||||
});
|
||||
// If ids is not empty, we need to update
|
||||
!ids.is_empty()
|
||||
}
|
||||
|
||||
/// Send responses through the `entry` response channel
|
||||
|
|
|
@ -402,9 +402,6 @@ class FlashCausalLMBatch(Batch):
|
|||
) -> Optional["FlashCausalLMBatch"]:
|
||||
if len(updated_requests) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
# We assume that if len(requests) == len(self) then the requests are the same
|
||||
if len(updated_requests) == len(self):
|
||||
return self
|
||||
|
||||
device = self.input_ids.device
|
||||
|
||||
|
|
Loading…
Reference in New Issue