diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index ec6fd318..f4b0f774 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -2,6 +2,7 @@ use std::{ cmp::min, collections::{hash_map::Entry, BTreeSet, HashMap}, hash::{DefaultHasher, Hash, Hasher}, + sync::Arc, }; use tokio::sync::{mpsc, oneshot}; @@ -9,12 +10,14 @@ use tokio::sync::{mpsc, oneshot}; pub(crate) struct BlockAllocation { pub blocks: Vec, pub slots: Vec, + pub allocation_id: u64, block_allocator: BlockAllocator, } impl Drop for BlockAllocation { fn drop(&mut self) { - self.block_allocator.free(self.blocks.clone()) + self.block_allocator + .free(self.blocks.clone(), self.allocation_id) } } @@ -46,11 +49,16 @@ impl BlockAllocator { } } - pub(crate) async fn allocate(&self, tokens: u32) -> Option { + pub(crate) async fn allocate( + &self, + tokens: u32, + prefill_tokens: Option>>, + ) -> Option { let (response_sender, response_receiver) = oneshot::channel(); self.block_allocator .send(BlockAllocatorCommand::Allocate { tokens, + prefill_tokens, response_sender, }) .unwrap(); @@ -58,16 +66,20 @@ impl BlockAllocator { response_receiver .await .unwrap() - .map(|(blocks, slots)| BlockAllocation { + .map(|(blocks, slots, allocation_id)| BlockAllocation { blocks, slots, + allocation_id, block_allocator: self.clone(), }) } - pub(crate) fn free(&self, blocks: Vec) { + pub(crate) fn free(&self, blocks: Vec, allocation_id: u64) { self.block_allocator - .send(BlockAllocatorCommand::Free { blocks }) + .send(BlockAllocatorCommand::Free { + allocation_id, + blocks, + }) .unwrap(); } } @@ -81,21 +93,32 @@ async fn block_allocator_task( let mut allocator = SimpleAllocator::new(blocks, block_size, window_size); while let Some(cmd) = receiver.recv().await { match cmd { - BlockAllocatorCommand::Free { blocks } => allocator.free(blocks), + BlockAllocatorCommand::Free { + blocks, + allocation_id, + } => allocator.free(blocks, allocation_id), BlockAllocatorCommand::Allocate { tokens, + prefill_tokens, response_sender, } => { - response_sender.send(allocator.allocate(tokens)).unwrap(); + let prefill_tokens_slice = prefill_tokens.as_ref().map(|p| p.as_slice()); + response_sender + .send(allocator.allocate(tokens, prefill_tokens_slice)) + .unwrap(); } } } } pub trait Allocator { - fn allocate(&mut self, tokens: u32) -> Option<(Vec, Vec)>; + fn allocate( + &mut self, + tokens: u32, + prefill_tokens: Option<&[u32]>, + ) -> Option<(Vec, Vec, u64)>; - fn free(&mut self, blocks: Vec); + fn free(&mut self, blocks: Vec, allocation_id: u64); } pub struct SimpleAllocator { @@ -116,7 +139,11 @@ impl SimpleAllocator { } impl Allocator for SimpleAllocator { - fn allocate(&mut self, tokens: u32) -> Option<(Vec, Vec)> { + fn allocate( + &mut self, + tokens: u32, + _prefill_tokens: Option<&[u32]>, + ) -> Option<(Vec, Vec, u64)> { // Apply window size let (required_blocks, repeats) = { let (tokens, repeats) = match self.window_size { @@ -150,11 +177,11 @@ impl Allocator for SimpleAllocator { } } } - Some((blocks, slots)) + Some((blocks, slots, 0)) } } - fn free(&mut self, blocks: Vec) { + fn free(&mut self, blocks: Vec, _allocation_id: u64) { self.free_blocks.extend(blocks) } } @@ -163,20 +190,15 @@ impl Allocator for SimpleAllocator { enum BlockAllocatorCommand { Free { blocks: Vec, + allocation_id: u64, }, Allocate { tokens: u32, - response_sender: oneshot::Sender, Vec)>>, + prefill_tokens: Option>>, + response_sender: oneshot::Sender, Vec, u64)>>, }, } -#[derive(Debug, Clone, Eq, PartialEq)] -pub struct BlockAllocationWithCache { - pub blocks: Vec, - pub slots: Vec, - pub allocation: u64, -} - #[derive(Debug)] struct Allocation { cache_prefixes: Vec, @@ -235,87 +257,6 @@ impl PrefixCacheAllocator { } } - pub fn alloc( - &mut self, - n_tokens: usize, - prefill_tokens: &[u32], - ) -> Option { - let mut hasher = DefaultHasher::new(); - let mut prefix_cache_blocks = Vec::new(); - - // Find hashes for all block_sized prefill chunks. - let mut prefix_hashes = Vec::new(); - for prefill_chunk in prefill_tokens.chunks(self.block_size) { - if prefill_chunk.len() < self.block_size { - break; - } - - prefill_chunk.hash(&mut hasher); - prefix_hashes.push(hasher.finish()); - } - - let mut n_from_cache = 0; - for prefix_hash in prefix_hashes.iter() { - let block_id = match self.cache_blocks.get(prefix_hash) { - Some(state) => state.block_id, - None => break, - }; - - // We have to acquire the prefixes blocks, even if the allocation fails - // later, otherwise the allocation below could garbage collect the - // prefix blocks. - self.incref_prefix(*prefix_hash); - prefix_cache_blocks.push(block_id); - n_from_cache += 1; - } - - let new_prefixes = prefix_hashes.split_off(n_from_cache); - let cache_prefixes = prefix_hashes; - - // Get tokens for the remaining prefill and decode. - let blocks = match self.alloc_or_reclaim(n_tokens - (n_from_cache * self.block_size)) { - Some(blocks) => blocks, - None => { - // If the allocation fails, we have relinquish our use of the - // prefix cache blocks. Maybe we can do this using `Drop`? - for prefix_hash in cache_prefixes { - self.decref_prefix(prefix_hash); - } - - return None; - } - }; - - prefix_cache_blocks.extend(blocks); - - let mut slots = Vec::with_capacity(n_tokens); - for block_id in prefix_cache_blocks.iter() { - for s in - (*block_id * self.block_size as u32)..((*block_id + 1) * self.block_size as u32) - { - slots.push(s); - if slots.len() == n_tokens { - break; - } - } - } - - let allocation = Allocation { - cache_prefixes, - new_prefixes, - }; - - let allocation_id = self.time; - self.time += 1; - self.allocations.insert(allocation_id, allocation); - - Some(BlockAllocationWithCache { - blocks: prefix_cache_blocks, - slots, - allocation: allocation_id, - }) - } - fn free_prefix_block(&mut self, prefix_hash: u64) { let state = self .cache_blocks @@ -368,8 +309,92 @@ impl PrefixCacheAllocator { .split_off(self.free_blocks.len() - n_blocks_needed), ) } +} - pub fn free(&mut self, blocks: &[u32], allocation: u64) { +impl Allocator for PrefixCacheAllocator { + fn allocate( + &mut self, + n_tokens: u32, + prefill_tokens: Option<&[u32]>, + ) -> Option<(Vec, Vec, u64)> { + let mut hasher = DefaultHasher::new(); + let mut prefix_cache_blocks = Vec::new(); + + // Find hashes for all block_sized prefill chunks. + let mut prefix_hashes = Vec::new(); + let mut n_from_cache = 0; + + // If we don't have a fast tokenizer, can't do prefix caching. + if let Some(prefill_tokens) = prefill_tokens { + for prefill_chunk in prefill_tokens.chunks(self.block_size) { + if prefill_chunk.len() < self.block_size { + break; + } + + prefill_chunk.hash(&mut hasher); + prefix_hashes.push(hasher.finish()); + } + + for prefix_hash in prefix_hashes.iter() { + let block_id = match self.cache_blocks.get(prefix_hash) { + Some(state) => state.block_id, + None => break, + }; + + // We have to acquire the prefixes blocks, even if the allocation fails + // later, otherwise the allocation below could garbage collect the + // prefix blocks. + self.incref_prefix(*prefix_hash); + prefix_cache_blocks.push(block_id); + n_from_cache += 1; + } + } + + let new_prefixes = prefix_hashes.split_off(n_from_cache); + let cache_prefixes = prefix_hashes; + + // Get tokens for the remaining prefill and decode. + let blocks = + match self.alloc_or_reclaim(n_tokens as usize - (n_from_cache * self.block_size)) { + Some(blocks) => blocks, + None => { + // If the allocation fails, we have relinquish our use of the + // prefix cache blocks. Maybe we can do this using `Drop`? + for prefix_hash in cache_prefixes { + self.decref_prefix(prefix_hash); + } + + return None; + } + }; + + prefix_cache_blocks.extend(blocks); + + let mut slots = Vec::with_capacity(n_tokens as usize); + for block_id in prefix_cache_blocks.iter() { + for s in + (*block_id * self.block_size as u32)..((*block_id + 1) * self.block_size as u32) + { + slots.push(s); + if slots.len() == n_tokens as usize { + break; + } + } + } + + let allocation = Allocation { + cache_prefixes, + new_prefixes, + }; + + let allocation_id = self.time; + self.time += 1; + self.allocations.insert(allocation_id, allocation); + + Some((prefix_cache_blocks, slots, allocation_id)) + } + + fn free(&mut self, blocks: Vec, allocation: u64) { let allocation = match self.allocations.remove(&allocation) { Some(allocation) => allocation, None => unreachable!("Tried to free an unknown allocation."), @@ -433,8 +458,6 @@ impl PrefixCacheAllocator { #[cfg(test)] mod tests { - use crate::infer::v3::block_allocator::BlockAllocationWithCache; - use super::PrefixCacheAllocator; #[test] diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index 894d9cab..10877504 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -298,7 +298,10 @@ impl State { + self.speculate - 1; - match block_allocator.allocate(tokens).await { + match block_allocator + .allocate(tokens, entry.request.input_ids.clone()) + .await + { None => { // Entry is over budget // Add it back to the front diff --git a/router/src/validation.rs b/router/src/validation.rs index 07ad14c9..d942dd8f 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -11,6 +11,7 @@ use rand::{thread_rng, Rng}; use serde_json::Value; use std::io::Cursor; use std::iter; +use std::sync::Arc; use text_generation_client::{Chunk, Image, InputChunk}; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; @@ -122,7 +123,7 @@ impl Validation { inputs: String, truncate: Option, max_new_tokens: Option, - ) -> Result<(Vec, usize, u32), ValidationError> { + ) -> Result<(Vec, Option>, usize, u32), ValidationError> { // If we have a fast tokenizer if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { // Create response channel @@ -157,8 +158,10 @@ impl Validation { )); } + let input_ids = encoding.get_ids()[..input_length].to_owned(); + metrics::histogram!("tgi_request_input_length").record(input_length as f64); - Ok((inputs, input_length, max_new_tokens)) + Ok((inputs, Some(input_ids), input_length, max_new_tokens)) } // Return inputs without validation else { @@ -183,6 +186,7 @@ impl Validation { Ok(( vec![Chunk::Text(inputs).into()], + None, input_length, max_new_tokens, )) @@ -319,7 +323,7 @@ impl Validation { .unwrap_or(Ok(None))?; // Validate inputs - let (inputs, input_length, max_new_tokens) = self + let (inputs, input_ids, input_length, max_new_tokens) = self .validate_input(request.inputs, truncate, max_new_tokens) .await?; @@ -388,6 +392,7 @@ impl Validation { Ok(ValidGenerateRequest { inputs, + input_ids: input_ids.map(Arc::new), decoder_input_details, input_length: input_length as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32, @@ -671,6 +676,7 @@ pub(crate) struct ValidStoppingParameters { #[derive(Debug, Clone)] pub(crate) struct ValidGenerateRequest { pub inputs: Vec, + pub input_ids: Option>>, pub input_length: u32, pub truncate: u32, pub decoder_input_details: bool,