Fixing more correctly the invalid drop of the batch. (#2498)
This commit is contained in:
parent
aaea212d0f
commit
c1fe28d694
|
@ -122,7 +122,7 @@ impl Backend for BackendV3 {
|
|||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn batching_task(
|
||||
mut client: ShardedClient,
|
||||
_waiting_served_ratio: f32,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
|
@ -170,8 +170,7 @@ pub(crate) async fn batching_task(
|
|||
// Minimum batch size
|
||||
// TODO: temporarily disable to avoid incorrect deallocation +
|
||||
// reallocation when using prefix caching.
|
||||
// Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||
None
|
||||
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||
};
|
||||
|
||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||
|
|
|
@ -252,17 +252,14 @@ impl State {
|
|||
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
||||
next_batch_span.follows_from(Span::current());
|
||||
|
||||
let mut batch_requests = Vec::with_capacity(self.entries.len());
|
||||
let mut batch_entries =
|
||||
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
||||
|
||||
let mut batch = Vec::with_capacity(self.entries.len());
|
||||
let mut max_input_length = 0;
|
||||
let mut prefill_tokens: u32 = 0;
|
||||
let mut decode_tokens: u32 = 0;
|
||||
let mut max_blocks = 0;
|
||||
|
||||
// Pop entries starting from the front of the queue
|
||||
'entry_loop: while let Some((id, mut entry)) = self.entries.pop_front() {
|
||||
'entry_loop: while let Some((id, entry)) = self.entries.pop_front() {
|
||||
// Filter entries where the response receiver was dropped (== entries where the request
|
||||
// was dropped by the client)
|
||||
if entry.response_tx.is_closed() {
|
||||
|
@ -276,7 +273,7 @@ impl State {
|
|||
// We pad to max input length in the Python shards
|
||||
// We need to take these padding tokens into the equation
|
||||
max_input_length = max_input_length.max(entry.request.input_length);
|
||||
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length;
|
||||
prefill_tokens = (batch.len() + 1) as u32 * max_input_length;
|
||||
|
||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||
let total_tokens = prefill_tokens + decode_tokens + self.speculate;
|
||||
|
@ -290,7 +287,7 @@ impl State {
|
|||
}
|
||||
None
|
||||
}
|
||||
Some(block_allocator) => {
|
||||
Some(_block_allocator) => {
|
||||
prefill_tokens += entry.request.input_length;
|
||||
let max_new_tokens = match self.window_size {
|
||||
None => entry.request.stopping_parameters.max_new_tokens,
|
||||
|
@ -324,13 +321,49 @@ impl State {
|
|||
entry.request.input_ids.clone()
|
||||
};
|
||||
|
||||
Some((tokens, input_ids))
|
||||
}
|
||||
};
|
||||
batch.push((id, entry, block_allocation));
|
||||
if Some(batch.len()) == max_size {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Empty batch
|
||||
if batch.is_empty() {
|
||||
tracing::debug!("Filterered out all entries");
|
||||
return None;
|
||||
}
|
||||
|
||||
// XXX We haven't allocated yet, so we're allowed to ditch the results.
|
||||
// Check if our batch is big enough
|
||||
if let Some(min_size) = min_size {
|
||||
// Batch is too small
|
||||
if batch.len() < min_size {
|
||||
// Add back entries to the queue in the correct order
|
||||
for (id, entry, _) in batch.into_iter().rev() {
|
||||
self.entries.push_front((id, entry));
|
||||
}
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
let mut batch_requests = Vec::with_capacity(self.entries.len());
|
||||
let mut batch_entries =
|
||||
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
||||
|
||||
for (id, mut entry, block_allocation) in batch {
|
||||
let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) =
|
||||
(block_allocation, &self.block_allocator)
|
||||
{
|
||||
match block_allocator.allocate(tokens, input_ids).await {
|
||||
None => {
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: not enough free blocks");
|
||||
self.entries.push_front((id, entry));
|
||||
break 'entry_loop;
|
||||
break;
|
||||
}
|
||||
Some(block_allocation) => {
|
||||
tracing::debug!("Allocation: {block_allocation:?}");
|
||||
|
@ -338,9 +371,9 @@ impl State {
|
|||
Some(block_allocation)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
tracing::debug!("Accepting entry");
|
||||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||
|
@ -400,32 +433,6 @@ impl State {
|
|||
entry.batch_time = Some(Instant::now());
|
||||
// Insert in batch_entries IntMap
|
||||
batch_entries.insert(id, entry);
|
||||
|
||||
// Check if max_size
|
||||
if Some(batch_requests.len()) == max_size {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Empty batch
|
||||
if batch_requests.is_empty() {
|
||||
tracing::debug!("Filterered out all entries");
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check if our batch is big enough
|
||||
if let Some(min_size) = min_size {
|
||||
// Batch is too small
|
||||
if batch_requests.len() < min_size {
|
||||
// Add back entries to the queue in the correct order
|
||||
for r in batch_requests.into_iter().rev() {
|
||||
let id = r.id;
|
||||
let entry = batch_entries.remove(&id).unwrap();
|
||||
self.entries.push_front((id, entry));
|
||||
}
|
||||
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
// Final batch size
|
||||
|
|
|
@ -89,6 +89,8 @@ impl Allocator for RadixAllocator {
|
|||
|
||||
let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size;
|
||||
|
||||
tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}");
|
||||
|
||||
match self.alloc_or_reclaim(suffix_blocks as usize) {
|
||||
Some(suffix_blocks) => blocks.extend(suffix_blocks),
|
||||
None => {
|
||||
|
|
Loading…
Reference in New Issue