remove slots from grpc

This commit is contained in:
OlivierDehaene 2024-06-12 11:50:31 +02:00
parent c2fb459bc1
commit 9ac7b7bc52
8 changed files with 52 additions and 75 deletions

View File

@ -156,7 +156,6 @@ async fn prefill(
}),
top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![],
slots: vec![],
})
.collect();

View File

@ -132,8 +132,6 @@ message Request {
uint32 top_n_tokens = 7;
/// Paged attention blocks
repeated uint32 blocks = 9;
/// Paged attention slots
repeated uint32 slots = 10;
}
message Batch {
@ -208,8 +206,6 @@ message KeptRequest {
uint64 id = 1;
/// Paged attention blocks
repeated uint32 blocks = 2;
/// Paged attention slots
repeated uint32 slots = 3;
}
/// kept_requests + terminated_request_ids might not cover all requests from the

View File

@ -157,7 +157,6 @@ impl Client {
truncate,
// Blocks and slots will be set on the server side if we use paged attention
blocks: vec![],
slots: vec![],
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,

View File

@ -250,7 +250,6 @@ impl Health for ShardedClient {
top_n_tokens: 0,
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
};
let batch = Batch {
id: u64::MAX,

View File

@ -1,11 +1,12 @@
use std::cmp::min;
use std::fmt::Formatter;
use std::sync::{Arc, Mutex, TryLockError};
use thiserror::Error;
#[derive(Clone)]
pub(crate) struct BlockAllocation {
block_size: usize,
allocated_blocks: Vec<u32>,
allocated_slots: Vec<u32>,
required_blocks: usize,
required_slots: usize,
block_allocator: BlockAllocator,
@ -13,25 +14,20 @@ pub(crate) struct BlockAllocation {
impl BlockAllocation {
pub(crate) fn len(&self) -> usize {
self.allocated_slots.len()
self.allocated_blocks.len() * self.block_size
}
pub(crate) fn blocks(&self) -> &[u32] {
&self.allocated_blocks
}
pub(crate) fn slots(&self) -> &[u32] {
&self.allocated_slots
}
/// Extend an allocation by adding a new block
/// If the allocation length > window size, repeats blocks and slots to cover the
/// whole `required_blocks` and `required_slots`
pub(crate) fn extend(&mut self) -> Result<(), AllocationError> {
let (block, slots) = self.block_allocator.allocate_block()?;
let block = self.block_allocator.allocate_block()?;
// Add block and slots to current allocation
self.allocated_blocks.push(block);
self.allocated_slots.extend(slots);
if let Some(window_size) = self.block_allocator.window_size {
// if we have more slots than the window size,
@ -41,8 +37,6 @@ impl BlockAllocation {
let repeats = (self.required_slots + window_size - 1) / window_size;
self.allocated_blocks = self.allocated_blocks.repeat(repeats);
self.allocated_blocks.truncate(self.required_blocks);
self.allocated_slots = self.allocated_slots.repeat(repeats);
self.allocated_slots.truncate(self.required_slots);
}
}
@ -62,7 +56,6 @@ impl std::fmt::Debug for BlockAllocation {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BlockAllocation")
.field("allocated_blocks", &self.allocated_blocks.len())
.field("allocated_slots", &self.allocated_slots.len())
.field("required_blocks", &self.required_blocks)
.field("required_slots", &self.required_slots)
.field("block_allocator", &self.block_allocator)
@ -94,30 +87,29 @@ impl BlockAllocator {
}
}
fn allocate_block(&self) -> Result<(u32, Vec<u32>), AllocationError> {
fn allocate_block(&self) -> Result<u32, AllocationError> {
let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired");
if free_blocks.is_empty() {
return Err(AllocationError::NotEnoughPages);
}
let block_id = free_blocks.pop().unwrap();
let slots = ((block_id * self.block_size)..((block_id + 1) * self.block_size)).collect();
Ok((block_id, slots))
Ok(free_blocks.pop().unwrap())
}
/// For prompt tokens, we allocate enough blocks to cover all tokens
/// For decode tokens, we allocate block by block
/// For decode tokens, we allocate min(decode_blocks, 16) blocks
///
/// If prompt tokens + min(decode_tokens, block_size) > window size, we repeat blocks and slots
/// If allocation > window size, we repeat blocks and slots
pub(crate) fn block_allocation(
&self,
prompt_tokens: u32,
decode_tokens: u32,
) -> Result<BlockAllocation, AllocationError> {
let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size;
// prompt blocks + a single block for decode
let required_blocks = required_prompt_blocks + 1;
// prompt blocks + 16 blocks for decode
let decode_blocks = (decode_tokens + self.block_size - 1) / self.block_size;
let required_blocks = required_prompt_blocks + min(decode_blocks, 16);
let required_slots = required_blocks * self.block_size;
// Slots and blocks required for the whole request
@ -164,21 +156,9 @@ impl BlockAllocator {
allocated_blocks
};
let mut allocated_slots =
Vec::with_capacity(allocated_blocks.len() * self.block_size as usize * repeats);
'slots: for block_id in allocated_blocks.iter() {
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
allocated_slots.push(s);
if allocated_slots.len() > total_slots {
break 'slots;
}
}
}
Ok(BlockAllocation {
block_size: self.block_size as usize,
allocated_blocks,
allocated_slots,
required_blocks: total_required_blocks,
required_slots: total_slots,
block_allocator: self.clone(),

View File

@ -224,6 +224,11 @@ impl State {
}
}
// Check if max_size == 0
if max_size == Some(0) {
return None;
}
// Pad prefill_token_budget to be a multiple of block size
let prefill_token_budget =
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
@ -312,14 +317,10 @@ impl State {
// Update entry
entry.temp_span = Some(entry_batch_span);
let (blocks, slots) = match &block_allocation {
None => (Vec::new(), Vec::new()),
Some(block_allocation) => (
block_allocation.blocks().to_vec(),
block_allocation.slots().to_vec(),
),
};
let blocks = block_allocation
.as_ref()
.map(|block_allocation| block_allocation.blocks().to_vec())
.unwrap_or_default();
entry.block_allocation = block_allocation;
batch_requests.push(Request {
@ -338,7 +339,6 @@ impl State {
)),
top_n_tokens: entry.request.top_n_tokens,
blocks,
slots,
});
// Set batch_time
entry.batch_time = Some(Instant::now());

View File

@ -164,7 +164,7 @@ pub(crate) async fn batching_task(
};
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize);
let max_size = max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
// Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue
@ -382,16 +382,15 @@ async fn filter_batch(
let updated_requests = entries
.iter()
.map(|(request_id, entry)| {
let (blocks, slots) = entry
let blocks = entry
.block_allocation
.as_ref()
.map(|alloc| (alloc.blocks().to_vec(), alloc.slots().to_vec()))
.map(|alloc| alloc.blocks().to_vec())
.unwrap_or_default();
KeptRequest {
id: *request_id,
blocks,
slots,
}
})
.collect();

View File

@ -133,9 +133,12 @@ class FlashCausalLMBatch(Batch):
batch_inputs.append(concat_text_chunks(r.input_chunks.chunks))
max_truncation = max(max_truncation, r.truncate)
logger.error(batch_inputs)
batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"]
logger.error(batch_tokenized_inputs)
return batch_tokenized_inputs
@classmethod
@ -179,7 +182,7 @@ class FlashCausalLMBatch(Batch):
max_blocks = 0
block_tables = []
flat_slots = []
flat_blocks = []
# Parse batch
for i, (r, tokenized_input) in enumerate(
@ -231,24 +234,18 @@ class FlashCausalLMBatch(Batch):
request_blocks = [
b for b in range(num_blocks, num_blocks + needed_blocks)
]
request_slots = [
s
for b in request_blocks
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
]
else:
request_blocks = r.blocks
request_slots = r.slots
block_tables.append(request_blocks)
num_blocks += len(request_blocks)
request_slot_indices = torch.arange(
len(flat_slots),
len(flat_slots) + input_length,
len(flat_blocks) * BLOCK_SIZE,
(len(flat_blocks) * BLOCK_SIZE) + input_length,
dtype=torch.int64,
)
flat_slots.extend(request_slots)
flat_blocks.extend(request_blocks)
slot_indices.append(request_slot_indices)
# Create tensor to slice into the kv tensor in prefill
@ -347,7 +344,13 @@ class FlashCausalLMBatch(Batch):
top_n_tokens, device=device, dtype=torch.int64
)
slots = torch.tensor(flat_slots, dtype=torch.int64, device=device)
flat_blocks_tensor = torch.tensor(flat_blocks, dtype=torch.int64, device=device)
slots = (
(flat_blocks_tensor * BLOCK_SIZE).repeat(BLOCK_SIZE, 1).T
+ torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int64)
).flatten()
block_tables_tensor = torch.zeros(
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
)
@ -444,8 +447,8 @@ class FlashCausalLMBatch(Batch):
max_seqlen = 0
requests = []
flat_blocks = []
block_tables = []
flat_slots = []
all_input_ids = []
input_lengths = []
@ -483,16 +486,13 @@ class FlashCausalLMBatch(Batch):
top_n_tokens.append(self.top_n_tokens[idx])
request_block_table = request.blocks
num_blocks += len(request_block_table)
block_tables.append(request_block_table)
# List of slots allocated for this request
request_slots = request.slots
flat_blocks.extend(request_block_table)
# Index
slot_indices.append(len(flat_slots) + request_input_length - 1)
flat_slots.extend(request_slots)
slot_indices.append((num_blocks * BLOCK_SIZE) + request_input_length - 1)
num_blocks += len(request_block_table)
max_blocks = max(max_blocks, len(request_block_table))
# Index into tensors
@ -514,11 +514,16 @@ class FlashCausalLMBatch(Batch):
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
# Allocate on GPU
slots = torch.tensor(flat_slots, dtype=torch.int64, device=device)
slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device)
# Move to GPU
block_tables_tensor = block_tables_tensor.to(device)
flat_blocks_tensor = torch.tensor(flat_blocks, dtype=torch.int64, device=device)
slots = (
(flat_blocks_tensor * BLOCK_SIZE).repeat(BLOCK_SIZE, 1).T
+ torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int64)
).flatten()
filtered_batch = type(self)(
batch_id=self.batch_id,