From 7735b385dc21d896955d4af88fbadfc88a70abd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Fri, 9 Aug 2024 11:47:14 +0000 Subject: [PATCH] Prefix caching WIP --- Cargo.lock | 1 + backends/client/src/v3/client.rs | 1 + backends/client/src/v3/sharded_client.rs | 1 + backends/v3/Cargo.toml | 1 + backends/v3/src/backend.rs | 10 + backends/v3/src/block_allocator.rs | 182 +++-- backends/v3/src/client/grpc_client.rs | 1 + backends/v3/src/client/sharded_client.rs | 1 + backends/v3/src/lib.rs | 1 + backends/v3/src/queue.rs | 53 +- backends/v3/src/radix.rs | 755 ++++++++++++++++++ benchmark/src/generation.rs | 1 + proto/v3/generate.proto | 356 ++++----- router/src/validation.rs | 19 +- .../layers/attention/__init__.py | 7 +- .../layers/attention/cuda.py | 25 +- .../layers/attention/flash_infer.py | 76 ++ .../custom_modeling/flash_cohere_modeling.py | 2 + .../custom_modeling/flash_dbrx_modeling.py | 2 + .../flash_deepseek_v2_modeling.py | 2 + .../custom_modeling/flash_gemma2_modeling.py | 2 + .../custom_modeling/flash_gemma_modeling.py | 2 + .../custom_modeling/flash_gpt2_modeling.py | 2 + .../custom_modeling/flash_llama_modeling.py | 3 + .../custom_modeling/flash_mistral_modeling.py | 2 + .../custom_modeling/flash_mixtral_modeling.py | 2 + .../custom_modeling/flash_neox_modeling.py | 2 + .../custom_modeling/flash_phi_modeling.py | 2 + .../custom_modeling/flash_qwen2_modeling.py | 2 + .../custom_modeling/flash_rw_modeling.py | 4 + .../flash_santacoder_modeling.py | 2 + .../flash_starcoder2_modeling.py | 2 + .../models/flash_causal_lm.py | 218 ++++- .../text_generation_server/models/globals.py | 16 +- 34 files changed, 1451 insertions(+), 307 deletions(-) create mode 100644 backends/v3/src/radix.rs diff --git a/Cargo.lock b/Cargo.lock index 92367d1e..3a5845a7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4045,6 +4045,7 @@ dependencies = [ "reqwest", "serde", "serde_json", + "slotmap", "text-generation-router", "thiserror", "tokenizers", diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs index a996b14f..b321278c 100644 --- a/backends/client/src/v3/client.rs +++ b/backends/client/src/v3/client.rs @@ -156,6 +156,7 @@ impl Client { // Blocks and slots will be set on the server side if we use paged attention blocks: vec![], slots: vec![], + prefix_len: 0, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, diff --git a/backends/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs index ae8a899b..1cc173e3 100644 --- a/backends/client/src/v3/sharded_client.rs +++ b/backends/client/src/v3/sharded_client.rs @@ -244,6 +244,7 @@ impl Health for ShardedClient { // Block 0 is reserved for health checks blocks: vec![0], slots: (0..16).collect(), + prefix_len: 0, adapter_id: None, }; let batch = Batch { diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml index 5d9a140b..129ceb9c 100644 --- a/backends/v3/Cargo.toml +++ b/backends/v3/Cargo.toml @@ -33,6 +33,7 @@ rand = "0.8.5" reqwest = { version = "0.11.20", features = [] } serde = "1.0.188" serde_json = "1.0.107" +slotmap = "1.0.7" thiserror = "1.0.48" tokenizers = { workspace = true} tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 68ddf00b..cbcbff72 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -35,15 +35,24 @@ impl BackendV3 { window_size: Option, speculate: u32, ) -> Self { + let prefix_caching = if let Ok(prefix_caching) = std::env::var("USE_PREFIX_CACHING") { + matches!(prefix_caching.as_str(), "true" | "1") + } else { + false + }; let attention = if let Ok(attention) = std::env::var("ATTENTION") { attention .parse() .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")) + } else if prefix_caching { + Attention::FlashInfer } else { Attention::Paged }; let block_size = if attention == Attention::FlashDecoding { 256 + } else if attention == Attention::FlashInfer { + 1 } else { 16 }; @@ -51,6 +60,7 @@ impl BackendV3 { let queue = Queue::new( requires_padding, block_size, + prefix_caching, window_size, speculate, max_batch_total_tokens, diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index 7467fd85..05c2bd30 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -1,16 +1,26 @@ -use std::cmp::min; +use std::{cmp::min, sync::Arc}; use tokio::sync::{mpsc, oneshot}; +use crate::radix::RadixAllocator; + #[derive(Debug, Clone)] pub(crate) struct BlockAllocation { + pub allocation_id: u64, pub blocks: Vec, pub slots: Vec, - block_allocator: BlockAllocator, + + /// Prefix that was cached and for which the KV does not have to + /// be recomputed. + pub prefix_len: u32, + + pub(crate) block_allocator: Option, } impl Drop for BlockAllocation { fn drop(&mut self) { - self.block_allocator.free(self.blocks.clone()) + if let Some(block_allocator) = self.block_allocator.as_mut() { + block_allocator.free(self.blocks.clone(), self.allocation_id) + } } } @@ -24,6 +34,7 @@ impl BlockAllocator { pub(crate) fn new( max_batch_total_tokens: u32, block_size: u32, + prefix_caching: bool, window_size: Option, ) -> Self { // Create channel @@ -33,6 +44,7 @@ impl BlockAllocator { tokio::spawn(block_allocator_task( max_batch_total_tokens / block_size, block_size, + prefix_caching, window_size, receiver, )); @@ -42,28 +54,32 @@ impl BlockAllocator { } } - pub(crate) async fn allocate(&self, tokens: u32) -> Option { + pub(crate) async fn allocate( + &self, + tokens: u32, + prefill_tokens: Option>>, + ) -> Option { let (response_sender, response_receiver) = oneshot::channel(); self.block_allocator .send(BlockAllocatorCommand::Allocate { tokens, + prefill_tokens, response_sender, }) .unwrap(); - response_receiver - .await - .unwrap() - .map(|(blocks, slots)| BlockAllocation { - blocks, - slots, - block_allocator: self.clone(), - }) + response_receiver.await.unwrap().map(|mut allocation| { + allocation.block_allocator = Some(self.clone()); + allocation + }) } - pub(crate) fn free(&self, blocks: Vec) { + pub(crate) fn free(&self, blocks: Vec, allocation_id: u64) { self.block_allocator - .send(BlockAllocatorCommand::Free { blocks }) + .send(BlockAllocatorCommand::Free { + allocation_id, + blocks, + }) .unwrap(); } } @@ -71,54 +87,29 @@ impl BlockAllocator { async fn block_allocator_task( blocks: u32, block_size: u32, + prefix_caching: bool, window_size: Option, mut receiver: mpsc::UnboundedReceiver, ) { - // Block 0 is reserved for health checks - let mut free_blocks: Vec = (1..blocks).collect(); + let mut allocator: Box = if prefix_caching { + Box::new(RadixAllocator::new(block_size, blocks, window_size)) + } else { + Box::new(SimpleAllocator::new(blocks, block_size, window_size)) + }; while let Some(cmd) = receiver.recv().await { match cmd { - BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks), + BlockAllocatorCommand::Free { + blocks, + allocation_id, + } => allocator.free(blocks, allocation_id), BlockAllocatorCommand::Allocate { tokens, + prefill_tokens, response_sender, } => { - // Apply window size - let (required_blocks, repeats) = { - let (tokens, repeats) = match window_size { - None => (tokens, 1), - Some(window_size) => { - let repeats = (tokens + window_size - 1) / window_size; - let tokens = min(tokens, window_size); - (tokens, repeats as usize) - } - }; - // Pad to a multiple of block size - let required_blocks = (tokens + block_size - 1) / block_size; - (required_blocks, repeats) - }; - - let tokens = tokens as usize; - let allocation = if required_blocks > free_blocks.len() as u32 { - None - } else { - let blocks = - free_blocks.split_off(free_blocks.len() - required_blocks as usize); - let mut slots = Vec::with_capacity( - (required_blocks * block_size * repeats as u32) as usize, - ); - - 'slots: for block_id in blocks.repeat(repeats).iter() { - for s in (block_id * block_size)..((block_id + 1) * block_size) { - slots.push(s); - if slots.len() == tokens { - break 'slots; - } - } - } - Some((blocks, slots)) - }; - response_sender.send(allocation).unwrap(); + response_sender + .send(allocator.allocate(tokens, prefill_tokens)) + .unwrap(); } } } @@ -128,9 +119,92 @@ async fn block_allocator_task( enum BlockAllocatorCommand { Free { blocks: Vec, + allocation_id: u64, }, Allocate { tokens: u32, - response_sender: oneshot::Sender, Vec)>>, + prefill_tokens: Option>>, + response_sender: oneshot::Sender>, }, } + +pub(crate) trait Allocator { + fn allocate( + &mut self, + tokens: u32, + prefill_tokens: Option>>, + ) -> Option; + + fn free(&mut self, blocks: Vec, allocation_id: u64); +} + +pub struct SimpleAllocator { + free_blocks: Vec, + block_size: u32, + window_size: Option, +} + +impl SimpleAllocator { + fn new(blocks: u32, block_size: u32, window_size: Option) -> Self { + SimpleAllocator { + block_size, + // Block 0 is reserved for health checks + free_blocks: (1..blocks).collect(), + window_size, + } + } +} + +impl Allocator for SimpleAllocator { + fn allocate( + &mut self, + tokens: u32, + _prefill_tokens: Option>>, + ) -> Option { + // Apply window size + let (required_blocks, repeats) = { + let (tokens, repeats) = match self.window_size { + None => (tokens, 1), + Some(window_size) => { + let repeats = (tokens + window_size - 1) / window_size; + let tokens = min(tokens, window_size); + (tokens, repeats as usize) + } + }; + // Pad to a multiple of block size + let required_blocks = (tokens + self.block_size - 1) / self.block_size; + (required_blocks, repeats) + }; + + let tokens = tokens as usize; + if required_blocks > self.free_blocks.len() as u32 { + None + } else { + let blocks = self + .free_blocks + .split_off(self.free_blocks.len() - required_blocks as usize); + let mut slots = + Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize); + + 'slots: for block_id in blocks.repeat(repeats).iter() { + for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { + slots.push(s); + if slots.len() == tokens { + break 'slots; + } + } + } + Some(BlockAllocation { + allocation_id: 0, + blocks, + slots, + prefix_len: 0, + block_allocator: None, + }) + } + } + + fn free(&mut self, blocks: Vec, _allocation_id: u64) { + self.free_blocks.extend(blocks) + } +} diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs index c407687b..6282759e 100644 --- a/backends/v3/src/client/grpc_client.rs +++ b/backends/v3/src/client/grpc_client.rs @@ -157,6 +157,7 @@ impl Client { // Blocks and slots will be set on the server side if we use paged attention blocks: vec![], slots: vec![], + prefix_len: 0, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index afb13cdc..2f78da03 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -245,6 +245,7 @@ impl Health for ShardedClient { // Block 0 is reserved for health checks blocks: vec![0], slots: (0..16).collect(), + prefix_len: 0, adapter_id: None, }; let batch = Batch { diff --git a/backends/v3/src/lib.rs b/backends/v3/src/lib.rs index a6f89169..c8fc55f8 100644 --- a/backends/v3/src/lib.rs +++ b/backends/v3/src/lib.rs @@ -2,6 +2,7 @@ mod backend; mod block_allocator; mod client; mod queue; +mod radix; use crate::client::{ClientError, ShardedClient}; pub(crate) use backend::BackendV3; diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index b457389c..13544235 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -46,6 +46,7 @@ impl Queue { pub(crate) fn new( requires_padding: bool, block_size: u32, + prefix_caching: bool, window_size: Option, speculate: u32, max_batch_total_tokens: u32, @@ -57,6 +58,7 @@ impl Queue { tokio::spawn(queue_task( requires_padding, block_size, + prefix_caching, window_size, speculate, max_batch_total_tokens, @@ -109,6 +111,7 @@ impl Queue { async fn queue_task( requires_padding: bool, block_size: u32, + prefix_caching: bool, window_size: Option, speculate: u32, max_batch_total_tokens: u32, @@ -117,6 +120,7 @@ async fn queue_task( let mut state = State::new( requires_padding, block_size, + prefix_caching, window_size, speculate, max_batch_total_tokens, @@ -176,12 +180,19 @@ impl State { fn new( requires_padding: bool, block_size: u32, + prefix_caching: bool, window_size: Option, speculate: u32, max_batch_total_tokens: u32, ) -> Self { - let block_allocator = (!requires_padding) - .then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size)); + let block_allocator = (!requires_padding).then(|| { + BlockAllocator::new( + max_batch_total_tokens, + block_size, + prefix_caching, + window_size, + ) + }); Self { entries: VecDeque::with_capacity(128), @@ -305,7 +316,10 @@ impl State { + self.speculate - 1; - match block_allocator.allocate(tokens).await { + match block_allocator + .allocate(tokens, entry.request.input_ids.clone()) + .await + { None => { // Entry is over budget // Add it back to the front @@ -331,11 +345,12 @@ impl State { // Update entry entry.temp_span = Some(entry_batch_span); - let (blocks, slots) = match &block_allocation { - None => (Vec::new(), Vec::new()), + let (blocks, slots, prefix_len) = match &block_allocation { + None => (Vec::new(), Vec::new(), 0), Some(block_allocation) => ( block_allocation.blocks.clone(), block_allocation.slots.clone(), + block_allocation.prefix_len, ), }; @@ -372,6 +387,7 @@ impl State { top_n_tokens: entry.request.top_n_tokens, blocks, slots, + prefix_len, adapter_id: entry.request.adapter_id.clone(), }); // Set batch_time @@ -480,6 +496,8 @@ impl From for StoppingCriteriaParameters { #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use tracing::info_span; @@ -492,6 +510,7 @@ mod tests { let entry = Entry { request: ValidGenerateRequest { inputs: vec![], + input_ids: Some(Arc::new(vec![])), input_length: 0, truncate: 0, decoder_input_details: false, @@ -527,7 +546,7 @@ mod tests { #[tokio::test] async fn test_append() { - let mut state = State::new(false, 1, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16); let (entry, _guard) = default_entry(); assert_eq!(state.next_id, 0); @@ -543,7 +562,7 @@ mod tests { #[tokio::test] async fn test_next_batch_empty() { - let mut state = State::new(false, 1, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16); assert!(state.next_batch(None, None, 1, 1).await.is_none()); assert!(state.next_batch(Some(1), None, 1, 1).await.is_none()); @@ -551,7 +570,7 @@ mod tests { #[tokio::test] async fn test_next_batch_min_size() { - let mut state = State::new(false, 1, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -583,7 +602,7 @@ mod tests { #[tokio::test] async fn test_next_batch_max_size() { - let mut state = State::new(false, 1, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -603,7 +622,7 @@ mod tests { #[tokio::test] async fn test_next_batch_token_budget() { - let mut state = State::new(false, 1, None, 0, 2); + let mut state = State::new(false, 1, false, None, 0, 2); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -636,14 +655,14 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); assert!(queue.next_batch(None, None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); @@ -651,7 +670,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -684,7 +703,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_max_size() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -700,7 +719,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_budget() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -725,7 +744,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_speculate() { - let queue = Queue::new(false, 1, None, 2, 16); + let queue = Queue::new(false, 1, false, None, 2, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -744,7 +763,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); let (entry, _) = default_entry(); queue.append(entry); diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs new file mode 100644 index 00000000..0464b9f8 --- /dev/null +++ b/backends/v3/src/radix.rs @@ -0,0 +1,755 @@ +use std::{ + collections::{BTreeSet, HashMap}, + sync::Arc, +}; + +use slotmap::{DefaultKey, SlotMap}; + +use crate::block_allocator::{Allocator, BlockAllocation}; + +pub struct RadixAllocator { + allocation_id: u64, + + allocations: HashMap, + + cache_blocks: RadixTrie, + + /// Blocks that are immediately available for allocation. + free_blocks: Vec, +} + +impl RadixAllocator { + pub fn new(block_size: u32, n_blocks: u32, window_size: Option) -> Self { + assert_eq!( + block_size, 1, + "Radix tree allocator only works with block_size=1, was: {}", + block_size + ); + if window_size.is_some() { + unimplemented!("Window size not supported in the prefix-caching block allocator yet"); + } + + RadixAllocator { + allocation_id: 0, + allocations: HashMap::new(), + cache_blocks: RadixTrie::new(), + + // Block 0 is reserved for health checks. + free_blocks: (1..n_blocks).collect(), + } + } + + fn alloc_or_reclaim(&mut self, n_blocks_needed: usize) -> Option> { + if self.free_blocks.len() < n_blocks_needed { + // This is a bit annoying, we first extend the free list and then + // split it off again below. This is because we need to put it on + // the free list if we cannot allocate enough blocks. This is only + // temporary, the trie needs to be able to report whether it can + // allocate the requested amount. Just not implemented yet. + self.free_blocks.extend( + self.cache_blocks + .evict(n_blocks_needed - self.free_blocks.len()), + ); + } + + if self.free_blocks.len() >= n_blocks_needed { + Some( + self.free_blocks + .split_off(self.free_blocks.len() - n_blocks_needed), + ) + } else { + None + } + } +} + +impl Allocator for RadixAllocator { + fn allocate( + &mut self, + tokens: u32, + prefill_tokens: Option>>, + ) -> Option { + let mut blocks = vec![]; + let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() { + let node_id = self + .cache_blocks + .find(prefill_tokens.as_slice(), &mut blocks); + // Even if this allocation fails below, we need to increase he + // refcount to ensure that the prefix that was found is not evicted. + + node_id + } else { + self.cache_blocks.root_id() + }; + + self.cache_blocks + .incref(prefix_node) + .expect("Failed to increment refcount"); + + let prefix_len = blocks.len(); + let suffix_len = tokens - prefix_len as u32; + + match self.alloc_or_reclaim(suffix_len as usize) { + Some(suffix_blocks) => blocks.extend(suffix_blocks), + None => { + self.cache_blocks + .decref(prefix_node) + .expect("Failed to decrement refcount"); + return None; + } + } + + // 1:1 mapping of blocks and slots. + let slots = blocks.clone(); + + let allocation = RadixAllocation { + prefix_node, + cached_prefix_len: prefix_len, + prefill_tokens: prefill_tokens.clone(), + }; + + self.allocation_id += 1; + self.allocations.insert(self.allocation_id, allocation); + + Some(BlockAllocation { + allocation_id: self.allocation_id, + block_allocator: None, + blocks, + slots, + prefix_len: prefix_len as u32, + }) + } + + fn free(&mut self, blocks: Vec, allocation_id: u64) { + let allocation = match self.allocations.remove(&allocation_id) { + Some(allocation) => allocation, + None => unreachable!("Tried to free an unknown allocation."), + }; + + self.cache_blocks + .decref(allocation.prefix_node) + .expect("Failed to decrement refcount"); + + if let Some(prefill_tokens) = allocation.prefill_tokens { + let prefill_tokens = prefill_tokens.as_slice(); + + // If there are prefill tokens that did not come from the cache, + // add them to the cache. + if prefill_tokens.len() > allocation.cached_prefix_len { + let prefix_len = self + .cache_blocks + .insert(prefill_tokens, &blocks[..prefill_tokens.len()]) + // Unwrap, failing is a programming error. + .expect("Failed to store prefill tokens"); + + // We can have a prefill with the following structure: + // + // |---| From the prefix cache. + // A B C D E F G + //|--------| Found in the trie during insertion. + // + // This means that while processing this request there was a + // partially overlapping request that had A..=E in its + // prefill. In this case we need to free the blocks D E. + self.free_blocks + .extend(&blocks[allocation.cached_prefix_len..prefix_len]); + } + + // Free non-prefill blocks. + self.free_blocks.extend(&blocks[prefill_tokens.len()..]); + } else { + self.free_blocks.extend(blocks); + } + } +} + +struct RadixAllocation { + prefix_node: NodeId, + cached_prefix_len: usize, + prefill_tokens: Option>>, +} + +// Radix trie that is heavily inspired by radix attention from sglang. +// +// The trie is optimized for prefix caching: +// +// - A normal radix trie stores discrete values. In this radix trie, +// inserting *abc* with value *xyz* will also enable lookup for +// *a* (*x*) and *ab* (*xy*). +// - As a result, every value is required to have the same length as +// the key. +// - We store additional information in each node, such as last access +// time and a reference count. + +#[derive(Debug)] +pub enum TrieError { + InvalidNodeId, + RefCountUnderflow, + BlockTokenCountMismatch, +} + +pub type NodeId = DefaultKey; + +#[derive(Debug)] +pub struct RadixTrie { + /// Identifier of the root nod. + root: DefaultKey, + + /// Leave node identifiers ordered by increasing recency. + leaves: BTreeSet<(u64, NodeId)>, + + /// All trie nodes. + nodes: SlotMap, + + /// Time as a monotonically increating counter to avoid the system + /// call that a real time lookup would require. + time: u64, +} + +impl RadixTrie { + /// Construct a new radix trie. + pub fn new() -> Self { + let root = TrieNode::new(vec![], vec![], 0, None); + let mut nodes = SlotMap::new(); + let root = nodes.insert(root); + RadixTrie { + leaves: BTreeSet::new(), + nodes, + root, + time: 0, + } + } + + /// Find the prefix of the given tokens. + /// + /// The blocks corresponding to the part of the prefix that could be found + /// are writteng to `blocks`. The number of blocks is in `0..=tokens.len()`. + /// Returns the identifier of the trie node that contains the longest + /// prefix. The node identifier can be used by callers to e.g. increase its + /// reference count. + /// + /// Using this method will update the access time of the traversed nodes. + pub fn find(&mut self, key: &[u32], blocks: &mut Vec) -> NodeId { + self.time += 1; + self.find_(self.root, key, blocks) + } + + /// Find worker. + fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec) -> NodeId { + let node = &self.nodes[node_id]; + + if let Some(&child_id) = node.children.get(&key[0]) { + self.update_access_time(child_id); + let child = self.nodes.get(child_id).expect("Invalid child identifier"); + let shared_prefix_len = child.key.shared_prefix_len(key); + blocks.extend(&child.blocks[..shared_prefix_len]); + + let key = &key[shared_prefix_len..]; + if !key.is_empty() { + node_id = self.find_(child_id, key, blocks); + } + } + + node_id + } + + /// Decrease the reference count of a node. + pub fn decref(&mut self, node_id: NodeId) -> Result<(), TrieError> { + // We don't care about refcounting for root, since it will never + // be evicted. + if node_id == self.root { + return Ok(()); + } + + let node = self + .nodes + .get_mut(node_id) + .ok_or(TrieError::InvalidNodeId)?; + if node.ref_count == 0 { + return Err(TrieError::RefCountUnderflow); + } + + node.ref_count -= 1; + if node.ref_count == 0 { + self.leaves.insert((node.last_accessed, node_id)); + } + + Ok(()) + } + + /// Increase the reference count of a node. + pub fn incref(&mut self, node_id: NodeId) -> Result<(), TrieError> { + if node_id == self.root { + return Ok(()); + } + + let node = self + .nodes + .get_mut(node_id) + .ok_or(TrieError::InvalidNodeId)?; + if node.ref_count == 0 { + self.leaves.remove(&(node.last_accessed, node_id)); + } + node.ref_count += 1; + + Ok(()) + } + + /// Evict `n_blocks` from the trie. + /// + /// Returns the evicted blocks. When the length is less than `n_blocks`, + /// not enough blocks could beevicted. + pub fn evict(&mut self, n_blocks: usize) -> Vec { + // NOTE: we don't return Result here. If any of the unwrapping fails, + // it's a programming error in the trie implementation, not a user + // error caused by e.g. an invalid argument. + + // TODO: add some bookkeeping in the future to check whether we can + // evict n_blocks and return `None` if we can't. We are now needlessly + // evicting prefixes from the cache in such a case. + let mut evicted = Vec::new(); + + while let Some((last_access, node_id)) = self.leaves.pop_first() { + let blocks_needed = n_blocks - evicted.len(); + + let node = self.nodes.get(node_id).expect("Leave does not exist"); + if blocks_needed >= node.blocks.len() { + // We need to evict the whole node if we need more blocks than it has. + let node = self.remove_node(node_id); + evicted.extend(node.blocks); + + if evicted.len() >= n_blocks { + break; + } + } else { + // The node has more blocks than needed, so we'll just remove + // the required number of blocks and leave the remaining blocks + // untouched. + let node = self.nodes.get_mut(node_id).expect("Leave does not exist"); + node.key.truncate(node.blocks.len() - blocks_needed); + evicted.extend(node.blocks.split_off(node.blocks.len() - blocks_needed)); + self.leaves.insert((last_access, node_id)); + break; + } + } + + evicted + } + + /// Insert a prefill along with its blocks. + /// + /// This method returns the length of the prefix that was already + /// in the trie. E.g. if the length is 10, this means that for + /// the first 10 elements of the tree **the blocks are not updated**. + pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result { + self.time += 1; + self.insert_(self.root, tokens, blocks) + } + + /// Insertion worker. + fn insert_( + &mut self, + node_id: NodeId, + tokens: &[u32], + blocks: &[u32], + ) -> Result { + // TODO: in the future we may want to check that the blocks match for + // the part of the prefix that is already in the trie to detect + // mismatches. + + if tokens.len() != blocks.len() { + return Err(TrieError::BlockTokenCountMismatch); + } + + if let Some(&child_id) = self.nodes[node_id].children.get(&tokens[0]) { + self.update_access_time(child_id); + let child = self + .nodes + .get_mut(child_id) + // Unwrap here, since failure is a bug. + .expect("Child node does not exist"); + let shared_prefix_len = child.key.shared_prefix_len(tokens); + + // We are done, the prefix is already in the trie. + if shared_prefix_len == tokens.len() { + return Ok(shared_prefix_len); + } + + // The node's prefix is a prefix of the insertion prefix. + if shared_prefix_len == child.key.len() { + return Ok(shared_prefix_len + + self.insert_( + child_id, + &tokens[shared_prefix_len..], + &blocks[shared_prefix_len..], + )?); + } + + // The node's prefix and the insertion prefix only match partially, + // split the node to just contain the matching part. Then insert the + // remainder of the prefix into the node again + let child_id = self.split_node(child_id, shared_prefix_len); + let key = &tokens[shared_prefix_len..]; + let blocks = &blocks[shared_prefix_len..]; + Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?) + } else { + self.add_node(node_id, tokens, blocks); + Ok(0) + } + } + + fn split_node(&mut self, node_id: NodeId, prefix_len: usize) -> NodeId { + // We have to make the current node a child to ensure that its + // properties and node id stay the same. + + // This funcion unwraps, an invalid node_id is a programming error. + + let node = self + .nodes + .get_mut(node_id) + .expect("Node to-be split does not exist"); + let mut parent_key = node.key.split_off(prefix_len); + let mut parent_blocks = node.blocks.split_off(prefix_len); + + // Move first part of the prefix to the parent. We swap to avoid + // an allocation + copy for both splits of the key/blocks. + std::mem::swap(&mut node.key, &mut parent_key); + std::mem::swap(&mut node.blocks, &mut parent_blocks); + + let node_key = node.key[0]; + + let grandparent_id = node.parent.expect("Node does not have a parent"); + let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks); + self.add_node_to_parent(parent_id, node_key, node_id); + + // Reborrow to make the borrow checker happy. + let node = self + .nodes + .get_mut(node_id) + .expect("Node to-be split does not exist"); + node.parent = Some(parent_id); + + parent_id + } + + /// Create a node and add it to the parent. + fn add_node( + &mut self, + parent_id: NodeId, + key: impl Into>, + blocks: impl Into>, + ) -> NodeId { + let key = key.into(); + let blocks = blocks.into(); + let first = key[0]; + + let child = TrieNode::new(key, blocks, self.time, Some(parent_id)); + let child_id = self.nodes.insert(child); + + self.add_node_to_parent(parent_id, first, child_id); + self.leaves.insert((self.time, child_id)); + + child_id + } + + /// Add a node to the parent. + fn add_node_to_parent(&mut self, parent_id: NodeId, first: u32, child_id: NodeId) { + // Unwrap here, passing in an unknown id is a programming error. + let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); + if parent.children.insert(first, child_id).is_none() { + // Only increase reference count if child does not replace another child. + self.incref(parent_id) + .expect("Failed to increase parent refcount"); + } + } + + /// Remove a node from the trie. + fn remove_node(&mut self, node_id: NodeId) -> TrieNode { + // Unwrap here, passing in an unknown id is a programming error. + let node = self.nodes.remove(node_id).expect("Unknown node"); + let parent_id = node.parent.expect("Attempted to remove root node"); + let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); + parent.children.remove(&node.key[0]); + self.decref(parent_id) + .expect("Failed to decrease parent refcount"); + self.nodes.remove(node_id); + node + } + + fn update_access_time(&mut self, node_id: NodeId) { + // Unwrap here, passing in an unknown id is a programming error. + let node = self.nodes.get_mut(node_id).expect("Unknown node"); + + // Update the ordered leaves set if the node is a leave. + if self.leaves.remove(&(node.last_accessed, node_id)) { + self.leaves.insert((self.time, node_id)); + } + + node.last_accessed = self.time; + } + + #[allow(dead_code)] + #[doc(hidden)] + /// Print debugging output for the trie. + /// + /// In contrast to `Debug` nicely formatted. + pub fn print_debug(&self) { + self.print_debug_(self.root, 0); + } + + fn print_debug_(&self, node_id: NodeId, indent: usize) { + let node = &self.nodes[node_id]; + eprintln!( + "{}{:?}, key: {:?}, blocks: {:?}, ref_count: {}, last_accessed: {}, parent: {:?}, children: {:?}", + " ".repeat(indent), + node_id, + node.key, + node.blocks, + node.ref_count, + node.last_accessed, + node.parent, + node.children + ); + for child_id in self.nodes[node_id].children.values() { + self.print_debug_(*child_id, indent + 2); + } + } + + pub(crate) fn root_id(&self) -> DefaultKey { + self.root + } +} + +/// Trie node. +#[derive(Debug)] +struct TrieNode { + blocks: Vec, + children: HashMap, + key: Vec, + last_accessed: u64, + parent: Option, + ref_count: usize, +} + +impl TrieNode { + fn new(key: Vec, blocks: Vec, last_accessed: u64, parent: Option) -> Self { + TrieNode { + children: HashMap::new(), + key, + blocks, + last_accessed, + parent, + ref_count: 0, + } + } +} + +/// Helper trait to get the length of the shared prefix of two sequences. +trait SharedPrefixLen { + fn shared_prefix_len(&self, other: &Self) -> usize; +} + +impl SharedPrefixLen for [T] +where + T: PartialEq, +{ + fn shared_prefix_len(&self, other: &Self) -> usize { + self.iter().zip(other).take_while(|(a, b)| a == b).count() + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::block_allocator::Allocator; + + use super::RadixAllocator; + + #[test] + fn allocator_reuses_prefixes() { + let mut cache = RadixAllocator::new(1, 12, None); + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation.slots, allocation.slots); + assert_eq!(allocation.prefix_len, 0); + cache.free(allocation.blocks.clone(), allocation.allocation_id); + + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation.prefix_len, 4); + } + + #[test] + fn allocator_collects_older_prefixes_first() { + let mut cache = RadixAllocator::new(1, 7, None); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]); + assert_eq!(allocation1.prefix_len, 0); + + let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap(); + assert_eq!(allocation2.blocks, vec![1, 2]); + assert_eq!(allocation2.prefix_len, 0); + + cache.free(allocation1.blocks.clone(), allocation1.allocation_id); + cache.free(allocation2.blocks.clone(), allocation2.allocation_id); + + // We should get the blocks of the first allocation, since they are more recent. + let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap(); + assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]); + assert_eq!(allocation3.prefix_len, 0); + } + + #[test] + fn allocator_frees_fully_overlapping_prefills() { + let mut cache = RadixAllocator::new(1, 10, None); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + + cache.free(allocation2.blocks.clone(), allocation2.allocation_id); + cache.free(allocation1.blocks.clone(), allocation1.allocation_id); + + let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation3.prefix_len, 4); + + // 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks. + assert_eq!(cache.free_blocks.len(), 5); + } + + #[test] + fn allocator_frees_partially_overlapping_prefills() { + let mut cache = RadixAllocator::new(1, 20, None); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap(); + assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]); + assert_eq!(allocation1.prefix_len, 0); + + cache.free(allocation1.blocks.clone(), allocation1.allocation_id); + + let allocation2 = cache + .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5]))) + .unwrap(); + assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]); + assert_eq!(allocation2.prefix_len, 2); + + let allocation3 = cache + .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7]))) + .unwrap(); + assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation3.prefix_len, 2); + + cache.free(allocation3.blocks.clone(), allocation3.allocation_id); + cache.free(allocation2.blocks.clone(), allocation2.allocation_id); + + // 20 blocks, of which 1 reserved for health checks, 6 for allocation3, 2 for allocation2. + assert_eq!(cache.free_blocks.len(), 11); + + let allocation4 = cache + .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5]))) + .unwrap(); + assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]); + assert_eq!(allocation4.prefix_len, 6); + assert_eq!(cache.free_blocks.len(), 11); + + let allocation5 = cache + .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7]))) + .unwrap(); + assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]); + assert_eq!(allocation5.prefix_len, 6); + assert_eq!(cache.free_blocks.len(), 11); + } + + #[test] + fn trie_insertions_have_correct_prefix_len() { + let mut trie = super::RadixTrie::new(); + + assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0); + + // Already exists. + assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 3); + + // Completely new at root-level + assert_eq!(trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(), 0); + + // Contains full prefix, but longer. + assert_eq!(trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(), 3); + + // Shares partial prefix, we need a split. + assert_eq!( + trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) + .unwrap(), + 4 + ); + } + + #[test] + fn trie_get_returns_correct_blocks() { + let mut trie = super::RadixTrie::new(); + trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); + trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(); + trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(); + trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) + .unwrap(); + + let mut blocks = Vec::new(); + trie.find(&[0], &mut blocks); + assert_eq!(blocks, vec![0]); + + blocks.clear(); + trie.find(&[0, 1, 2], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2]); + + blocks.clear(); + trie.find(&[1, 2, 3], &mut blocks); + assert_eq!(blocks, vec![1, 2, 3]); + + blocks.clear(); + trie.find(&[0, 1, 2, 3], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3]); + + blocks.clear(); + trie.find(&[0, 1, 2, 3, 4], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3, 4]); + + blocks.clear(); + trie.find(&[0, 1, 2, 3, 5], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3, 5]); + } + + #[test] + fn trie_evict_removes_correct_blocks() { + let mut trie = super::RadixTrie::new(); + trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); + trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) + .unwrap(); + trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(); + trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(); + + let mut blocks = Vec::new(); + + // Remove less than the leave blocks. + assert_eq!(trie.evict(1), vec![7]); + trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3, 5, 6]); + + // Refresh other leaf. + trie.find(&[0, 1, 2, 3, 4], &mut blocks); + trie.find(&[1, 2, 3], &mut blocks); + + // Remove the leave blocks exactly. + assert_eq!(trie.evict(2), vec![5, 6]); + blocks.clear(); + trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3]); + + trie.find(&[1, 2, 3], &mut blocks); + + // Remove more than the leave blocks. + assert_eq!(trie.evict(3), vec![4, 3, 2]); + blocks.clear(); + trie.find(&[0, 1, 2, 3, 4], &mut blocks); + assert_eq!(blocks, vec![0, 1]); + + // Clear out the whole trie. + assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]); + } +} diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 5e739703..7494d5b5 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -157,6 +157,7 @@ async fn prefill( top_n_tokens: top_n_tokens.unwrap_or(0), blocks: vec![], slots: vec![], + prefix_len: 0, adapter_id: None, }) .collect(); diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 926c878e..68eea7ac 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -3,22 +3,23 @@ syntax = "proto3"; package generate.v3; service TextGenerationService { - /// Model Info - rpc Info (InfoRequest) returns (InfoResponse) {} - /// Service discovery - rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {} - /// Empties batch cache - rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse); - /// Remove requests from a cached batch - rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse); - /// Warmup the model and compute max cache size - rpc Warmup (WarmupRequest) returns (WarmupResponse); - /// Prefill batch and decode first token - rpc Prefill (PrefillRequest) returns (PrefillResponse); - /// Decode token for a list of prefilled batches - rpc Decode (DecodeRequest) returns (DecodeResponse); - /// Health check - rpc Health (HealthRequest) returns (HealthResponse); + /// Model Info + rpc Info(InfoRequest) returns (InfoResponse) {} + /// Service discovery + rpc ServiceDiscovery(ServiceDiscoveryRequest) + returns (ServiceDiscoveryResponse) {} + /// Empties batch cache + rpc ClearCache(ClearCacheRequest) returns (ClearCacheResponse); + /// Remove requests from a cached batch + rpc FilterBatch(FilterBatchRequest) returns (FilterBatchResponse); + /// Warmup the model and compute max cache size + rpc Warmup(WarmupRequest) returns (WarmupResponse); + /// Prefill batch and decode first token + rpc Prefill(PrefillRequest) returns (PrefillResponse); + /// Decode token for a list of prefilled batches + rpc Decode(DecodeRequest) returns (DecodeResponse); + /// Health check + rpc Health(HealthRequest) returns (HealthResponse); } message HealthRequest {} @@ -28,240 +29,239 @@ message HealthResponse {} message InfoRequest {} message InfoResponse { - bool requires_padding = 1; - string dtype = 2; - string device_type = 3; - optional uint32 window_size = 4; - uint32 speculate = 5; + bool requires_padding = 1; + string dtype = 2; + string device_type = 3; + optional uint32 window_size = 4; + uint32 speculate = 5; } /// Empty request message ServiceDiscoveryRequest {} message ServiceDiscoveryResponse { - /// Other shards urls - repeated string urls = 1; + /// Other shards urls + repeated string urls = 1; } message ClearCacheRequest { - /// Optional batch id - optional uint64 id = 1; + /// Optional batch id + optional uint64 id = 1; } /// Empty response message ClearCacheResponse {} message Image { - /// Binary image data. - bytes data = 1; + /// Binary image data. + bytes data = 1; - /// Image MIME type. - string mimetype = 2; + /// Image MIME type. + string mimetype = 2; } message InputChunk { - oneof chunk { - /// Plain text data - string text = 1; - /// Image data - Image image = 2; - } + oneof chunk { + /// Plain text data + string text = 1; + /// Image data + Image image = 2; + } } -message Input { - repeated InputChunk chunks = 1; - } +message Input { repeated InputChunk chunks = 1; } enum GrammarType { - GRAMMAR_TYPE_NONE = 0; - GRAMMAR_TYPE_JSON = 1; - GRAMMAR_TYPE_REGEX = 2; + GRAMMAR_TYPE_NONE = 0; + GRAMMAR_TYPE_JSON = 1; + GRAMMAR_TYPE_REGEX = 2; } message NextTokenChooserParameters { - /// exponential scaling output probability distribution - float temperature = 1; - /// restricting to the k highest probability elements - uint32 top_k = 2; - /// restricting to top tokens summing to prob_cut_off <= prob_cut_off - float top_p = 3; - /// restricting to top tokens summing to prob_cut_off <= prob_cut_off - float typical_p = 4; - /// apply sampling on the logits - bool do_sample = 5; - /// random seed for sampling - uint64 seed = 6; - /// repetition penalty - float repetition_penalty = 7; - /// frequency penalty - float frequency_penalty = 9; - /// token watermarking using "A Watermark for Large Language Models" - bool watermark = 8; - /// grammar (applied if not empty) - string grammar = 10; - /// grammar type - GrammarType grammar_type = 11; + /// exponential scaling output probability distribution + float temperature = 1; + /// restricting to the k highest probability elements + uint32 top_k = 2; + /// restricting to top tokens summing to prob_cut_off <= prob_cut_off + float top_p = 3; + /// restricting to top tokens summing to prob_cut_off <= prob_cut_off + float typical_p = 4; + /// apply sampling on the logits + bool do_sample = 5; + /// random seed for sampling + uint64 seed = 6; + /// repetition penalty + float repetition_penalty = 7; + /// frequency penalty + float frequency_penalty = 9; + /// token watermarking using "A Watermark for Large Language Models" + bool watermark = 8; + /// grammar (applied if not empty) + string grammar = 10; + /// grammar type + GrammarType grammar_type = 11; } message StoppingCriteriaParameters { - /// Maximum number of generated tokens - uint32 max_new_tokens = 1; - /// Optional stopping sequences - repeated string stop_sequences = 2; - /// Ignore end of sequence token - /// used for benchmarking - bool ignore_eos_token = 3; + /// Maximum number of generated tokens + uint32 max_new_tokens = 1; + /// Optional stopping sequences + repeated string stop_sequences = 2; + /// Ignore end of sequence token + /// used for benchmarking + bool ignore_eos_token = 3; } message Request { - /// Request ID - uint64 id = 1; - /// The generation context as chunks - Input input_chunks = 8; - /// The generation context, stringified input_chunks - string inputs = 2; - /// Context truncation - uint32 truncate = 3; - /// Next Token Chooser Parameters - NextTokenChooserParameters parameters = 4; - /// Stopping Criteria Parameters - StoppingCriteriaParameters stopping_parameters = 5; - /// Return prefill logprobs - bool prefill_logprobs = 6; - /// Return most likely n tokens - uint32 top_n_tokens = 7; - /// Paged attention blocks - repeated uint32 blocks = 9; - /// Paged attention slots - repeated uint32 slots = 10; - /// LORA adapter index - optional string adapter_id = 11; + /// Request ID + uint64 id = 1; + /// The generation context as chunks + Input input_chunks = 8; + /// The generation context, stringified input_chunks + string inputs = 2; + /// Context truncation + uint32 truncate = 3; + /// Next Token Chooser Parameters + NextTokenChooserParameters parameters = 4; + /// Stopping Criteria Parameters + StoppingCriteriaParameters stopping_parameters = 5; + /// Return prefill logprobs + bool prefill_logprobs = 6; + /// Return most likely n tokens + uint32 top_n_tokens = 7; + /// Paged attention blocks + repeated uint32 blocks = 9; + /// Paged attention slots + repeated uint32 slots = 10; + /// LORA adapter index + optional string adapter_id = 11; + /// Prefix length that can be retrieved from the KV cache. + uint32 prefix_len = 12; } message Batch { - /// Batch ID - uint64 id = 1; - /// Individual requests - repeated Request requests = 2; - /// Batch size (==len(requests)) - uint32 size = 3; - /// Maximum number of tokens this batch will grow to - uint32 max_tokens = 4; - /// Maximum number of Paged Attention blocks - uint32 max_blocks = 5; + /// Batch ID + uint64 id = 1; + /// Individual requests + repeated Request requests = 2; + /// Batch size (==len(requests)) + uint32 size = 3; + /// Maximum number of tokens this batch will grow to + uint32 max_tokens = 4; + /// Maximum number of Paged Attention blocks + uint32 max_blocks = 5; } message CachedBatch { - /// Batch ID - uint64 id = 1; - /// Individual requests ids - repeated uint64 request_ids = 2; - /// Batch size (==len(requests)) - uint32 size = 3; - /// Maximum number of tokens this batch will grow to - uint32 max_tokens = 4; + /// Batch ID + uint64 id = 1; + /// Individual requests ids + repeated uint64 request_ids = 2; + /// Batch size (==len(requests)) + uint32 size = 3; + /// Maximum number of tokens this batch will grow to + uint32 max_tokens = 4; } enum FinishReason { - FINISH_REASON_LENGTH = 0; - FINISH_REASON_EOS_TOKEN = 1; - FINISH_REASON_STOP_SEQUENCE = 2; + FINISH_REASON_LENGTH = 0; + FINISH_REASON_EOS_TOKEN = 1; + FINISH_REASON_STOP_SEQUENCE = 2; } message GeneratedText { - /// Output - string text = 1; - /// Number of generated tokens - uint32 generated_tokens = 2; - /// Finish reason - FinishReason finish_reason = 3; - /// Seed - optional uint64 seed = 4; + /// Output + string text = 1; + /// Number of generated tokens + uint32 generated_tokens = 2; + /// Finish reason + FinishReason finish_reason = 3; + /// Seed + optional uint64 seed = 4; } message Tokens { - /// Token IDs - repeated uint32 ids = 1; - /// Logprobs - repeated float logprobs = 2; - /// tokens - repeated string texts = 3; - /// special - repeated bool is_special = 4; + /// Token IDs + repeated uint32 ids = 1; + /// Logprobs + repeated float logprobs = 2; + /// tokens + repeated string texts = 3; + /// special + repeated bool is_special = 4; } message Generation { - /// Request ID - uint64 request_id = 1; - /// Prefill tokens (optional) - Tokens prefill_tokens = 2; - Tokens tokens = 3; - /// Complete generated text - optional GeneratedText generated_text = 4; - /// Top tokens - repeated Tokens top_tokens = 5; + /// Request ID + uint64 request_id = 1; + /// Prefill tokens (optional) + Tokens prefill_tokens = 2; + Tokens tokens = 3; + /// Complete generated text + optional GeneratedText generated_text = 4; + /// Top tokens + repeated Tokens top_tokens = 5; } message FilterBatchRequest { - /// Batch ID - uint64 batch_id = 1; - /// Requests to keep - repeated uint64 request_ids = 2; + /// Batch ID + uint64 batch_id = 1; + /// Requests to keep + repeated uint64 request_ids = 2; } message FilterBatchResponse { - /// Filtered Batch (cached) - CachedBatch batch = 1; + /// Filtered Batch (cached) + CachedBatch batch = 1; } - message PrefillRequest { - /// Batch - Batch batch = 1; + /// Batch + Batch batch = 1; } message PrefillResponse { - /// Generation - repeated Generation generations = 1; - /// Next batch (cached) - optional CachedBatch batch = 2; - /// Forward elapsed time in nanoseconds - uint64 forward_ns = 3; - /// Decode elapsed time in nanoseconds - uint64 decode_ns = 4; - /// Total elapsed time in nanoseconds - uint64 total_ns = 5; + /// Generation + repeated Generation generations = 1; + /// Next batch (cached) + optional CachedBatch batch = 2; + /// Forward elapsed time in nanoseconds + uint64 forward_ns = 3; + /// Decode elapsed time in nanoseconds + uint64 decode_ns = 4; + /// Total elapsed time in nanoseconds + uint64 total_ns = 5; } message DecodeRequest { - /// Cached batches - repeated CachedBatch batches = 1; + /// Cached batches + repeated CachedBatch batches = 1; } message DecodeResponse { - /// Decodes - repeated Generation generations = 1; - /// Next batch (cached) - optional CachedBatch batch = 2; - /// Forward elapsed time in nanoseconds - uint64 forward_ns = 3; - /// Decode elapsed time in nanoseconds - uint64 decode_ns = 4; - /// Total elapsed time in nanoseconds - uint64 total_ns = 5; - /// Concatenate elapsed time in nanoseconds - optional uint64 concat_ns = 6; + /// Decodes + repeated Generation generations = 1; + /// Next batch (cached) + optional CachedBatch batch = 2; + /// Forward elapsed time in nanoseconds + uint64 forward_ns = 3; + /// Decode elapsed time in nanoseconds + uint64 decode_ns = 4; + /// Total elapsed time in nanoseconds + uint64 total_ns = 5; + /// Concatenate elapsed time in nanoseconds + optional uint64 concat_ns = 6; } message WarmupRequest { - /// Batch to warmup on - Batch batch = 1; - uint32 max_input_length = 2; - uint32 max_prefill_tokens = 3; - uint32 max_total_tokens = 4; + /// Batch to warmup on + Batch batch = 1; + uint32 max_input_length = 2; + uint32 max_prefill_tokens = 3; + uint32 max_total_tokens = 4; } message WarmupResponse { - /// Maximum number of tokens supported by the model - optional uint32 max_supported_total_tokens = 1; + /// Maximum number of tokens supported by the model + optional uint32 max_supported_total_tokens = 1; } diff --git a/router/src/validation.rs b/router/src/validation.rs index 3d1a4103..5011158a 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -11,6 +11,7 @@ use rand::{thread_rng, Rng}; use serde_json::Value; use std::io::Cursor; use std::iter; +use std::sync::Arc; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; use tokio::sync::mpsc; @@ -115,13 +116,14 @@ impl Validation { } } + #[allow(clippy::type_complexity)] #[instrument(skip(self, inputs))] async fn validate_input( &self, inputs: String, truncate: Option, max_new_tokens: Option, - ) -> Result<(Vec, usize, u32), ValidationError> { + ) -> Result<(Vec, Option>, usize, u32), ValidationError> { // If we have a fast tokenizer if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { // Create response channel @@ -156,8 +158,10 @@ impl Validation { )); } + let input_ids = encoding.get_ids()[..input_length].to_owned(); + metrics::histogram!("tgi_request_input_length").record(input_length as f64); - Ok((inputs, input_length, max_new_tokens)) + Ok((inputs, Some(input_ids), input_length, max_new_tokens)) } // Return inputs without validation else { @@ -180,7 +184,12 @@ impl Validation { input_length = input_length.saturating_sub(max_new_tokens as usize); } - Ok((vec![Chunk::Text(inputs)], input_length, max_new_tokens)) + Ok(( + vec![Chunk::Text(inputs)], + None, + input_length, + max_new_tokens, + )) } } @@ -314,7 +323,7 @@ impl Validation { .unwrap_or(Ok(None))?; // Validate inputs - let (inputs, input_length, max_new_tokens) = self + let (inputs, input_ids, input_length, max_new_tokens) = self .validate_input(request.inputs, truncate, max_new_tokens) .await?; @@ -391,6 +400,7 @@ impl Validation { Ok(ValidGenerateRequest { inputs, + input_ids: input_ids.map(Arc::new), decoder_input_details, input_length: input_length as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32, @@ -707,6 +717,7 @@ pub struct ValidStoppingParameters { #[derive(Debug, Clone)] pub struct ValidGenerateRequest { pub inputs: Vec, + pub input_ids: Option>>, pub input_length: u32, pub truncate: u32, pub decoder_input_details: bool, diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index f9b1715e..56fc5319 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -6,7 +6,12 @@ from .common import Seqlen if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") if SYSTEM == "cuda": - from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING + from .cuda import ( + attention, + paged_attention, + reshape_and_cache, + SUPPORTS_WINDOWING, + ) elif SYSTEM == "rocm": from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING elif SYSTEM == "ipex": diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index d039e1e7..5fd0eff6 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -221,9 +221,11 @@ SUPPORTS_WINDOWING = V2 if ATTENTION == "flashinfer": def attention( - q, - k, - v, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, cu_seqlens, max_s, softmax_scale, @@ -231,14 +233,15 @@ if ATTENTION == "flashinfer": causal=True, softcap=0.0, ): - from text_generation_server.layers.attention.flash_infer import prefill_state + assert window_size_left == -1, "Windowing is not supported with flash infer" + from text_generation_server.layers.attention.flash_infer import ( + prefill_with_paged_kv_state, + ) - return prefill_state.get().forward( - q, - k, - v, + return prefill_with_paged_kv_state.get().forward( + q.contiguous(), causal=causal, - window_left=window_size_left, + paged_kv_cache=(key_cache, value_cache), logits_soft_cap=softcap, sm_scale=softmax_scale, ) @@ -249,6 +252,8 @@ elif V2: q, k, v, + key_cache: torch.Tensor, + value_cache: torch.Tensor, cu_seqlens, max_s, softmax_scale, @@ -289,6 +294,8 @@ else: q, k, v, + key_cache: torch.Tensor, + value_cache: torch.Tensor, cu_seqlens, max_s, softmax_scale, diff --git a/server/text_generation_server/layers/attention/flash_infer.py b/server/text_generation_server/layers/attention/flash_infer.py index 56b53b2c..e1ef62c5 100644 --- a/server/text_generation_server/layers/attention/flash_infer.py +++ b/server/text_generation_server/layers/attention/flash_infer.py @@ -9,6 +9,10 @@ prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = Con "prefill_state" ) +prefill_with_paged_kv_state: ContextVar[ + flashinfer.BatchPrefillWithPagedKVCacheWrapper +] = ContextVar("prefill_with_paged_kv_state") + decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar( "decode_state" ) @@ -24,6 +28,78 @@ def get_workspace(device): return workspace +def create_prefill_with_paged_kv_state( + *, + device: torch.device, +): + """Create a prefill state that uses the KV cache.""" + workspace_buffer = get_workspace(device) + return flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout="NHD", use_cuda_graph=False + ) + + +@contextmanager +def use_prefill_with_paged_kv_state( + *, + state: flashinfer.BatchPrefillWithPagedKVCacheWrapper, + block_tables: torch.Tensor, + cu_seqlens: torch.Tensor, + input_lengths: torch.Tensor, + num_heads: int, + num_kv_heads: int, + head_size: int, + page_size: int, + query_dtype: str = "float16", +): + """ + Context manager to set the active flashinfer prefill state to the given + `state` and parameters. This state will be used by all calls to the + `attention` function while the context manager is active. + """ + + indptr = torch.zeros( + input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32 + ) + # Round up to page size and then calculate the cumulative sum to get + # the indices into the block table. + torch.add(input_lengths, page_size - 1, out=indptr[1:]) + indptr[1:].div_(page_size, rounding_mode="floor") + indptr[1:].cumsum_(-1) + + # Get the lengths of the last page in a block. + if page_size == 1: + last_page_len = torch.ones( + input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device + ) + else: + last_page_len = torch.empty( + input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device + ) + torch.sub(input_lengths, 1, out=last_page_len) + last_page_len.remainder_(page_size) + last_page_len += 1 + + token = prefill_with_paged_kv_state.set(state) + try: + state.begin_forward( + qo_indptr=cu_seqlens, + paged_kv_indptr=indptr, + paged_kv_indices=block_tables, + paged_kv_last_page_len=last_page_len, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_size, + q_data_type=query_dtype, + page_size=page_size, + ) + yield + finally: + state.end_forward() + if token is not None: + prefill_with_paged_kv_state.reset(token) + + def create_prefill_state( *, device: torch.device, diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index e02a31d9..1eb8c6c3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -298,6 +298,8 @@ class FlashCohereAttention(torch.nn.Module): query, key, value, + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index d3d1d1ef..fc0dca5b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -337,6 +337,8 @@ class DbrxAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 0905d3c2..b25becd5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -365,6 +365,8 @@ class DeepseekV2Attention(torch.nn.Module): query, key, value, + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 54d212e6..faf0f325 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -238,6 +238,8 @@ class FlashGemma2Attention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 178efadb..33738a59 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -232,6 +232,8 @@ class FlashGemmaAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index a19cff8c..d30b5a0a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -232,6 +232,8 @@ class FlashGPT2Attention(torch.nn.Module): query, key, value, + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 9ea19a87..91a835f3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -21,6 +21,7 @@ from contextlib import contextmanager from typing import List, Optional, Tuple +from loguru import logger import torch import torch.distributed @@ -220,6 +221,8 @@ class FlashLlamaAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index dda53ff3..5a150267 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -219,6 +219,8 @@ class MistralAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 85431c6c..ad426ffe 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -276,6 +276,8 @@ class MixtralAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 67237d5c..b684e035 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -173,6 +173,8 @@ class FlashNeoxAttention(torch.nn.Module): qkv[:, 0], qkv[:, 1], qkv[:, 2], + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index a9e18348..efe27c13 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -194,6 +194,8 @@ class FlashPhiAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 865cc85d..879b8abd 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -137,6 +137,8 @@ class Qwen2Attention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 10f995a3..c72a9b90 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -208,6 +208,8 @@ class FlashRWAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, @@ -326,6 +328,8 @@ class FlashRWLargeAttention(torch.nn.Module): query, torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index c2676782..109304be 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -293,6 +293,8 @@ class FlashMQAttention(torch.nn.Module): query, torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index e562eb89..200d4ef0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -242,6 +242,8 @@ class Starcoder2Attention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 21b66a68..2be753cd 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -20,6 +20,9 @@ from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, D from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from text_generation_server.layers.attention.flash_infer import ( + create_prefill_with_paged_kv_state, +) from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models import Model @@ -43,6 +46,7 @@ from text_generation_server.models.globals import ( ATTENTION, BLOCK_SIZE, CUDA_GRAPHS, + PREFIX_CACHING, get_adapter_to_index, ) from text_generation_server.layers.attention import Seqlen @@ -108,6 +112,9 @@ class FlashCausalLMBatch(Batch): block_tables_tensor: torch.Tensor # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences slots: torch.Tensor + # size [b], containing the number of blocks that can be retrieved from the cache + prefix_lens: List[int] + prefix_lens_tensor: torch.Tensor max_seqlen: int @@ -116,6 +123,9 @@ class FlashCausalLMBatch(Batch): prefill_next_token_indices: Optional[torch.tensor] prefill_cu_outlens: Optional[List[int]] + # Prefixes + prefix_ids: List[List[int]] + # All tokens all_input_ids: List[List[int]] all_input_ids_tensor: torch.Tensor @@ -183,6 +193,7 @@ class FlashCausalLMBatch(Batch): prefix_offsets = [] read_offsets = [] all_input_ids = [] + prefix_ids = [] requests_idx_mapping = {} all_prefill_logprobs = True @@ -200,7 +211,7 @@ class FlashCausalLMBatch(Batch): # Cumulative length cumulative_length = 0 - cumulative_max_length = 0 + cumulative_slot_tokens = 0 prefill_out_cumulative_length = 0 num_blocks = 0 @@ -210,6 +221,7 @@ class FlashCausalLMBatch(Batch): block_tables = [] slots = [] + prefix_lens = [] # Parse batch for i, (r, tokenized_input) in enumerate( @@ -225,6 +237,19 @@ class FlashCausalLMBatch(Batch): ): tokenized_input = tokenized_input[1:] + orig_input_length = len(tokenized_input) + + if PREFIX_CACHING: + prefix_len = r.prefix_len + if prefix_len == orig_input_length: + assert prefix_len > 0 + prefix_len -= 1 + else: + prefix_len = 0 + + prefix_ids.append(tokenized_input[:prefix_len]) + tokenized_input = tokenized_input[prefix_len:] + input_length = len(tokenized_input) input_lengths.append(input_length) @@ -234,7 +259,9 @@ class FlashCausalLMBatch(Batch): all_input_ids.append(tokenized_input) # Position ids - request_position_ids = torch.arange(0, input_length, dtype=torch.int32) + request_position_ids = torch.arange( + prefix_len, orig_input_length, dtype=torch.int32 + ) position_ids.append(request_position_ids) # Add cumulative lengths of all previous inputs @@ -258,11 +285,17 @@ class FlashCausalLMBatch(Batch): # Remove one as the first token des not have a past speculative_length = get_speculate() speculative_length = 0 if speculative_length is None else speculative_length - total_tokens = input_length + max_new_tokens - 1 + speculative_length + + # Tokens that need to be mapped to blocks. + block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length + + # Tokens that need to be mapped to slots. We don't need slots for the + # cached prefix (if present). + slot_tokens = input_length + max_new_tokens - 1 + speculative_length # blocks and slots can be empty (for example in warmup) if not r.blocks: - needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) + needed_blocks = math.ceil(block_tokens / BLOCK_SIZE) request_blocks = [ b for b in range(num_blocks, num_blocks + needed_blocks) ] @@ -273,16 +306,20 @@ class FlashCausalLMBatch(Batch): ] else: request_blocks = r.blocks - request_slots = r.slots + request_slots = r.slots[ + prefix_len: #: orig_input_length + max_new_tokens + speculative_length + ] block_tables.append(request_blocks) - slots.extend(request_slots[:total_tokens]) + + slots.extend(request_slots) + prefix_lens.append(prefix_len) num_blocks += len(request_blocks) - start_slots.append(cumulative_max_length) + start_slots.append(cumulative_slot_tokens) request_slot_indices = torch.arange( - cumulative_max_length, - cumulative_max_length + input_length, + cumulative_slot_tokens, + cumulative_slot_tokens + input_length, dtype=torch.int64, ) slot_indices.append(request_slot_indices) @@ -318,7 +355,7 @@ class FlashCausalLMBatch(Batch): # Update cumulative_length += input_length - cumulative_max_length += total_tokens + cumulative_slot_tokens += slot_tokens max_seqlen = max(max_seqlen, input_length) max_blocks = max(max_blocks, len(request_blocks)) max_length = max( @@ -395,12 +432,14 @@ class FlashCausalLMBatch(Batch): ) slots = torch.tensor(slots, dtype=torch.int64, device=device) + block_tables_tensor = torch.zeros( (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" ) for i, request_blocks in enumerate(block_tables): block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) block_tables_tensor = block_tables_tensor.to(device) + prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device) return cls( batch_id=pb.id, @@ -415,6 +454,8 @@ class FlashCausalLMBatch(Batch): block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, + prefix_lens=prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, max_seqlen=max_seqlen, prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, @@ -425,6 +466,7 @@ class FlashCausalLMBatch(Batch): read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, + prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -480,6 +522,7 @@ class FlashCausalLMBatch(Batch): start_slots = [] block_tables = [] all_input_ids = [] + prefix_ids = [] input_lengths = [] prefix_offsets = [] @@ -506,6 +549,7 @@ class FlashCausalLMBatch(Batch): max_seqlen = max(max_seqlen, request_input_length) all_input_ids.append(self.all_input_ids[idx]) + prefix_ids.append(self.prefix_ids[idx]) input_lengths.append(request_input_length) prefix_offsets.append(self.prefix_offsets[idx]) @@ -591,6 +635,7 @@ class FlashCausalLMBatch(Batch): read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, + prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -651,6 +696,7 @@ class FlashCausalLMBatch(Batch): block_tables_tensor = batches[0].block_tables_tensor.new_zeros( (total_batch_size, max_blocks) ) + prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size) all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( (total_batch_size, max_length) ) @@ -668,7 +714,9 @@ class FlashCausalLMBatch(Batch): start_slots = [] block_tables = [] + prefix_lens = [] all_input_ids = [] + prefix_ids = [] input_lengths = [] prefix_offsets = [] @@ -730,10 +778,14 @@ class FlashCausalLMBatch(Batch): start_index:end_index, : batch.block_tables_tensor.shape[1] ] = batch.block_tables_tensor[:, :max_blocks] + prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor + start_slots.append(batch.start_slots + cumulative_slots) block_tables.extend(batch.block_tables) + prefix_lens.extend(batch.prefix_lens) all_input_ids.extend(batch.all_input_ids) + prefix_ids.extend(batch.prefix_ids) input_lengths.extend(batch.input_lengths) prefix_offsets.extend(batch.prefix_offsets) @@ -779,6 +831,8 @@ class FlashCausalLMBatch(Batch): slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, + prefix_lens=prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, slots=slots, max_seqlen=max_seqlen, prefill_head_indices=None, @@ -790,6 +844,7 @@ class FlashCausalLMBatch(Batch): read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, + prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -944,6 +999,9 @@ class FlashCausalLM(Model): ) self.prefill_state = create_prefill_state(device=device) + self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state( + device=device + ) if not CUDA_GRAPHS: self.decode_state = create_decode_state( @@ -1042,12 +1100,23 @@ class FlashCausalLM(Model): input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) slots = torch.arange(bs, dtype=torch.int64, device=self.device) - input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s - block_tables = ( - torch.arange(max_bt, dtype=torch.int32, device=self.device) - .repeat(bs) - .reshape((bs, max_bt)) + input_lengths = [max_s] * bs + prefix_lengths = [0] * bs + input_lengths_tensor = ( + torch.ones(bs, dtype=torch.int32, device=self.device) * max_s ) + prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) + block_tables = torch.arange( + max_bt, dtype=torch.int32, device=self.device + ).repeat(bs) + block_tables = block_tables.reshape((bs, max_bt)) + + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=input_lengths, + prefix_lens=prefix_lengths, + ) self.cuda_graphs[bs] = { "input_ids": input_ids, @@ -1055,9 +1124,9 @@ class FlashCausalLM(Model): "kv_cache": self.kv_cache, "block_tables": block_tables, "slots": slots, - "input_lengths": input_lengths, + "input_lengths": input_lengths_tensor, } - input_lengths_ = Seqlen(input_lengths=input_lengths) + input_lengths_ = Seqlen(input_lengths=input_lengths_tensor) graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph @@ -1072,7 +1141,7 @@ class FlashCausalLM(Model): last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device) state = create_decode_state_cuda_graphs( device=input_ids.device, - block_tables=block_tables.view(-1), + block_tables=block_tables, block_tables_ptr=block_tables_ptr, last_page_len=last_page_len, num_heads=self.num_heads, @@ -1088,7 +1157,10 @@ class FlashCausalLM(Model): block_tables=block_tables, cu_seqlen_prefill=None, input_lengths=input_lengths, + input_lengths_tensor=input_lengths_tensor, state=state, + prefix_lens=prefix_lengths, + prefix_lens_tensor=prefix_lengths_tensor, ): self.model.forward( input_ids=input_ids, @@ -1106,7 +1178,7 @@ class FlashCausalLM(Model): torch.cuda.synchronize() with torch.cuda.graph(graph, pool=MEM_POOL): - input_lengths = Seqlen(input_lengths=input_lengths) + input_lengths_tensor = Seqlen(input_lengths=input_lengths_tensor) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -1114,7 +1186,7 @@ class FlashCausalLM(Model): kv_cache=self.kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + input_lengths=input_lengths_tensor, max_s=max_s, prefill_cache_indices=None, lm_head_indices=None, @@ -1343,7 +1415,10 @@ class FlashCausalLM(Model): with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=cu_seqlen_prefill, - input_lengths=input_lengths, + input_lengths=batch.input_lengths, + input_lengths_tensor=input_lengths, + prefix_lens=batch.prefix_lens, + prefix_lens_tensor=batch.prefix_lens_tensor, ): input_lengths = Seqlen(input_lengths=input_lengths) logits, speculative_logits = self.model.forward( @@ -1367,19 +1442,32 @@ class FlashCausalLM(Model): # Static inputs are potentially padded cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids - cuda_graph["block_tables"][ - : block_tables.shape[0], : block_tables.shape[1] - ] = block_tables + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + prefix_lens=batch.prefix_lens, + ) + cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables + else: + cuda_graph["block_tables"][ + : block_tables.shape[0], : block_tables.shape[1] + ] = block_tables cuda_graph["slots"].fill_(-1) cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( + input_lengths + batch.prefix_lens_tensor + ) state = cuda_graph.get("state") with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=None, - input_lengths=input_lengths, + input_lengths=batch.input_lengths, + input_lengths_tensor=input_lengths, + prefix_lens=batch.prefix_lens, + prefix_lens_tensor=batch.prefix_lens_tensor, state=state, ): # Replay the graph @@ -1578,6 +1666,7 @@ class FlashCausalLM(Model): batch.read_offsets, batch.stopping_criterias, batch.all_input_ids, + batch.prefix_ids, batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, batch.top_n_tokens, @@ -1595,6 +1684,7 @@ class FlashCausalLM(Model): read_offset, stopping_criteria, all_input_ids, + prefix_ids, do_sample, seed, top_n_tokens, @@ -1669,18 +1759,18 @@ class FlashCausalLM(Model): out_end_index = batch.prefill_cu_outlens[i + 1] # Remove generated token to only have prefill and add nan for first prompt token - request_prefill_logprobs = [float("nan")] + prefill_logprobs[ - out_start_index : out_end_index - 1 - ] + request_prefill_logprobs = ( + [float("nan")] * (len(prefix_ids) + 1) + ) + prefill_logprobs[out_start_index : out_end_index - 1] prefill_token_ids = all_input_ids[:-1] prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, + prefix_ids + prefill_token_ids, clean_up_tokenization_spaces=False, skip_special_tokens=False, ) prefill_tokens = Tokens( - prefill_token_ids, + prefix_ids + prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special=[], @@ -1762,7 +1852,10 @@ class FlashCausalLM(Model): *, block_tables: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], - input_lengths: torch.Tensor, + input_lengths: List[int], + input_lengths_tensor: torch.Tensor, + prefix_lens: List[int], + prefix_lens_tensor: torch.Tensor, state: Optional[Any] = None, ) -> ContextManager: if ATTENTION != "flashinfer": @@ -1771,24 +1864,65 @@ class FlashCausalLM(Model): from text_generation_server.layers.attention.flash_infer import ( use_decode_state, use_prefill_state, + use_prefill_with_paged_kv_state, ) + # has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens) + if cu_seqlen_prefill is not None: - return use_prefill_state( - state=state if state is not None else self.prefill_state, - cu_seqlens=cu_seqlen_prefill, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - head_size=self.head_size, - ) + if True: # has_prefix_lens: + return use_prefill_with_paged_kv_state( + state=( + state if state is not None else self.prefill_with_paged_kv_state + ), + block_tables=block_tables_to_ragged( + block_tables=block_tables, + input_lengths=input_lengths, + prefix_lens=prefix_lens, + ), + cu_seqlens=cu_seqlen_prefill, + input_lengths=input_lengths_tensor + prefix_lens_tensor, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + page_size=BLOCK_SIZE, + ) + else: + return use_prefill_state( + state=state if state is not None else self.prefill_state, + cu_seqlens=cu_seqlen_prefill, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + ) else: - assert input_lengths is not None + assert input_lengths_tensor is not None return use_decode_state( state=state if state is not None else self.decode_state, - input_lengths=input_lengths, - block_tables=block_tables.view(-1), + input_lengths=input_lengths_tensor + prefix_lens_tensor, + block_tables=block_tables, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_size=self.head_size, page_size=BLOCK_SIZE, ) + + +def block_tables_to_ragged( + *, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int] +) -> torch.Tensor: + """Convert block table to ragged format compatible with FlashInfer.""" + assert len(input_lengths) == len(prefix_lens) + + total_len = sum(input_lengths) + sum(prefix_lens) + block_tables_ragged = torch.empty( + total_len, dtype=torch.int32, device=block_tables.device + ) + + offset = 0 + for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)): + seq_len = prefix_len + input_length + block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len] + offset += seq_len + + return block_tables_ragged diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index b58a5b80..7edac27f 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,17 +5,29 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master -ATTENTION = os.getenv("ATTENTION", "paged") +PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", False) +log_master(logger.info, f"Using Attention = {PREFIX_CACHING}") + +ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged") _expected = {"paged", "flashdecoding", "flashinfer"} assert ( ATTENTION in _expected ), f"Attention is not valid {ATTENTION}, expected {_expected}" log_master(logger.info, f"Using Attention = {ATTENTION}") +if PREFIX_CACHING and ATTENTION != "flashinfer": + raise RuntimeError("Prefix caching is only supported with flashinfer") + MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli -BLOCK_SIZE: int = 256 if ATTENTION == "flashdecoding" else 16 +BLOCK_SIZE: int +if ATTENTION == "flashdecoding": + BLOCK_SIZE = 256 +elif ATTENTION == "flashinfer": + BLOCK_SIZE = 1 +else: + BLOCK_SIZE = 16 cuda_graphs = os.getenv("CUDA_GRAPHS") if cuda_graphs is not None: