remove slots from grpc
This commit is contained in:
parent
c2fb459bc1
commit
9ac7b7bc52
|
@ -156,7 +156,6 @@ async fn prefill(
|
||||||
}),
|
}),
|
||||||
top_n_tokens: top_n_tokens.unwrap_or(0),
|
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||||
blocks: vec![],
|
blocks: vec![],
|
||||||
slots: vec![],
|
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
|
|
@ -132,8 +132,6 @@ message Request {
|
||||||
uint32 top_n_tokens = 7;
|
uint32 top_n_tokens = 7;
|
||||||
/// Paged attention blocks
|
/// Paged attention blocks
|
||||||
repeated uint32 blocks = 9;
|
repeated uint32 blocks = 9;
|
||||||
/// Paged attention slots
|
|
||||||
repeated uint32 slots = 10;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
message Batch {
|
message Batch {
|
||||||
|
@ -208,8 +206,6 @@ message KeptRequest {
|
||||||
uint64 id = 1;
|
uint64 id = 1;
|
||||||
/// Paged attention blocks
|
/// Paged attention blocks
|
||||||
repeated uint32 blocks = 2;
|
repeated uint32 blocks = 2;
|
||||||
/// Paged attention slots
|
|
||||||
repeated uint32 slots = 3;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// kept_requests + terminated_request_ids might not cover all requests from the
|
/// kept_requests + terminated_request_ids might not cover all requests from the
|
||||||
|
|
|
@ -157,7 +157,6 @@ impl Client {
|
||||||
truncate,
|
truncate,
|
||||||
// Blocks and slots will be set on the server side if we use paged attention
|
// Blocks and slots will be set on the server side if we use paged attention
|
||||||
blocks: vec![],
|
blocks: vec![],
|
||||||
slots: vec![],
|
|
||||||
// Set sampling parameters to also take these ops into account in the max memory
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 0.9,
|
temperature: 0.9,
|
||||||
|
|
|
@ -250,7 +250,6 @@ impl Health for ShardedClient {
|
||||||
top_n_tokens: 0,
|
top_n_tokens: 0,
|
||||||
// Block 0 is reserved for health checks
|
// Block 0 is reserved for health checks
|
||||||
blocks: vec![0],
|
blocks: vec![0],
|
||||||
slots: (0..16).collect(),
|
|
||||||
};
|
};
|
||||||
let batch = Batch {
|
let batch = Batch {
|
||||||
id: u64::MAX,
|
id: u64::MAX,
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
|
use std::cmp::min;
|
||||||
use std::fmt::Formatter;
|
use std::fmt::Formatter;
|
||||||
use std::sync::{Arc, Mutex, TryLockError};
|
use std::sync::{Arc, Mutex, TryLockError};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub(crate) struct BlockAllocation {
|
pub(crate) struct BlockAllocation {
|
||||||
|
block_size: usize,
|
||||||
allocated_blocks: Vec<u32>,
|
allocated_blocks: Vec<u32>,
|
||||||
allocated_slots: Vec<u32>,
|
|
||||||
required_blocks: usize,
|
required_blocks: usize,
|
||||||
required_slots: usize,
|
required_slots: usize,
|
||||||
block_allocator: BlockAllocator,
|
block_allocator: BlockAllocator,
|
||||||
|
@ -13,25 +14,20 @@ pub(crate) struct BlockAllocation {
|
||||||
|
|
||||||
impl BlockAllocation {
|
impl BlockAllocation {
|
||||||
pub(crate) fn len(&self) -> usize {
|
pub(crate) fn len(&self) -> usize {
|
||||||
self.allocated_slots.len()
|
self.allocated_blocks.len() * self.block_size
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn blocks(&self) -> &[u32] {
|
pub(crate) fn blocks(&self) -> &[u32] {
|
||||||
&self.allocated_blocks
|
&self.allocated_blocks
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn slots(&self) -> &[u32] {
|
|
||||||
&self.allocated_slots
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Extend an allocation by adding a new block
|
/// Extend an allocation by adding a new block
|
||||||
/// If the allocation length > window size, repeats blocks and slots to cover the
|
/// If the allocation length > window size, repeats blocks and slots to cover the
|
||||||
/// whole `required_blocks` and `required_slots`
|
/// whole `required_blocks` and `required_slots`
|
||||||
pub(crate) fn extend(&mut self) -> Result<(), AllocationError> {
|
pub(crate) fn extend(&mut self) -> Result<(), AllocationError> {
|
||||||
let (block, slots) = self.block_allocator.allocate_block()?;
|
let block = self.block_allocator.allocate_block()?;
|
||||||
// Add block and slots to current allocation
|
// Add block and slots to current allocation
|
||||||
self.allocated_blocks.push(block);
|
self.allocated_blocks.push(block);
|
||||||
self.allocated_slots.extend(slots);
|
|
||||||
|
|
||||||
if let Some(window_size) = self.block_allocator.window_size {
|
if let Some(window_size) = self.block_allocator.window_size {
|
||||||
// if we have more slots than the window size,
|
// if we have more slots than the window size,
|
||||||
|
@ -41,8 +37,6 @@ impl BlockAllocation {
|
||||||
let repeats = (self.required_slots + window_size - 1) / window_size;
|
let repeats = (self.required_slots + window_size - 1) / window_size;
|
||||||
self.allocated_blocks = self.allocated_blocks.repeat(repeats);
|
self.allocated_blocks = self.allocated_blocks.repeat(repeats);
|
||||||
self.allocated_blocks.truncate(self.required_blocks);
|
self.allocated_blocks.truncate(self.required_blocks);
|
||||||
self.allocated_slots = self.allocated_slots.repeat(repeats);
|
|
||||||
self.allocated_slots.truncate(self.required_slots);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -62,7 +56,6 @@ impl std::fmt::Debug for BlockAllocation {
|
||||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||||
f.debug_struct("BlockAllocation")
|
f.debug_struct("BlockAllocation")
|
||||||
.field("allocated_blocks", &self.allocated_blocks.len())
|
.field("allocated_blocks", &self.allocated_blocks.len())
|
||||||
.field("allocated_slots", &self.allocated_slots.len())
|
|
||||||
.field("required_blocks", &self.required_blocks)
|
.field("required_blocks", &self.required_blocks)
|
||||||
.field("required_slots", &self.required_slots)
|
.field("required_slots", &self.required_slots)
|
||||||
.field("block_allocator", &self.block_allocator)
|
.field("block_allocator", &self.block_allocator)
|
||||||
|
@ -94,30 +87,29 @@ impl BlockAllocator {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn allocate_block(&self) -> Result<(u32, Vec<u32>), AllocationError> {
|
fn allocate_block(&self) -> Result<u32, AllocationError> {
|
||||||
let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired");
|
let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired");
|
||||||
|
|
||||||
if free_blocks.is_empty() {
|
if free_blocks.is_empty() {
|
||||||
return Err(AllocationError::NotEnoughPages);
|
return Err(AllocationError::NotEnoughPages);
|
||||||
}
|
}
|
||||||
|
|
||||||
let block_id = free_blocks.pop().unwrap();
|
Ok(free_blocks.pop().unwrap())
|
||||||
let slots = ((block_id * self.block_size)..((block_id + 1) * self.block_size)).collect();
|
|
||||||
Ok((block_id, slots))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// For prompt tokens, we allocate enough blocks to cover all tokens
|
/// For prompt tokens, we allocate enough blocks to cover all tokens
|
||||||
/// For decode tokens, we allocate block by block
|
/// For decode tokens, we allocate min(decode_blocks, 16) blocks
|
||||||
///
|
///
|
||||||
/// If prompt tokens + min(decode_tokens, block_size) > window size, we repeat blocks and slots
|
/// If allocation > window size, we repeat blocks and slots
|
||||||
pub(crate) fn block_allocation(
|
pub(crate) fn block_allocation(
|
||||||
&self,
|
&self,
|
||||||
prompt_tokens: u32,
|
prompt_tokens: u32,
|
||||||
decode_tokens: u32,
|
decode_tokens: u32,
|
||||||
) -> Result<BlockAllocation, AllocationError> {
|
) -> Result<BlockAllocation, AllocationError> {
|
||||||
let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size;
|
let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size;
|
||||||
// prompt blocks + a single block for decode
|
// prompt blocks + 16 blocks for decode
|
||||||
let required_blocks = required_prompt_blocks + 1;
|
let decode_blocks = (decode_tokens + self.block_size - 1) / self.block_size;
|
||||||
|
let required_blocks = required_prompt_blocks + min(decode_blocks, 16);
|
||||||
let required_slots = required_blocks * self.block_size;
|
let required_slots = required_blocks * self.block_size;
|
||||||
|
|
||||||
// Slots and blocks required for the whole request
|
// Slots and blocks required for the whole request
|
||||||
|
@ -164,21 +156,9 @@ impl BlockAllocator {
|
||||||
allocated_blocks
|
allocated_blocks
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut allocated_slots =
|
|
||||||
Vec::with_capacity(allocated_blocks.len() * self.block_size as usize * repeats);
|
|
||||||
|
|
||||||
'slots: for block_id in allocated_blocks.iter() {
|
|
||||||
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
|
||||||
allocated_slots.push(s);
|
|
||||||
if allocated_slots.len() > total_slots {
|
|
||||||
break 'slots;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(BlockAllocation {
|
Ok(BlockAllocation {
|
||||||
|
block_size: self.block_size as usize,
|
||||||
allocated_blocks,
|
allocated_blocks,
|
||||||
allocated_slots,
|
|
||||||
required_blocks: total_required_blocks,
|
required_blocks: total_required_blocks,
|
||||||
required_slots: total_slots,
|
required_slots: total_slots,
|
||||||
block_allocator: self.clone(),
|
block_allocator: self.clone(),
|
||||||
|
|
|
@ -224,6 +224,11 @@ impl State {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if max_size == 0
|
||||||
|
if max_size == Some(0) {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
// Pad prefill_token_budget to be a multiple of block size
|
// Pad prefill_token_budget to be a multiple of block size
|
||||||
let prefill_token_budget =
|
let prefill_token_budget =
|
||||||
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
|
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
|
||||||
|
@ -312,14 +317,10 @@ impl State {
|
||||||
// Update entry
|
// Update entry
|
||||||
entry.temp_span = Some(entry_batch_span);
|
entry.temp_span = Some(entry_batch_span);
|
||||||
|
|
||||||
let (blocks, slots) = match &block_allocation {
|
let blocks = block_allocation
|
||||||
None => (Vec::new(), Vec::new()),
|
.as_ref()
|
||||||
Some(block_allocation) => (
|
.map(|block_allocation| block_allocation.blocks().to_vec())
|
||||||
block_allocation.blocks().to_vec(),
|
.unwrap_or_default();
|
||||||
block_allocation.slots().to_vec(),
|
|
||||||
),
|
|
||||||
};
|
|
||||||
|
|
||||||
entry.block_allocation = block_allocation;
|
entry.block_allocation = block_allocation;
|
||||||
|
|
||||||
batch_requests.push(Request {
|
batch_requests.push(Request {
|
||||||
|
@ -338,7 +339,6 @@ impl State {
|
||||||
)),
|
)),
|
||||||
top_n_tokens: entry.request.top_n_tokens,
|
top_n_tokens: entry.request.top_n_tokens,
|
||||||
blocks,
|
blocks,
|
||||||
slots,
|
|
||||||
});
|
});
|
||||||
// Set batch_time
|
// Set batch_time
|
||||||
entry.batch_time = Some(Instant::now());
|
entry.batch_time = Some(Instant::now());
|
||||||
|
|
|
@ -164,7 +164,7 @@ pub(crate) async fn batching_task(
|
||||||
};
|
};
|
||||||
|
|
||||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||||
let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize);
|
let max_size = max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
|
||||||
|
|
||||||
// Try to get a new batch
|
// Try to get a new batch
|
||||||
if let Some((mut new_entries, new_batch, span)) = queue
|
if let Some((mut new_entries, new_batch, span)) = queue
|
||||||
|
@ -382,16 +382,15 @@ async fn filter_batch(
|
||||||
let updated_requests = entries
|
let updated_requests = entries
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(request_id, entry)| {
|
.map(|(request_id, entry)| {
|
||||||
let (blocks, slots) = entry
|
let blocks = entry
|
||||||
.block_allocation
|
.block_allocation
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|alloc| (alloc.blocks().to_vec(), alloc.slots().to_vec()))
|
.map(|alloc| alloc.blocks().to_vec())
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
|
||||||
KeptRequest {
|
KeptRequest {
|
||||||
id: *request_id,
|
id: *request_id,
|
||||||
blocks,
|
blocks,
|
||||||
slots,
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
|
@ -133,9 +133,12 @@ class FlashCausalLMBatch(Batch):
|
||||||
batch_inputs.append(concat_text_chunks(r.input_chunks.chunks))
|
batch_inputs.append(concat_text_chunks(r.input_chunks.chunks))
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
|
|
||||||
|
logger.error(batch_inputs)
|
||||||
|
|
||||||
batch_tokenized_inputs = tokenizer(
|
batch_tokenized_inputs = tokenizer(
|
||||||
batch_inputs, truncation=True, max_length=max_truncation
|
batch_inputs, truncation=True, max_length=max_truncation
|
||||||
)["input_ids"]
|
)["input_ids"]
|
||||||
|
logger.error(batch_tokenized_inputs)
|
||||||
return batch_tokenized_inputs
|
return batch_tokenized_inputs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -179,7 +182,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
max_blocks = 0
|
max_blocks = 0
|
||||||
|
|
||||||
block_tables = []
|
block_tables = []
|
||||||
flat_slots = []
|
flat_blocks = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
for i, (r, tokenized_input) in enumerate(
|
for i, (r, tokenized_input) in enumerate(
|
||||||
|
@ -231,24 +234,18 @@ class FlashCausalLMBatch(Batch):
|
||||||
request_blocks = [
|
request_blocks = [
|
||||||
b for b in range(num_blocks, num_blocks + needed_blocks)
|
b for b in range(num_blocks, num_blocks + needed_blocks)
|
||||||
]
|
]
|
||||||
request_slots = [
|
|
||||||
s
|
|
||||||
for b in request_blocks
|
|
||||||
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
request_blocks = r.blocks
|
request_blocks = r.blocks
|
||||||
request_slots = r.slots
|
|
||||||
|
|
||||||
block_tables.append(request_blocks)
|
block_tables.append(request_blocks)
|
||||||
num_blocks += len(request_blocks)
|
num_blocks += len(request_blocks)
|
||||||
|
|
||||||
request_slot_indices = torch.arange(
|
request_slot_indices = torch.arange(
|
||||||
len(flat_slots),
|
len(flat_blocks) * BLOCK_SIZE,
|
||||||
len(flat_slots) + input_length,
|
(len(flat_blocks) * BLOCK_SIZE) + input_length,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
flat_slots.extend(request_slots)
|
flat_blocks.extend(request_blocks)
|
||||||
slot_indices.append(request_slot_indices)
|
slot_indices.append(request_slot_indices)
|
||||||
|
|
||||||
# Create tensor to slice into the kv tensor in prefill
|
# Create tensor to slice into the kv tensor in prefill
|
||||||
|
@ -347,7 +344,13 @@ class FlashCausalLMBatch(Batch):
|
||||||
top_n_tokens, device=device, dtype=torch.int64
|
top_n_tokens, device=device, dtype=torch.int64
|
||||||
)
|
)
|
||||||
|
|
||||||
slots = torch.tensor(flat_slots, dtype=torch.int64, device=device)
|
flat_blocks_tensor = torch.tensor(flat_blocks, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
|
slots = (
|
||||||
|
(flat_blocks_tensor * BLOCK_SIZE).repeat(BLOCK_SIZE, 1).T
|
||||||
|
+ torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int64)
|
||||||
|
).flatten()
|
||||||
|
|
||||||
block_tables_tensor = torch.zeros(
|
block_tables_tensor = torch.zeros(
|
||||||
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
|
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
|
||||||
)
|
)
|
||||||
|
@ -444,8 +447,8 @@ class FlashCausalLMBatch(Batch):
|
||||||
max_seqlen = 0
|
max_seqlen = 0
|
||||||
|
|
||||||
requests = []
|
requests = []
|
||||||
|
flat_blocks = []
|
||||||
block_tables = []
|
block_tables = []
|
||||||
flat_slots = []
|
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
|
@ -483,16 +486,13 @@ class FlashCausalLMBatch(Batch):
|
||||||
top_n_tokens.append(self.top_n_tokens[idx])
|
top_n_tokens.append(self.top_n_tokens[idx])
|
||||||
|
|
||||||
request_block_table = request.blocks
|
request_block_table = request.blocks
|
||||||
num_blocks += len(request_block_table)
|
|
||||||
block_tables.append(request_block_table)
|
block_tables.append(request_block_table)
|
||||||
|
flat_blocks.extend(request_block_table)
|
||||||
# List of slots allocated for this request
|
|
||||||
request_slots = request.slots
|
|
||||||
|
|
||||||
# Index
|
# Index
|
||||||
slot_indices.append(len(flat_slots) + request_input_length - 1)
|
slot_indices.append((num_blocks * BLOCK_SIZE) + request_input_length - 1)
|
||||||
flat_slots.extend(request_slots)
|
|
||||||
|
|
||||||
|
num_blocks += len(request_block_table)
|
||||||
max_blocks = max(max_blocks, len(request_block_table))
|
max_blocks = max(max_blocks, len(request_block_table))
|
||||||
|
|
||||||
# Index into tensors
|
# Index into tensors
|
||||||
|
@ -514,11 +514,16 @@ class FlashCausalLMBatch(Batch):
|
||||||
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
||||||
|
|
||||||
# Allocate on GPU
|
# Allocate on GPU
|
||||||
slots = torch.tensor(flat_slots, dtype=torch.int64, device=device)
|
|
||||||
slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device)
|
slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
# Move to GPU
|
# Move to GPU
|
||||||
block_tables_tensor = block_tables_tensor.to(device)
|
block_tables_tensor = block_tables_tensor.to(device)
|
||||||
|
flat_blocks_tensor = torch.tensor(flat_blocks, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
|
slots = (
|
||||||
|
(flat_blocks_tensor * BLOCK_SIZE).repeat(BLOCK_SIZE, 1).T
|
||||||
|
+ torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int64)
|
||||||
|
).flatten()
|
||||||
|
|
||||||
filtered_batch = type(self)(
|
filtered_batch = type(self)(
|
||||||
batch_id=self.batch_id,
|
batch_id=self.batch_id,
|
||||||
|
|
Loading…
Reference in New Issue