remove slots from grpc
This commit is contained in:
parent
c2fb459bc1
commit
9ac7b7bc52
|
@ -156,7 +156,6 @@ async fn prefill(
|
|||
}),
|
||||
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
|
|
@ -132,8 +132,6 @@ message Request {
|
|||
uint32 top_n_tokens = 7;
|
||||
/// Paged attention blocks
|
||||
repeated uint32 blocks = 9;
|
||||
/// Paged attention slots
|
||||
repeated uint32 slots = 10;
|
||||
}
|
||||
|
||||
message Batch {
|
||||
|
@ -208,8 +206,6 @@ message KeptRequest {
|
|||
uint64 id = 1;
|
||||
/// Paged attention blocks
|
||||
repeated uint32 blocks = 2;
|
||||
/// Paged attention slots
|
||||
repeated uint32 slots = 3;
|
||||
}
|
||||
|
||||
/// kept_requests + terminated_request_ids might not cover all requests from the
|
||||
|
|
|
@ -157,7 +157,6 @@ impl Client {
|
|||
truncate,
|
||||
// Blocks and slots will be set on the server side if we use paged attention
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
|
|
|
@ -250,7 +250,6 @@ impl Health for ShardedClient {
|
|||
top_n_tokens: 0,
|
||||
// Block 0 is reserved for health checks
|
||||
blocks: vec![0],
|
||||
slots: (0..16).collect(),
|
||||
};
|
||||
let batch = Batch {
|
||||
id: u64::MAX,
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
use std::cmp::min;
|
||||
use std::fmt::Formatter;
|
||||
use std::sync::{Arc, Mutex, TryLockError};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct BlockAllocation {
|
||||
block_size: usize,
|
||||
allocated_blocks: Vec<u32>,
|
||||
allocated_slots: Vec<u32>,
|
||||
required_blocks: usize,
|
||||
required_slots: usize,
|
||||
block_allocator: BlockAllocator,
|
||||
|
@ -13,25 +14,20 @@ pub(crate) struct BlockAllocation {
|
|||
|
||||
impl BlockAllocation {
|
||||
pub(crate) fn len(&self) -> usize {
|
||||
self.allocated_slots.len()
|
||||
self.allocated_blocks.len() * self.block_size
|
||||
}
|
||||
|
||||
pub(crate) fn blocks(&self) -> &[u32] {
|
||||
&self.allocated_blocks
|
||||
}
|
||||
|
||||
pub(crate) fn slots(&self) -> &[u32] {
|
||||
&self.allocated_slots
|
||||
}
|
||||
|
||||
/// Extend an allocation by adding a new block
|
||||
/// If the allocation length > window size, repeats blocks and slots to cover the
|
||||
/// whole `required_blocks` and `required_slots`
|
||||
pub(crate) fn extend(&mut self) -> Result<(), AllocationError> {
|
||||
let (block, slots) = self.block_allocator.allocate_block()?;
|
||||
let block = self.block_allocator.allocate_block()?;
|
||||
// Add block and slots to current allocation
|
||||
self.allocated_blocks.push(block);
|
||||
self.allocated_slots.extend(slots);
|
||||
|
||||
if let Some(window_size) = self.block_allocator.window_size {
|
||||
// if we have more slots than the window size,
|
||||
|
@ -41,8 +37,6 @@ impl BlockAllocation {
|
|||
let repeats = (self.required_slots + window_size - 1) / window_size;
|
||||
self.allocated_blocks = self.allocated_blocks.repeat(repeats);
|
||||
self.allocated_blocks.truncate(self.required_blocks);
|
||||
self.allocated_slots = self.allocated_slots.repeat(repeats);
|
||||
self.allocated_slots.truncate(self.required_slots);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -62,7 +56,6 @@ impl std::fmt::Debug for BlockAllocation {
|
|||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("BlockAllocation")
|
||||
.field("allocated_blocks", &self.allocated_blocks.len())
|
||||
.field("allocated_slots", &self.allocated_slots.len())
|
||||
.field("required_blocks", &self.required_blocks)
|
||||
.field("required_slots", &self.required_slots)
|
||||
.field("block_allocator", &self.block_allocator)
|
||||
|
@ -94,30 +87,29 @@ impl BlockAllocator {
|
|||
}
|
||||
}
|
||||
|
||||
fn allocate_block(&self) -> Result<(u32, Vec<u32>), AllocationError> {
|
||||
fn allocate_block(&self) -> Result<u32, AllocationError> {
|
||||
let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired");
|
||||
|
||||
if free_blocks.is_empty() {
|
||||
return Err(AllocationError::NotEnoughPages);
|
||||
}
|
||||
|
||||
let block_id = free_blocks.pop().unwrap();
|
||||
let slots = ((block_id * self.block_size)..((block_id + 1) * self.block_size)).collect();
|
||||
Ok((block_id, slots))
|
||||
Ok(free_blocks.pop().unwrap())
|
||||
}
|
||||
|
||||
/// For prompt tokens, we allocate enough blocks to cover all tokens
|
||||
/// For decode tokens, we allocate block by block
|
||||
/// For decode tokens, we allocate min(decode_blocks, 16) blocks
|
||||
///
|
||||
/// If prompt tokens + min(decode_tokens, block_size) > window size, we repeat blocks and slots
|
||||
/// If allocation > window size, we repeat blocks and slots
|
||||
pub(crate) fn block_allocation(
|
||||
&self,
|
||||
prompt_tokens: u32,
|
||||
decode_tokens: u32,
|
||||
) -> Result<BlockAllocation, AllocationError> {
|
||||
let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size;
|
||||
// prompt blocks + a single block for decode
|
||||
let required_blocks = required_prompt_blocks + 1;
|
||||
// prompt blocks + 16 blocks for decode
|
||||
let decode_blocks = (decode_tokens + self.block_size - 1) / self.block_size;
|
||||
let required_blocks = required_prompt_blocks + min(decode_blocks, 16);
|
||||
let required_slots = required_blocks * self.block_size;
|
||||
|
||||
// Slots and blocks required for the whole request
|
||||
|
@ -164,21 +156,9 @@ impl BlockAllocator {
|
|||
allocated_blocks
|
||||
};
|
||||
|
||||
let mut allocated_slots =
|
||||
Vec::with_capacity(allocated_blocks.len() * self.block_size as usize * repeats);
|
||||
|
||||
'slots: for block_id in allocated_blocks.iter() {
|
||||
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
||||
allocated_slots.push(s);
|
||||
if allocated_slots.len() > total_slots {
|
||||
break 'slots;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(BlockAllocation {
|
||||
block_size: self.block_size as usize,
|
||||
allocated_blocks,
|
||||
allocated_slots,
|
||||
required_blocks: total_required_blocks,
|
||||
required_slots: total_slots,
|
||||
block_allocator: self.clone(),
|
||||
|
|
|
@ -224,6 +224,11 @@ impl State {
|
|||
}
|
||||
}
|
||||
|
||||
// Check if max_size == 0
|
||||
if max_size == Some(0) {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Pad prefill_token_budget to be a multiple of block size
|
||||
let prefill_token_budget =
|
||||
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
|
||||
|
@ -312,14 +317,10 @@ impl State {
|
|||
// Update entry
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
|
||||
let (blocks, slots) = match &block_allocation {
|
||||
None => (Vec::new(), Vec::new()),
|
||||
Some(block_allocation) => (
|
||||
block_allocation.blocks().to_vec(),
|
||||
block_allocation.slots().to_vec(),
|
||||
),
|
||||
};
|
||||
|
||||
let blocks = block_allocation
|
||||
.as_ref()
|
||||
.map(|block_allocation| block_allocation.blocks().to_vec())
|
||||
.unwrap_or_default();
|
||||
entry.block_allocation = block_allocation;
|
||||
|
||||
batch_requests.push(Request {
|
||||
|
@ -338,7 +339,6 @@ impl State {
|
|||
)),
|
||||
top_n_tokens: entry.request.top_n_tokens,
|
||||
blocks,
|
||||
slots,
|
||||
});
|
||||
// Set batch_time
|
||||
entry.batch_time = Some(Instant::now());
|
||||
|
|
|
@ -164,7 +164,7 @@ pub(crate) async fn batching_task(
|
|||
};
|
||||
|
||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||
let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize);
|
||||
let max_size = max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
|
||||
|
||||
// Try to get a new batch
|
||||
if let Some((mut new_entries, new_batch, span)) = queue
|
||||
|
@ -382,16 +382,15 @@ async fn filter_batch(
|
|||
let updated_requests = entries
|
||||
.iter()
|
||||
.map(|(request_id, entry)| {
|
||||
let (blocks, slots) = entry
|
||||
let blocks = entry
|
||||
.block_allocation
|
||||
.as_ref()
|
||||
.map(|alloc| (alloc.blocks().to_vec(), alloc.slots().to_vec()))
|
||||
.map(|alloc| alloc.blocks().to_vec())
|
||||
.unwrap_or_default();
|
||||
|
||||
KeptRequest {
|
||||
id: *request_id,
|
||||
blocks,
|
||||
slots,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
@ -991,10 +990,10 @@ mod tests {
|
|||
content: "You are a friendly chatbot who always responds in the style of a pirate"
|
||||
.to_string(),
|
||||
}]
|
||||
.iter()
|
||||
.chain(&example_chat)
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
.iter()
|
||||
.chain(&example_chat)
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let test_default_templates = vec![
|
||||
ChatTemplateTestItem {
|
||||
|
|
|
@ -133,9 +133,12 @@ class FlashCausalLMBatch(Batch):
|
|||
batch_inputs.append(concat_text_chunks(r.input_chunks.chunks))
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
logger.error(batch_inputs)
|
||||
|
||||
batch_tokenized_inputs = tokenizer(
|
||||
batch_inputs, truncation=True, max_length=max_truncation
|
||||
)["input_ids"]
|
||||
logger.error(batch_tokenized_inputs)
|
||||
return batch_tokenized_inputs
|
||||
|
||||
@classmethod
|
||||
|
@ -179,7 +182,7 @@ class FlashCausalLMBatch(Batch):
|
|||
max_blocks = 0
|
||||
|
||||
block_tables = []
|
||||
flat_slots = []
|
||||
flat_blocks = []
|
||||
|
||||
# Parse batch
|
||||
for i, (r, tokenized_input) in enumerate(
|
||||
|
@ -231,24 +234,18 @@ class FlashCausalLMBatch(Batch):
|
|||
request_blocks = [
|
||||
b for b in range(num_blocks, num_blocks + needed_blocks)
|
||||
]
|
||||
request_slots = [
|
||||
s
|
||||
for b in request_blocks
|
||||
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
|
||||
]
|
||||
else:
|
||||
request_blocks = r.blocks
|
||||
request_slots = r.slots
|
||||
|
||||
block_tables.append(request_blocks)
|
||||
num_blocks += len(request_blocks)
|
||||
|
||||
request_slot_indices = torch.arange(
|
||||
len(flat_slots),
|
||||
len(flat_slots) + input_length,
|
||||
len(flat_blocks) * BLOCK_SIZE,
|
||||
(len(flat_blocks) * BLOCK_SIZE) + input_length,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
flat_slots.extend(request_slots)
|
||||
flat_blocks.extend(request_blocks)
|
||||
slot_indices.append(request_slot_indices)
|
||||
|
||||
# Create tensor to slice into the kv tensor in prefill
|
||||
|
@ -347,7 +344,13 @@ class FlashCausalLMBatch(Batch):
|
|||
top_n_tokens, device=device, dtype=torch.int64
|
||||
)
|
||||
|
||||
slots = torch.tensor(flat_slots, dtype=torch.int64, device=device)
|
||||
flat_blocks_tensor = torch.tensor(flat_blocks, dtype=torch.int64, device=device)
|
||||
|
||||
slots = (
|
||||
(flat_blocks_tensor * BLOCK_SIZE).repeat(BLOCK_SIZE, 1).T
|
||||
+ torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int64)
|
||||
).flatten()
|
||||
|
||||
block_tables_tensor = torch.zeros(
|
||||
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
|
||||
)
|
||||
|
@ -444,8 +447,8 @@ class FlashCausalLMBatch(Batch):
|
|||
max_seqlen = 0
|
||||
|
||||
requests = []
|
||||
flat_blocks = []
|
||||
block_tables = []
|
||||
flat_slots = []
|
||||
all_input_ids = []
|
||||
|
||||
input_lengths = []
|
||||
|
@ -483,16 +486,13 @@ class FlashCausalLMBatch(Batch):
|
|||
top_n_tokens.append(self.top_n_tokens[idx])
|
||||
|
||||
request_block_table = request.blocks
|
||||
num_blocks += len(request_block_table)
|
||||
block_tables.append(request_block_table)
|
||||
|
||||
# List of slots allocated for this request
|
||||
request_slots = request.slots
|
||||
flat_blocks.extend(request_block_table)
|
||||
|
||||
# Index
|
||||
slot_indices.append(len(flat_slots) + request_input_length - 1)
|
||||
flat_slots.extend(request_slots)
|
||||
slot_indices.append((num_blocks * BLOCK_SIZE) + request_input_length - 1)
|
||||
|
||||
num_blocks += len(request_block_table)
|
||||
max_blocks = max(max_blocks, len(request_block_table))
|
||||
|
||||
# Index into tensors
|
||||
|
@ -514,11 +514,16 @@ class FlashCausalLMBatch(Batch):
|
|||
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
||||
|
||||
# Allocate on GPU
|
||||
slots = torch.tensor(flat_slots, dtype=torch.int64, device=device)
|
||||
slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device)
|
||||
|
||||
# Move to GPU
|
||||
block_tables_tensor = block_tables_tensor.to(device)
|
||||
flat_blocks_tensor = torch.tensor(flat_blocks, dtype=torch.int64, device=device)
|
||||
|
||||
slots = (
|
||||
(flat_blocks_tensor * BLOCK_SIZE).repeat(BLOCK_SIZE, 1).T
|
||||
+ torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int64)
|
||||
).flatten()
|
||||
|
||||
filtered_batch = type(self)(
|
||||
batch_id=self.batch_id,
|
||||
|
|
Loading…
Reference in New Issue