remove slots from grpc

This commit is contained in:
OlivierDehaene 2024-06-12 11:50:31 +02:00
parent c2fb459bc1
commit 9ac7b7bc52
8 changed files with 52 additions and 75 deletions

View File

@ -156,7 +156,6 @@ async fn prefill(
}), }),
top_n_tokens: top_n_tokens.unwrap_or(0), top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![], blocks: vec![],
slots: vec![],
}) })
.collect(); .collect();

View File

@ -132,8 +132,6 @@ message Request {
uint32 top_n_tokens = 7; uint32 top_n_tokens = 7;
/// Paged attention blocks /// Paged attention blocks
repeated uint32 blocks = 9; repeated uint32 blocks = 9;
/// Paged attention slots
repeated uint32 slots = 10;
} }
message Batch { message Batch {
@ -208,8 +206,6 @@ message KeptRequest {
uint64 id = 1; uint64 id = 1;
/// Paged attention blocks /// Paged attention blocks
repeated uint32 blocks = 2; repeated uint32 blocks = 2;
/// Paged attention slots
repeated uint32 slots = 3;
} }
/// kept_requests + terminated_request_ids might not cover all requests from the /// kept_requests + terminated_request_ids might not cover all requests from the

View File

@ -157,7 +157,6 @@ impl Client {
truncate, truncate,
// Blocks and slots will be set on the server side if we use paged attention // Blocks and slots will be set on the server side if we use paged attention
blocks: vec![], blocks: vec![],
slots: vec![],
// Set sampling parameters to also take these ops into account in the max memory // Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 0.9, temperature: 0.9,

View File

@ -250,7 +250,6 @@ impl Health for ShardedClient {
top_n_tokens: 0, top_n_tokens: 0,
// Block 0 is reserved for health checks // Block 0 is reserved for health checks
blocks: vec![0], blocks: vec![0],
slots: (0..16).collect(),
}; };
let batch = Batch { let batch = Batch {
id: u64::MAX, id: u64::MAX,

View File

@ -1,11 +1,12 @@
use std::cmp::min;
use std::fmt::Formatter; use std::fmt::Formatter;
use std::sync::{Arc, Mutex, TryLockError}; use std::sync::{Arc, Mutex, TryLockError};
use thiserror::Error; use thiserror::Error;
#[derive(Clone)] #[derive(Clone)]
pub(crate) struct BlockAllocation { pub(crate) struct BlockAllocation {
block_size: usize,
allocated_blocks: Vec<u32>, allocated_blocks: Vec<u32>,
allocated_slots: Vec<u32>,
required_blocks: usize, required_blocks: usize,
required_slots: usize, required_slots: usize,
block_allocator: BlockAllocator, block_allocator: BlockAllocator,
@ -13,25 +14,20 @@ pub(crate) struct BlockAllocation {
impl BlockAllocation { impl BlockAllocation {
pub(crate) fn len(&self) -> usize { pub(crate) fn len(&self) -> usize {
self.allocated_slots.len() self.allocated_blocks.len() * self.block_size
} }
pub(crate) fn blocks(&self) -> &[u32] { pub(crate) fn blocks(&self) -> &[u32] {
&self.allocated_blocks &self.allocated_blocks
} }
pub(crate) fn slots(&self) -> &[u32] {
&self.allocated_slots
}
/// Extend an allocation by adding a new block /// Extend an allocation by adding a new block
/// If the allocation length > window size, repeats blocks and slots to cover the /// If the allocation length > window size, repeats blocks and slots to cover the
/// whole `required_blocks` and `required_slots` /// whole `required_blocks` and `required_slots`
pub(crate) fn extend(&mut self) -> Result<(), AllocationError> { pub(crate) fn extend(&mut self) -> Result<(), AllocationError> {
let (block, slots) = self.block_allocator.allocate_block()?; let block = self.block_allocator.allocate_block()?;
// Add block and slots to current allocation // Add block and slots to current allocation
self.allocated_blocks.push(block); self.allocated_blocks.push(block);
self.allocated_slots.extend(slots);
if let Some(window_size) = self.block_allocator.window_size { if let Some(window_size) = self.block_allocator.window_size {
// if we have more slots than the window size, // if we have more slots than the window size,
@ -41,8 +37,6 @@ impl BlockAllocation {
let repeats = (self.required_slots + window_size - 1) / window_size; let repeats = (self.required_slots + window_size - 1) / window_size;
self.allocated_blocks = self.allocated_blocks.repeat(repeats); self.allocated_blocks = self.allocated_blocks.repeat(repeats);
self.allocated_blocks.truncate(self.required_blocks); self.allocated_blocks.truncate(self.required_blocks);
self.allocated_slots = self.allocated_slots.repeat(repeats);
self.allocated_slots.truncate(self.required_slots);
} }
} }
@ -62,7 +56,6 @@ impl std::fmt::Debug for BlockAllocation {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BlockAllocation") f.debug_struct("BlockAllocation")
.field("allocated_blocks", &self.allocated_blocks.len()) .field("allocated_blocks", &self.allocated_blocks.len())
.field("allocated_slots", &self.allocated_slots.len())
.field("required_blocks", &self.required_blocks) .field("required_blocks", &self.required_blocks)
.field("required_slots", &self.required_slots) .field("required_slots", &self.required_slots)
.field("block_allocator", &self.block_allocator) .field("block_allocator", &self.block_allocator)
@ -94,30 +87,29 @@ impl BlockAllocator {
} }
} }
fn allocate_block(&self) -> Result<(u32, Vec<u32>), AllocationError> { fn allocate_block(&self) -> Result<u32, AllocationError> {
let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired"); let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired");
if free_blocks.is_empty() { if free_blocks.is_empty() {
return Err(AllocationError::NotEnoughPages); return Err(AllocationError::NotEnoughPages);
} }
let block_id = free_blocks.pop().unwrap(); Ok(free_blocks.pop().unwrap())
let slots = ((block_id * self.block_size)..((block_id + 1) * self.block_size)).collect();
Ok((block_id, slots))
} }
/// For prompt tokens, we allocate enough blocks to cover all tokens /// For prompt tokens, we allocate enough blocks to cover all tokens
/// For decode tokens, we allocate block by block /// For decode tokens, we allocate min(decode_blocks, 16) blocks
/// ///
/// If prompt tokens + min(decode_tokens, block_size) > window size, we repeat blocks and slots /// If allocation > window size, we repeat blocks and slots
pub(crate) fn block_allocation( pub(crate) fn block_allocation(
&self, &self,
prompt_tokens: u32, prompt_tokens: u32,
decode_tokens: u32, decode_tokens: u32,
) -> Result<BlockAllocation, AllocationError> { ) -> Result<BlockAllocation, AllocationError> {
let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size; let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size;
// prompt blocks + a single block for decode // prompt blocks + 16 blocks for decode
let required_blocks = required_prompt_blocks + 1; let decode_blocks = (decode_tokens + self.block_size - 1) / self.block_size;
let required_blocks = required_prompt_blocks + min(decode_blocks, 16);
let required_slots = required_blocks * self.block_size; let required_slots = required_blocks * self.block_size;
// Slots and blocks required for the whole request // Slots and blocks required for the whole request
@ -164,21 +156,9 @@ impl BlockAllocator {
allocated_blocks allocated_blocks
}; };
let mut allocated_slots =
Vec::with_capacity(allocated_blocks.len() * self.block_size as usize * repeats);
'slots: for block_id in allocated_blocks.iter() {
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
allocated_slots.push(s);
if allocated_slots.len() > total_slots {
break 'slots;
}
}
}
Ok(BlockAllocation { Ok(BlockAllocation {
block_size: self.block_size as usize,
allocated_blocks, allocated_blocks,
allocated_slots,
required_blocks: total_required_blocks, required_blocks: total_required_blocks,
required_slots: total_slots, required_slots: total_slots,
block_allocator: self.clone(), block_allocator: self.clone(),

View File

@ -224,6 +224,11 @@ impl State {
} }
} }
// Check if max_size == 0
if max_size == Some(0) {
return None;
}
// Pad prefill_token_budget to be a multiple of block size // Pad prefill_token_budget to be a multiple of block size
let prefill_token_budget = let prefill_token_budget =
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size; ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
@ -312,14 +317,10 @@ impl State {
// Update entry // Update entry
entry.temp_span = Some(entry_batch_span); entry.temp_span = Some(entry_batch_span);
let (blocks, slots) = match &block_allocation { let blocks = block_allocation
None => (Vec::new(), Vec::new()), .as_ref()
Some(block_allocation) => ( .map(|block_allocation| block_allocation.blocks().to_vec())
block_allocation.blocks().to_vec(), .unwrap_or_default();
block_allocation.slots().to_vec(),
),
};
entry.block_allocation = block_allocation; entry.block_allocation = block_allocation;
batch_requests.push(Request { batch_requests.push(Request {
@ -338,7 +339,6 @@ impl State {
)), )),
top_n_tokens: entry.request.top_n_tokens, top_n_tokens: entry.request.top_n_tokens,
blocks, blocks,
slots,
}); });
// Set batch_time // Set batch_time
entry.batch_time = Some(Instant::now()); entry.batch_time = Some(Instant::now());

View File

@ -164,7 +164,7 @@ pub(crate) async fn batching_task(
}; };
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize); let max_size = max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
// Try to get a new batch // Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue if let Some((mut new_entries, new_batch, span)) = queue
@ -382,16 +382,15 @@ async fn filter_batch(
let updated_requests = entries let updated_requests = entries
.iter() .iter()
.map(|(request_id, entry)| { .map(|(request_id, entry)| {
let (blocks, slots) = entry let blocks = entry
.block_allocation .block_allocation
.as_ref() .as_ref()
.map(|alloc| (alloc.blocks().to_vec(), alloc.slots().to_vec())) .map(|alloc| alloc.blocks().to_vec())
.unwrap_or_default(); .unwrap_or_default();
KeptRequest { KeptRequest {
id: *request_id, id: *request_id,
blocks, blocks,
slots,
} }
}) })
.collect(); .collect();
@ -991,10 +990,10 @@ mod tests {
content: "You are a friendly chatbot who always responds in the style of a pirate" content: "You are a friendly chatbot who always responds in the style of a pirate"
.to_string(), .to_string(),
}] }]
.iter() .iter()
.chain(&example_chat) .chain(&example_chat)
.cloned() .cloned()
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let test_default_templates = vec![ let test_default_templates = vec![
ChatTemplateTestItem { ChatTemplateTestItem {

View File

@ -133,9 +133,12 @@ class FlashCausalLMBatch(Batch):
batch_inputs.append(concat_text_chunks(r.input_chunks.chunks)) batch_inputs.append(concat_text_chunks(r.input_chunks.chunks))
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)
logger.error(batch_inputs)
batch_tokenized_inputs = tokenizer( batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"] )["input_ids"]
logger.error(batch_tokenized_inputs)
return batch_tokenized_inputs return batch_tokenized_inputs
@classmethod @classmethod
@ -179,7 +182,7 @@ class FlashCausalLMBatch(Batch):
max_blocks = 0 max_blocks = 0
block_tables = [] block_tables = []
flat_slots = [] flat_blocks = []
# Parse batch # Parse batch
for i, (r, tokenized_input) in enumerate( for i, (r, tokenized_input) in enumerate(
@ -231,24 +234,18 @@ class FlashCausalLMBatch(Batch):
request_blocks = [ request_blocks = [
b for b in range(num_blocks, num_blocks + needed_blocks) b for b in range(num_blocks, num_blocks + needed_blocks)
] ]
request_slots = [
s
for b in request_blocks
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
]
else: else:
request_blocks = r.blocks request_blocks = r.blocks
request_slots = r.slots
block_tables.append(request_blocks) block_tables.append(request_blocks)
num_blocks += len(request_blocks) num_blocks += len(request_blocks)
request_slot_indices = torch.arange( request_slot_indices = torch.arange(
len(flat_slots), len(flat_blocks) * BLOCK_SIZE,
len(flat_slots) + input_length, (len(flat_blocks) * BLOCK_SIZE) + input_length,
dtype=torch.int64, dtype=torch.int64,
) )
flat_slots.extend(request_slots) flat_blocks.extend(request_blocks)
slot_indices.append(request_slot_indices) slot_indices.append(request_slot_indices)
# Create tensor to slice into the kv tensor in prefill # Create tensor to slice into the kv tensor in prefill
@ -347,7 +344,13 @@ class FlashCausalLMBatch(Batch):
top_n_tokens, device=device, dtype=torch.int64 top_n_tokens, device=device, dtype=torch.int64
) )
slots = torch.tensor(flat_slots, dtype=torch.int64, device=device) flat_blocks_tensor = torch.tensor(flat_blocks, dtype=torch.int64, device=device)
slots = (
(flat_blocks_tensor * BLOCK_SIZE).repeat(BLOCK_SIZE, 1).T
+ torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int64)
).flatten()
block_tables_tensor = torch.zeros( block_tables_tensor = torch.zeros(
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu" (len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
) )
@ -444,8 +447,8 @@ class FlashCausalLMBatch(Batch):
max_seqlen = 0 max_seqlen = 0
requests = [] requests = []
flat_blocks = []
block_tables = [] block_tables = []
flat_slots = []
all_input_ids = [] all_input_ids = []
input_lengths = [] input_lengths = []
@ -483,16 +486,13 @@ class FlashCausalLMBatch(Batch):
top_n_tokens.append(self.top_n_tokens[idx]) top_n_tokens.append(self.top_n_tokens[idx])
request_block_table = request.blocks request_block_table = request.blocks
num_blocks += len(request_block_table)
block_tables.append(request_block_table) block_tables.append(request_block_table)
flat_blocks.extend(request_block_table)
# List of slots allocated for this request
request_slots = request.slots
# Index # Index
slot_indices.append(len(flat_slots) + request_input_length - 1) slot_indices.append((num_blocks * BLOCK_SIZE) + request_input_length - 1)
flat_slots.extend(request_slots)
num_blocks += len(request_block_table)
max_blocks = max(max_blocks, len(request_block_table)) max_blocks = max(max_blocks, len(request_block_table))
# Index into tensors # Index into tensors
@ -514,11 +514,16 @@ class FlashCausalLMBatch(Batch):
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
# Allocate on GPU # Allocate on GPU
slots = torch.tensor(flat_slots, dtype=torch.int64, device=device)
slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device) slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device)
# Move to GPU # Move to GPU
block_tables_tensor = block_tables_tensor.to(device) block_tables_tensor = block_tables_tensor.to(device)
flat_blocks_tensor = torch.tensor(flat_blocks, dtype=torch.int64, device=device)
slots = (
(flat_blocks_tensor * BLOCK_SIZE).repeat(BLOCK_SIZE, 1).T
+ torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int64)
).flatten()
filtered_batch = type(self)( filtered_batch = type(self)(
batch_id=self.batch_id, batch_id=self.batch_id,