Sync allocator interfaces

This commit is contained in:
Daniël de Kok 2024-07-16 14:42:32 +00:00
parent 48b21eab7a
commit 27ef5aa029
3 changed files with 140 additions and 108 deletions

View File

@ -2,6 +2,7 @@ use std::{
cmp::min, cmp::min,
collections::{hash_map::Entry, BTreeSet, HashMap}, collections::{hash_map::Entry, BTreeSet, HashMap},
hash::{DefaultHasher, Hash, Hasher}, hash::{DefaultHasher, Hash, Hasher},
sync::Arc,
}; };
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
@ -9,12 +10,14 @@ use tokio::sync::{mpsc, oneshot};
pub(crate) struct BlockAllocation { pub(crate) struct BlockAllocation {
pub blocks: Vec<u32>, pub blocks: Vec<u32>,
pub slots: Vec<u32>, pub slots: Vec<u32>,
pub allocation_id: u64,
block_allocator: BlockAllocator, block_allocator: BlockAllocator,
} }
impl Drop for BlockAllocation { impl Drop for BlockAllocation {
fn drop(&mut self) { fn drop(&mut self) {
self.block_allocator.free(self.blocks.clone()) self.block_allocator
.free(self.blocks.clone(), self.allocation_id)
} }
} }
@ -46,11 +49,16 @@ impl BlockAllocator {
} }
} }
pub(crate) async fn allocate(&self, tokens: u32) -> Option<BlockAllocation> { pub(crate) async fn allocate(
&self,
tokens: u32,
prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation> {
let (response_sender, response_receiver) = oneshot::channel(); let (response_sender, response_receiver) = oneshot::channel();
self.block_allocator self.block_allocator
.send(BlockAllocatorCommand::Allocate { .send(BlockAllocatorCommand::Allocate {
tokens, tokens,
prefill_tokens,
response_sender, response_sender,
}) })
.unwrap(); .unwrap();
@ -58,16 +66,20 @@ impl BlockAllocator {
response_receiver response_receiver
.await .await
.unwrap() .unwrap()
.map(|(blocks, slots)| BlockAllocation { .map(|(blocks, slots, allocation_id)| BlockAllocation {
blocks, blocks,
slots, slots,
allocation_id,
block_allocator: self.clone(), block_allocator: self.clone(),
}) })
} }
pub(crate) fn free(&self, blocks: Vec<u32>) { pub(crate) fn free(&self, blocks: Vec<u32>, allocation_id: u64) {
self.block_allocator self.block_allocator
.send(BlockAllocatorCommand::Free { blocks }) .send(BlockAllocatorCommand::Free {
allocation_id,
blocks,
})
.unwrap(); .unwrap();
} }
} }
@ -81,21 +93,32 @@ async fn block_allocator_task(
let mut allocator = SimpleAllocator::new(blocks, block_size, window_size); let mut allocator = SimpleAllocator::new(blocks, block_size, window_size);
while let Some(cmd) = receiver.recv().await { while let Some(cmd) = receiver.recv().await {
match cmd { match cmd {
BlockAllocatorCommand::Free { blocks } => allocator.free(blocks), BlockAllocatorCommand::Free {
blocks,
allocation_id,
} => allocator.free(blocks, allocation_id),
BlockAllocatorCommand::Allocate { BlockAllocatorCommand::Allocate {
tokens, tokens,
prefill_tokens,
response_sender, response_sender,
} => { } => {
response_sender.send(allocator.allocate(tokens)).unwrap(); let prefill_tokens_slice = prefill_tokens.as_ref().map(|p| p.as_slice());
response_sender
.send(allocator.allocate(tokens, prefill_tokens_slice))
.unwrap();
} }
} }
} }
} }
pub trait Allocator { pub trait Allocator {
fn allocate(&mut self, tokens: u32) -> Option<(Vec<u32>, Vec<u32>)>; fn allocate(
&mut self,
tokens: u32,
prefill_tokens: Option<&[u32]>,
) -> Option<(Vec<u32>, Vec<u32>, u64)>;
fn free(&mut self, blocks: Vec<u32>); fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
} }
pub struct SimpleAllocator { pub struct SimpleAllocator {
@ -116,7 +139,11 @@ impl SimpleAllocator {
} }
impl Allocator for SimpleAllocator { impl Allocator for SimpleAllocator {
fn allocate(&mut self, tokens: u32) -> Option<(Vec<u32>, Vec<u32>)> { fn allocate(
&mut self,
tokens: u32,
_prefill_tokens: Option<&[u32]>,
) -> Option<(Vec<u32>, Vec<u32>, u64)> {
// Apply window size // Apply window size
let (required_blocks, repeats) = { let (required_blocks, repeats) = {
let (tokens, repeats) = match self.window_size { let (tokens, repeats) = match self.window_size {
@ -150,11 +177,11 @@ impl Allocator for SimpleAllocator {
} }
} }
} }
Some((blocks, slots)) Some((blocks, slots, 0))
} }
} }
fn free(&mut self, blocks: Vec<u32>) { fn free(&mut self, blocks: Vec<u32>, _allocation_id: u64) {
self.free_blocks.extend(blocks) self.free_blocks.extend(blocks)
} }
} }
@ -163,20 +190,15 @@ impl Allocator for SimpleAllocator {
enum BlockAllocatorCommand { enum BlockAllocatorCommand {
Free { Free {
blocks: Vec<u32>, blocks: Vec<u32>,
allocation_id: u64,
}, },
Allocate { Allocate {
tokens: u32, tokens: u32,
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>)>>, prefill_tokens: Option<Arc<Vec<u32>>>,
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>, u64)>>,
}, },
} }
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct BlockAllocationWithCache {
pub blocks: Vec<u32>,
pub slots: Vec<u32>,
pub allocation: u64,
}
#[derive(Debug)] #[derive(Debug)]
struct Allocation { struct Allocation {
cache_prefixes: Vec<u64>, cache_prefixes: Vec<u64>,
@ -235,87 +257,6 @@ impl PrefixCacheAllocator {
} }
} }
pub fn alloc(
&mut self,
n_tokens: usize,
prefill_tokens: &[u32],
) -> Option<BlockAllocationWithCache> {
let mut hasher = DefaultHasher::new();
let mut prefix_cache_blocks = Vec::new();
// Find hashes for all block_sized prefill chunks.
let mut prefix_hashes = Vec::new();
for prefill_chunk in prefill_tokens.chunks(self.block_size) {
if prefill_chunk.len() < self.block_size {
break;
}
prefill_chunk.hash(&mut hasher);
prefix_hashes.push(hasher.finish());
}
let mut n_from_cache = 0;
for prefix_hash in prefix_hashes.iter() {
let block_id = match self.cache_blocks.get(prefix_hash) {
Some(state) => state.block_id,
None => break,
};
// We have to acquire the prefixes blocks, even if the allocation fails
// later, otherwise the allocation below could garbage collect the
// prefix blocks.
self.incref_prefix(*prefix_hash);
prefix_cache_blocks.push(block_id);
n_from_cache += 1;
}
let new_prefixes = prefix_hashes.split_off(n_from_cache);
let cache_prefixes = prefix_hashes;
// Get tokens for the remaining prefill and decode.
let blocks = match self.alloc_or_reclaim(n_tokens - (n_from_cache * self.block_size)) {
Some(blocks) => blocks,
None => {
// If the allocation fails, we have relinquish our use of the
// prefix cache blocks. Maybe we can do this using `Drop`?
for prefix_hash in cache_prefixes {
self.decref_prefix(prefix_hash);
}
return None;
}
};
prefix_cache_blocks.extend(blocks);
let mut slots = Vec::with_capacity(n_tokens);
for block_id in prefix_cache_blocks.iter() {
for s in
(*block_id * self.block_size as u32)..((*block_id + 1) * self.block_size as u32)
{
slots.push(s);
if slots.len() == n_tokens {
break;
}
}
}
let allocation = Allocation {
cache_prefixes,
new_prefixes,
};
let allocation_id = self.time;
self.time += 1;
self.allocations.insert(allocation_id, allocation);
Some(BlockAllocationWithCache {
blocks: prefix_cache_blocks,
slots,
allocation: allocation_id,
})
}
fn free_prefix_block(&mut self, prefix_hash: u64) { fn free_prefix_block(&mut self, prefix_hash: u64) {
let state = self let state = self
.cache_blocks .cache_blocks
@ -368,8 +309,92 @@ impl PrefixCacheAllocator {
.split_off(self.free_blocks.len() - n_blocks_needed), .split_off(self.free_blocks.len() - n_blocks_needed),
) )
} }
}
pub fn free(&mut self, blocks: &[u32], allocation: u64) { impl Allocator for PrefixCacheAllocator {
fn allocate(
&mut self,
n_tokens: u32,
prefill_tokens: Option<&[u32]>,
) -> Option<(Vec<u32>, Vec<u32>, u64)> {
let mut hasher = DefaultHasher::new();
let mut prefix_cache_blocks = Vec::new();
// Find hashes for all block_sized prefill chunks.
let mut prefix_hashes = Vec::new();
let mut n_from_cache = 0;
// If we don't have a fast tokenizer, can't do prefix caching.
if let Some(prefill_tokens) = prefill_tokens {
for prefill_chunk in prefill_tokens.chunks(self.block_size) {
if prefill_chunk.len() < self.block_size {
break;
}
prefill_chunk.hash(&mut hasher);
prefix_hashes.push(hasher.finish());
}
for prefix_hash in prefix_hashes.iter() {
let block_id = match self.cache_blocks.get(prefix_hash) {
Some(state) => state.block_id,
None => break,
};
// We have to acquire the prefixes blocks, even if the allocation fails
// later, otherwise the allocation below could garbage collect the
// prefix blocks.
self.incref_prefix(*prefix_hash);
prefix_cache_blocks.push(block_id);
n_from_cache += 1;
}
}
let new_prefixes = prefix_hashes.split_off(n_from_cache);
let cache_prefixes = prefix_hashes;
// Get tokens for the remaining prefill and decode.
let blocks =
match self.alloc_or_reclaim(n_tokens as usize - (n_from_cache * self.block_size)) {
Some(blocks) => blocks,
None => {
// If the allocation fails, we have relinquish our use of the
// prefix cache blocks. Maybe we can do this using `Drop`?
for prefix_hash in cache_prefixes {
self.decref_prefix(prefix_hash);
}
return None;
}
};
prefix_cache_blocks.extend(blocks);
let mut slots = Vec::with_capacity(n_tokens as usize);
for block_id in prefix_cache_blocks.iter() {
for s in
(*block_id * self.block_size as u32)..((*block_id + 1) * self.block_size as u32)
{
slots.push(s);
if slots.len() == n_tokens as usize {
break;
}
}
}
let allocation = Allocation {
cache_prefixes,
new_prefixes,
};
let allocation_id = self.time;
self.time += 1;
self.allocations.insert(allocation_id, allocation);
Some((prefix_cache_blocks, slots, allocation_id))
}
fn free(&mut self, blocks: Vec<u32>, allocation: u64) {
let allocation = match self.allocations.remove(&allocation) { let allocation = match self.allocations.remove(&allocation) {
Some(allocation) => allocation, Some(allocation) => allocation,
None => unreachable!("Tried to free an unknown allocation."), None => unreachable!("Tried to free an unknown allocation."),
@ -433,8 +458,6 @@ impl PrefixCacheAllocator {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::infer::v3::block_allocator::BlockAllocationWithCache;
use super::PrefixCacheAllocator; use super::PrefixCacheAllocator;
#[test] #[test]

View File

@ -298,7 +298,10 @@ impl State {
+ self.speculate + self.speculate
- 1; - 1;
match block_allocator.allocate(tokens).await { match block_allocator
.allocate(tokens, entry.request.input_ids.clone())
.await
{
None => { None => {
// Entry is over budget // Entry is over budget
// Add it back to the front // Add it back to the front

View File

@ -11,6 +11,7 @@ use rand::{thread_rng, Rng};
use serde_json::Value; use serde_json::Value;
use std::io::Cursor; use std::io::Cursor;
use std::iter; use std::iter;
use std::sync::Arc;
use text_generation_client::{Chunk, Image, InputChunk}; use text_generation_client::{Chunk, Image, InputChunk};
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
@ -122,7 +123,7 @@ impl Validation {
inputs: String, inputs: String,
truncate: Option<usize>, truncate: Option<usize>,
max_new_tokens: Option<u32>, max_new_tokens: Option<u32>,
) -> Result<(Vec<InputChunk>, usize, u32), ValidationError> { ) -> Result<(Vec<InputChunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
// If we have a fast tokenizer // If we have a fast tokenizer
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
// Create response channel // Create response channel
@ -157,8 +158,10 @@ impl Validation {
)); ));
} }
let input_ids = encoding.get_ids()[..input_length].to_owned();
metrics::histogram!("tgi_request_input_length").record(input_length as f64); metrics::histogram!("tgi_request_input_length").record(input_length as f64);
Ok((inputs, input_length, max_new_tokens)) Ok((inputs, Some(input_ids), input_length, max_new_tokens))
} }
// Return inputs without validation // Return inputs without validation
else { else {
@ -183,6 +186,7 @@ impl Validation {
Ok(( Ok((
vec![Chunk::Text(inputs).into()], vec![Chunk::Text(inputs).into()],
None,
input_length, input_length,
max_new_tokens, max_new_tokens,
)) ))
@ -319,7 +323,7 @@ impl Validation {
.unwrap_or(Ok(None))?; .unwrap_or(Ok(None))?;
// Validate inputs // Validate inputs
let (inputs, input_length, max_new_tokens) = self let (inputs, input_ids, input_length, max_new_tokens) = self
.validate_input(request.inputs, truncate, max_new_tokens) .validate_input(request.inputs, truncate, max_new_tokens)
.await?; .await?;
@ -388,6 +392,7 @@ impl Validation {
Ok(ValidGenerateRequest { Ok(ValidGenerateRequest {
inputs, inputs,
input_ids: input_ids.map(Arc::new),
decoder_input_details, decoder_input_details,
input_length: input_length as u32, input_length: input_length as u32,
truncate: truncate.unwrap_or(self.max_input_length) as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32,
@ -671,6 +676,7 @@ pub(crate) struct ValidStoppingParameters {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct ValidGenerateRequest { pub(crate) struct ValidGenerateRequest {
pub inputs: Vec<InputChunk>, pub inputs: Vec<InputChunk>,
pub input_ids: Option<Arc<Vec<u32>>>,
pub input_length: u32, pub input_length: u32,
pub truncate: u32, pub truncate: u32,
pub decoder_input_details: bool, pub decoder_input_details: bool,