Fixing more correctly the invalid drop of the batch. (#2498)
This commit is contained in:
parent
aaea212d0f
commit
c1fe28d694
|
@ -122,7 +122,7 @@ impl Backend for BackendV3 {
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub(crate) async fn batching_task(
|
pub(crate) async fn batching_task(
|
||||||
mut client: ShardedClient,
|
mut client: ShardedClient,
|
||||||
_waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
max_batch_prefill_tokens: u32,
|
max_batch_prefill_tokens: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
|
@ -170,8 +170,7 @@ pub(crate) async fn batching_task(
|
||||||
// Minimum batch size
|
// Minimum batch size
|
||||||
// TODO: temporarily disable to avoid incorrect deallocation +
|
// TODO: temporarily disable to avoid incorrect deallocation +
|
||||||
// reallocation when using prefix caching.
|
// reallocation when using prefix caching.
|
||||||
// Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||||
None
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||||
|
|
|
@ -252,17 +252,14 @@ impl State {
|
||||||
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
||||||
next_batch_span.follows_from(Span::current());
|
next_batch_span.follows_from(Span::current());
|
||||||
|
|
||||||
let mut batch_requests = Vec::with_capacity(self.entries.len());
|
let mut batch = Vec::with_capacity(self.entries.len());
|
||||||
let mut batch_entries =
|
|
||||||
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
|
||||||
|
|
||||||
let mut max_input_length = 0;
|
let mut max_input_length = 0;
|
||||||
let mut prefill_tokens: u32 = 0;
|
let mut prefill_tokens: u32 = 0;
|
||||||
let mut decode_tokens: u32 = 0;
|
let mut decode_tokens: u32 = 0;
|
||||||
let mut max_blocks = 0;
|
let mut max_blocks = 0;
|
||||||
|
|
||||||
// Pop entries starting from the front of the queue
|
// Pop entries starting from the front of the queue
|
||||||
'entry_loop: while let Some((id, mut entry)) = self.entries.pop_front() {
|
'entry_loop: while let Some((id, entry)) = self.entries.pop_front() {
|
||||||
// Filter entries where the response receiver was dropped (== entries where the request
|
// Filter entries where the response receiver was dropped (== entries where the request
|
||||||
// was dropped by the client)
|
// was dropped by the client)
|
||||||
if entry.response_tx.is_closed() {
|
if entry.response_tx.is_closed() {
|
||||||
|
@ -276,7 +273,7 @@ impl State {
|
||||||
// We pad to max input length in the Python shards
|
// We pad to max input length in the Python shards
|
||||||
// We need to take these padding tokens into the equation
|
// We need to take these padding tokens into the equation
|
||||||
max_input_length = max_input_length.max(entry.request.input_length);
|
max_input_length = max_input_length.max(entry.request.input_length);
|
||||||
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length;
|
prefill_tokens = (batch.len() + 1) as u32 * max_input_length;
|
||||||
|
|
||||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||||
let total_tokens = prefill_tokens + decode_tokens + self.speculate;
|
let total_tokens = prefill_tokens + decode_tokens + self.speculate;
|
||||||
|
@ -290,7 +287,7 @@ impl State {
|
||||||
}
|
}
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
Some(block_allocator) => {
|
Some(_block_allocator) => {
|
||||||
prefill_tokens += entry.request.input_length;
|
prefill_tokens += entry.request.input_length;
|
||||||
let max_new_tokens = match self.window_size {
|
let max_new_tokens = match self.window_size {
|
||||||
None => entry.request.stopping_parameters.max_new_tokens,
|
None => entry.request.stopping_parameters.max_new_tokens,
|
||||||
|
@ -324,23 +321,59 @@ impl State {
|
||||||
entry.request.input_ids.clone()
|
entry.request.input_ids.clone()
|
||||||
};
|
};
|
||||||
|
|
||||||
match block_allocator.allocate(tokens, input_ids).await {
|
Some((tokens, input_ids))
|
||||||
None => {
|
|
||||||
// Entry is over budget
|
|
||||||
// Add it back to the front
|
|
||||||
tracing::debug!("Over budget: not enough free blocks");
|
|
||||||
self.entries.push_front((id, entry));
|
|
||||||
break 'entry_loop;
|
|
||||||
}
|
|
||||||
Some(block_allocation) => {
|
|
||||||
tracing::debug!("Allocation: {block_allocation:?}");
|
|
||||||
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
|
||||||
Some(block_allocation)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
batch.push((id, entry, block_allocation));
|
||||||
|
if Some(batch.len()) == max_size {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty batch
|
||||||
|
if batch.is_empty() {
|
||||||
|
tracing::debug!("Filterered out all entries");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// XXX We haven't allocated yet, so we're allowed to ditch the results.
|
||||||
|
// Check if our batch is big enough
|
||||||
|
if let Some(min_size) = min_size {
|
||||||
|
// Batch is too small
|
||||||
|
if batch.len() < min_size {
|
||||||
|
// Add back entries to the queue in the correct order
|
||||||
|
for (id, entry, _) in batch.into_iter().rev() {
|
||||||
|
self.entries.push_front((id, entry));
|
||||||
|
}
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut batch_requests = Vec::with_capacity(self.entries.len());
|
||||||
|
let mut batch_entries =
|
||||||
|
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
||||||
|
|
||||||
|
for (id, mut entry, block_allocation) in batch {
|
||||||
|
let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) =
|
||||||
|
(block_allocation, &self.block_allocator)
|
||||||
|
{
|
||||||
|
match block_allocator.allocate(tokens, input_ids).await {
|
||||||
|
None => {
|
||||||
|
// Entry is over budget
|
||||||
|
// Add it back to the front
|
||||||
|
tracing::debug!("Over budget: not enough free blocks");
|
||||||
|
self.entries.push_front((id, entry));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Some(block_allocation) => {
|
||||||
|
tracing::debug!("Allocation: {block_allocation:?}");
|
||||||
|
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
||||||
|
Some(block_allocation)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
tracing::debug!("Accepting entry");
|
tracing::debug!("Accepting entry");
|
||||||
// Create a new span to link the batch back to this entry
|
// Create a new span to link the batch back to this entry
|
||||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||||
|
@ -400,32 +433,6 @@ impl State {
|
||||||
entry.batch_time = Some(Instant::now());
|
entry.batch_time = Some(Instant::now());
|
||||||
// Insert in batch_entries IntMap
|
// Insert in batch_entries IntMap
|
||||||
batch_entries.insert(id, entry);
|
batch_entries.insert(id, entry);
|
||||||
|
|
||||||
// Check if max_size
|
|
||||||
if Some(batch_requests.len()) == max_size {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Empty batch
|
|
||||||
if batch_requests.is_empty() {
|
|
||||||
tracing::debug!("Filterered out all entries");
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if our batch is big enough
|
|
||||||
if let Some(min_size) = min_size {
|
|
||||||
// Batch is too small
|
|
||||||
if batch_requests.len() < min_size {
|
|
||||||
// Add back entries to the queue in the correct order
|
|
||||||
for r in batch_requests.into_iter().rev() {
|
|
||||||
let id = r.id;
|
|
||||||
let entry = batch_entries.remove(&id).unwrap();
|
|
||||||
self.entries.push_front((id, entry));
|
|
||||||
}
|
|
||||||
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Final batch size
|
// Final batch size
|
||||||
|
|
|
@ -89,6 +89,8 @@ impl Allocator for RadixAllocator {
|
||||||
|
|
||||||
let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size;
|
let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size;
|
||||||
|
|
||||||
|
tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}");
|
||||||
|
|
||||||
match self.alloc_or_reclaim(suffix_blocks as usize) {
|
match self.alloc_or_reclaim(suffix_blocks as usize) {
|
||||||
Some(suffix_blocks) => blocks.extend(suffix_blocks),
|
Some(suffix_blocks) => blocks.extend(suffix_blocks),
|
||||||
None => {
|
None => {
|
||||||
|
|
Loading…
Reference in New Issue