Prefix caching WIP

This commit is contained in:
Daniël de Kok 2024-08-09 11:47:14 +00:00
parent 7a48a84784
commit 7735b385dc
34 changed files with 1451 additions and 307 deletions

1
Cargo.lock generated
View File

@ -4045,6 +4045,7 @@ dependencies = [
"reqwest", "reqwest",
"serde", "serde",
"serde_json", "serde_json",
"slotmap",
"text-generation-router", "text-generation-router",
"thiserror", "thiserror",
"tokenizers", "tokenizers",

View File

@ -156,6 +156,7 @@ impl Client {
// Blocks and slots will be set on the server side if we use paged attention // Blocks and slots will be set on the server side if we use paged attention
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
prefix_len: 0,
// Set sampling parameters to also take these ops into account in the max memory // Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 0.9, temperature: 0.9,

View File

@ -244,6 +244,7 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks // Block 0 is reserved for health checks
blocks: vec![0], blocks: vec![0],
slots: (0..16).collect(), slots: (0..16).collect(),
prefix_len: 0,
adapter_id: None, adapter_id: None,
}; };
let batch = Batch { let batch = Batch {

View File

@ -33,6 +33,7 @@ rand = "0.8.5"
reqwest = { version = "0.11.20", features = [] } reqwest = { version = "0.11.20", features = [] }
serde = "1.0.188" serde = "1.0.188"
serde_json = "1.0.107" serde_json = "1.0.107"
slotmap = "1.0.7"
thiserror = "1.0.48" thiserror = "1.0.48"
tokenizers = { workspace = true} tokenizers = { workspace = true}
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }

View File

@ -35,15 +35,24 @@ impl BackendV3 {
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
) -> Self { ) -> Self {
let prefix_caching = if let Ok(prefix_caching) = std::env::var("USE_PREFIX_CACHING") {
matches!(prefix_caching.as_str(), "true" | "1")
} else {
false
};
let attention = if let Ok(attention) = std::env::var("ATTENTION") { let attention = if let Ok(attention) = std::env::var("ATTENTION") {
attention attention
.parse() .parse()
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")) .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
} else if prefix_caching {
Attention::FlashInfer
} else { } else {
Attention::Paged Attention::Paged
}; };
let block_size = if attention == Attention::FlashDecoding { let block_size = if attention == Attention::FlashDecoding {
256 256
} else if attention == Attention::FlashInfer {
1
} else { } else {
16 16
}; };
@ -51,6 +60,7 @@ impl BackendV3 {
let queue = Queue::new( let queue = Queue::new(
requires_padding, requires_padding,
block_size, block_size,
prefix_caching,
window_size, window_size,
speculate, speculate,
max_batch_total_tokens, max_batch_total_tokens,

View File

@ -1,16 +1,26 @@
use std::cmp::min; use std::{cmp::min, sync::Arc};
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use crate::radix::RadixAllocator;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct BlockAllocation { pub(crate) struct BlockAllocation {
pub allocation_id: u64,
pub blocks: Vec<u32>, pub blocks: Vec<u32>,
pub slots: Vec<u32>, pub slots: Vec<u32>,
block_allocator: BlockAllocator,
/// Prefix that was cached and for which the KV does not have to
/// be recomputed.
pub prefix_len: u32,
pub(crate) block_allocator: Option<BlockAllocator>,
} }
impl Drop for BlockAllocation { impl Drop for BlockAllocation {
fn drop(&mut self) { fn drop(&mut self) {
self.block_allocator.free(self.blocks.clone()) if let Some(block_allocator) = self.block_allocator.as_mut() {
block_allocator.free(self.blocks.clone(), self.allocation_id)
}
} }
} }
@ -24,6 +34,7 @@ impl BlockAllocator {
pub(crate) fn new( pub(crate) fn new(
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
block_size: u32, block_size: u32,
prefix_caching: bool,
window_size: Option<u32>, window_size: Option<u32>,
) -> Self { ) -> Self {
// Create channel // Create channel
@ -33,6 +44,7 @@ impl BlockAllocator {
tokio::spawn(block_allocator_task( tokio::spawn(block_allocator_task(
max_batch_total_tokens / block_size, max_batch_total_tokens / block_size,
block_size, block_size,
prefix_caching,
window_size, window_size,
receiver, receiver,
)); ));
@ -42,28 +54,32 @@ impl BlockAllocator {
} }
} }
pub(crate) async fn allocate(&self, tokens: u32) -> Option<BlockAllocation> { pub(crate) async fn allocate(
&self,
tokens: u32,
prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation> {
let (response_sender, response_receiver) = oneshot::channel(); let (response_sender, response_receiver) = oneshot::channel();
self.block_allocator self.block_allocator
.send(BlockAllocatorCommand::Allocate { .send(BlockAllocatorCommand::Allocate {
tokens, tokens,
prefill_tokens,
response_sender, response_sender,
}) })
.unwrap(); .unwrap();
response_receiver response_receiver.await.unwrap().map(|mut allocation| {
.await allocation.block_allocator = Some(self.clone());
.unwrap() allocation
.map(|(blocks, slots)| BlockAllocation {
blocks,
slots,
block_allocator: self.clone(),
}) })
} }
pub(crate) fn free(&self, blocks: Vec<u32>) { pub(crate) fn free(&self, blocks: Vec<u32>, allocation_id: u64) {
self.block_allocator self.block_allocator
.send(BlockAllocatorCommand::Free { blocks }) .send(BlockAllocatorCommand::Free {
allocation_id,
blocks,
})
.unwrap(); .unwrap();
} }
} }
@ -71,54 +87,29 @@ impl BlockAllocator {
async fn block_allocator_task( async fn block_allocator_task(
blocks: u32, blocks: u32,
block_size: u32, block_size: u32,
prefix_caching: bool,
window_size: Option<u32>, window_size: Option<u32>,
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>, mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
) { ) {
// Block 0 is reserved for health checks let mut allocator: Box<dyn Allocator + Send> = if prefix_caching {
let mut free_blocks: Vec<u32> = (1..blocks).collect(); Box::new(RadixAllocator::new(block_size, blocks, window_size))
} else {
Box::new(SimpleAllocator::new(blocks, block_size, window_size))
};
while let Some(cmd) = receiver.recv().await { while let Some(cmd) = receiver.recv().await {
match cmd { match cmd {
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks), BlockAllocatorCommand::Free {
blocks,
allocation_id,
} => allocator.free(blocks, allocation_id),
BlockAllocatorCommand::Allocate { BlockAllocatorCommand::Allocate {
tokens, tokens,
prefill_tokens,
response_sender, response_sender,
} => { } => {
// Apply window size response_sender
let (required_blocks, repeats) = { .send(allocator.allocate(tokens, prefill_tokens))
let (tokens, repeats) = match window_size { .unwrap();
None => (tokens, 1),
Some(window_size) => {
let repeats = (tokens + window_size - 1) / window_size;
let tokens = min(tokens, window_size);
(tokens, repeats as usize)
}
};
// Pad to a multiple of block size
let required_blocks = (tokens + block_size - 1) / block_size;
(required_blocks, repeats)
};
let tokens = tokens as usize;
let allocation = if required_blocks > free_blocks.len() as u32 {
None
} else {
let blocks =
free_blocks.split_off(free_blocks.len() - required_blocks as usize);
let mut slots = Vec::with_capacity(
(required_blocks * block_size * repeats as u32) as usize,
);
'slots: for block_id in blocks.repeat(repeats).iter() {
for s in (block_id * block_size)..((block_id + 1) * block_size) {
slots.push(s);
if slots.len() == tokens {
break 'slots;
}
}
}
Some((blocks, slots))
};
response_sender.send(allocation).unwrap();
} }
} }
} }
@ -128,9 +119,92 @@ async fn block_allocator_task(
enum BlockAllocatorCommand { enum BlockAllocatorCommand {
Free { Free {
blocks: Vec<u32>, blocks: Vec<u32>,
allocation_id: u64,
}, },
Allocate { Allocate {
tokens: u32, tokens: u32,
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>)>>, prefill_tokens: Option<Arc<Vec<u32>>>,
response_sender: oneshot::Sender<Option<BlockAllocation>>,
}, },
} }
pub(crate) trait Allocator {
fn allocate(
&mut self,
tokens: u32,
prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation>;
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
}
pub struct SimpleAllocator {
free_blocks: Vec<u32>,
block_size: u32,
window_size: Option<u32>,
}
impl SimpleAllocator {
fn new(blocks: u32, block_size: u32, window_size: Option<u32>) -> Self {
SimpleAllocator {
block_size,
// Block 0 is reserved for health checks
free_blocks: (1..blocks).collect(),
window_size,
}
}
}
impl Allocator for SimpleAllocator {
fn allocate(
&mut self,
tokens: u32,
_prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation> {
// Apply window size
let (required_blocks, repeats) = {
let (tokens, repeats) = match self.window_size {
None => (tokens, 1),
Some(window_size) => {
let repeats = (tokens + window_size - 1) / window_size;
let tokens = min(tokens, window_size);
(tokens, repeats as usize)
}
};
// Pad to a multiple of block size
let required_blocks = (tokens + self.block_size - 1) / self.block_size;
(required_blocks, repeats)
};
let tokens = tokens as usize;
if required_blocks > self.free_blocks.len() as u32 {
None
} else {
let blocks = self
.free_blocks
.split_off(self.free_blocks.len() - required_blocks as usize);
let mut slots =
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
'slots: for block_id in blocks.repeat(repeats).iter() {
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
slots.push(s);
if slots.len() == tokens {
break 'slots;
}
}
}
Some(BlockAllocation {
allocation_id: 0,
blocks,
slots,
prefix_len: 0,
block_allocator: None,
})
}
}
fn free(&mut self, blocks: Vec<u32>, _allocation_id: u64) {
self.free_blocks.extend(blocks)
}
}

View File

@ -157,6 +157,7 @@ impl Client {
// Blocks and slots will be set on the server side if we use paged attention // Blocks and slots will be set on the server side if we use paged attention
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
prefix_len: 0,
// Set sampling parameters to also take these ops into account in the max memory // Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 0.9, temperature: 0.9,

View File

@ -245,6 +245,7 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks // Block 0 is reserved for health checks
blocks: vec![0], blocks: vec![0],
slots: (0..16).collect(), slots: (0..16).collect(),
prefix_len: 0,
adapter_id: None, adapter_id: None,
}; };
let batch = Batch { let batch = Batch {

View File

@ -2,6 +2,7 @@ mod backend;
mod block_allocator; mod block_allocator;
mod client; mod client;
mod queue; mod queue;
mod radix;
use crate::client::{ClientError, ShardedClient}; use crate::client::{ClientError, ShardedClient};
pub(crate) use backend::BackendV3; pub(crate) use backend::BackendV3;

View File

@ -46,6 +46,7 @@ impl Queue {
pub(crate) fn new( pub(crate) fn new(
requires_padding: bool, requires_padding: bool,
block_size: u32, block_size: u32,
prefix_caching: bool,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
@ -57,6 +58,7 @@ impl Queue {
tokio::spawn(queue_task( tokio::spawn(queue_task(
requires_padding, requires_padding,
block_size, block_size,
prefix_caching,
window_size, window_size,
speculate, speculate,
max_batch_total_tokens, max_batch_total_tokens,
@ -109,6 +111,7 @@ impl Queue {
async fn queue_task( async fn queue_task(
requires_padding: bool, requires_padding: bool,
block_size: u32, block_size: u32,
prefix_caching: bool,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
@ -117,6 +120,7 @@ async fn queue_task(
let mut state = State::new( let mut state = State::new(
requires_padding, requires_padding,
block_size, block_size,
prefix_caching,
window_size, window_size,
speculate, speculate,
max_batch_total_tokens, max_batch_total_tokens,
@ -176,12 +180,19 @@ impl State {
fn new( fn new(
requires_padding: bool, requires_padding: bool,
block_size: u32, block_size: u32,
prefix_caching: bool,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
) -> Self { ) -> Self {
let block_allocator = (!requires_padding) let block_allocator = (!requires_padding).then(|| {
.then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size)); BlockAllocator::new(
max_batch_total_tokens,
block_size,
prefix_caching,
window_size,
)
});
Self { Self {
entries: VecDeque::with_capacity(128), entries: VecDeque::with_capacity(128),
@ -305,7 +316,10 @@ impl State {
+ self.speculate + self.speculate
- 1; - 1;
match block_allocator.allocate(tokens).await { match block_allocator
.allocate(tokens, entry.request.input_ids.clone())
.await
{
None => { None => {
// Entry is over budget // Entry is over budget
// Add it back to the front // Add it back to the front
@ -331,11 +345,12 @@ impl State {
// Update entry // Update entry
entry.temp_span = Some(entry_batch_span); entry.temp_span = Some(entry_batch_span);
let (blocks, slots) = match &block_allocation { let (blocks, slots, prefix_len) = match &block_allocation {
None => (Vec::new(), Vec::new()), None => (Vec::new(), Vec::new(), 0),
Some(block_allocation) => ( Some(block_allocation) => (
block_allocation.blocks.clone(), block_allocation.blocks.clone(),
block_allocation.slots.clone(), block_allocation.slots.clone(),
block_allocation.prefix_len,
), ),
}; };
@ -372,6 +387,7 @@ impl State {
top_n_tokens: entry.request.top_n_tokens, top_n_tokens: entry.request.top_n_tokens,
blocks, blocks,
slots, slots,
prefix_len,
adapter_id: entry.request.adapter_id.clone(), adapter_id: entry.request.adapter_id.clone(),
}); });
// Set batch_time // Set batch_time
@ -480,6 +496,8 @@ impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::sync::Arc;
use super::*; use super::*;
use tracing::info_span; use tracing::info_span;
@ -492,6 +510,7 @@ mod tests {
let entry = Entry { let entry = Entry {
request: ValidGenerateRequest { request: ValidGenerateRequest {
inputs: vec![], inputs: vec![],
input_ids: Some(Arc::new(vec![])),
input_length: 0, input_length: 0,
truncate: 0, truncate: 0,
decoder_input_details: false, decoder_input_details: false,
@ -527,7 +546,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_append() { async fn test_append() {
let mut state = State::new(false, 1, None, 0, 16); let mut state = State::new(false, 1, false, None, 0, 16);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0); assert_eq!(state.next_id, 0);
@ -543,7 +562,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_empty() { async fn test_next_batch_empty() {
let mut state = State::new(false, 1, None, 0, 16); let mut state = State::new(false, 1, false, None, 0, 16);
assert!(state.next_batch(None, None, 1, 1).await.is_none()); assert!(state.next_batch(None, None, 1, 1).await.is_none());
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none()); assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
@ -551,7 +570,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_min_size() { async fn test_next_batch_min_size() {
let mut state = State::new(false, 1, None, 0, 16); let mut state = State::new(false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -583,7 +602,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_max_size() { async fn test_next_batch_max_size() {
let mut state = State::new(false, 1, None, 0, 16); let mut state = State::new(false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -603,7 +622,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_token_budget() { async fn test_next_batch_token_budget() {
let mut state = State::new(false, 1, None, 0, 2); let mut state = State::new(false, 1, false, None, 0, 2);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -636,14 +655,14 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_append() { async fn test_queue_append() {
let queue = Queue::new(false, 1, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
queue.append(entry); queue.append(entry);
} }
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_empty() { async fn test_queue_next_batch_empty() {
let queue = Queue::new(false, 1, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16);
assert!(queue.next_batch(None, None, 1, 1).await.is_none()); assert!(queue.next_batch(None, None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
@ -651,7 +670,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_min_size() { async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false, 1, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -684,7 +703,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_max_size() { async fn test_queue_next_batch_max_size() {
let queue = Queue::new(false, 1, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -700,7 +719,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_budget() { async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false, 1, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -725,7 +744,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_speculate() { async fn test_queue_next_batch_token_speculate() {
let queue = Queue::new(false, 1, None, 2, 16); let queue = Queue::new(false, 1, false, None, 2, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -744,7 +763,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_dropped_receiver() { async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false, 1, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16);
let (entry, _) = default_entry(); let (entry, _) = default_entry();
queue.append(entry); queue.append(entry);

755
backends/v3/src/radix.rs Normal file
View File

@ -0,0 +1,755 @@
use std::{
collections::{BTreeSet, HashMap},
sync::Arc,
};
use slotmap::{DefaultKey, SlotMap};
use crate::block_allocator::{Allocator, BlockAllocation};
pub struct RadixAllocator {
allocation_id: u64,
allocations: HashMap<u64, RadixAllocation>,
cache_blocks: RadixTrie,
/// Blocks that are immediately available for allocation.
free_blocks: Vec<u32>,
}
impl RadixAllocator {
pub fn new(block_size: u32, n_blocks: u32, window_size: Option<u32>) -> Self {
assert_eq!(
block_size, 1,
"Radix tree allocator only works with block_size=1, was: {}",
block_size
);
if window_size.is_some() {
unimplemented!("Window size not supported in the prefix-caching block allocator yet");
}
RadixAllocator {
allocation_id: 0,
allocations: HashMap::new(),
cache_blocks: RadixTrie::new(),
// Block 0 is reserved for health checks.
free_blocks: (1..n_blocks).collect(),
}
}
fn alloc_or_reclaim(&mut self, n_blocks_needed: usize) -> Option<Vec<u32>> {
if self.free_blocks.len() < n_blocks_needed {
// This is a bit annoying, we first extend the free list and then
// split it off again below. This is because we need to put it on
// the free list if we cannot allocate enough blocks. This is only
// temporary, the trie needs to be able to report whether it can
// allocate the requested amount. Just not implemented yet.
self.free_blocks.extend(
self.cache_blocks
.evict(n_blocks_needed - self.free_blocks.len()),
);
}
if self.free_blocks.len() >= n_blocks_needed {
Some(
self.free_blocks
.split_off(self.free_blocks.len() - n_blocks_needed),
)
} else {
None
}
}
}
impl Allocator for RadixAllocator {
fn allocate(
&mut self,
tokens: u32,
prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation> {
let mut blocks = vec![];
let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() {
let node_id = self
.cache_blocks
.find(prefill_tokens.as_slice(), &mut blocks);
// Even if this allocation fails below, we need to increase he
// refcount to ensure that the prefix that was found is not evicted.
node_id
} else {
self.cache_blocks.root_id()
};
self.cache_blocks
.incref(prefix_node)
.expect("Failed to increment refcount");
let prefix_len = blocks.len();
let suffix_len = tokens - prefix_len as u32;
match self.alloc_or_reclaim(suffix_len as usize) {
Some(suffix_blocks) => blocks.extend(suffix_blocks),
None => {
self.cache_blocks
.decref(prefix_node)
.expect("Failed to decrement refcount");
return None;
}
}
// 1:1 mapping of blocks and slots.
let slots = blocks.clone();
let allocation = RadixAllocation {
prefix_node,
cached_prefix_len: prefix_len,
prefill_tokens: prefill_tokens.clone(),
};
self.allocation_id += 1;
self.allocations.insert(self.allocation_id, allocation);
Some(BlockAllocation {
allocation_id: self.allocation_id,
block_allocator: None,
blocks,
slots,
prefix_len: prefix_len as u32,
})
}
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
let allocation = match self.allocations.remove(&allocation_id) {
Some(allocation) => allocation,
None => unreachable!("Tried to free an unknown allocation."),
};
self.cache_blocks
.decref(allocation.prefix_node)
.expect("Failed to decrement refcount");
if let Some(prefill_tokens) = allocation.prefill_tokens {
let prefill_tokens = prefill_tokens.as_slice();
// If there are prefill tokens that did not come from the cache,
// add them to the cache.
if prefill_tokens.len() > allocation.cached_prefix_len {
let prefix_len = self
.cache_blocks
.insert(prefill_tokens, &blocks[..prefill_tokens.len()])
// Unwrap, failing is a programming error.
.expect("Failed to store prefill tokens");
// We can have a prefill with the following structure:
//
// |---| From the prefix cache.
// A B C D E F G
//|--------| Found in the trie during insertion.
//
// This means that while processing this request there was a
// partially overlapping request that had A..=E in its
// prefill. In this case we need to free the blocks D E.
self.free_blocks
.extend(&blocks[allocation.cached_prefix_len..prefix_len]);
}
// Free non-prefill blocks.
self.free_blocks.extend(&blocks[prefill_tokens.len()..]);
} else {
self.free_blocks.extend(blocks);
}
}
}
struct RadixAllocation {
prefix_node: NodeId,
cached_prefix_len: usize,
prefill_tokens: Option<Arc<Vec<u32>>>,
}
// Radix trie that is heavily inspired by radix attention from sglang.
//
// The trie is optimized for prefix caching:
//
// - A normal radix trie stores discrete values. In this radix trie,
// inserting *abc* with value *xyz* will also enable lookup for
// *a* (*x*) and *ab* (*xy*).
// - As a result, every value is required to have the same length as
// the key.
// - We store additional information in each node, such as last access
// time and a reference count.
#[derive(Debug)]
pub enum TrieError {
InvalidNodeId,
RefCountUnderflow,
BlockTokenCountMismatch,
}
pub type NodeId = DefaultKey;
#[derive(Debug)]
pub struct RadixTrie {
/// Identifier of the root nod.
root: DefaultKey,
/// Leave node identifiers ordered by increasing recency.
leaves: BTreeSet<(u64, NodeId)>,
/// All trie nodes.
nodes: SlotMap<NodeId, TrieNode>,
/// Time as a monotonically increating counter to avoid the system
/// call that a real time lookup would require.
time: u64,
}
impl RadixTrie {
/// Construct a new radix trie.
pub fn new() -> Self {
let root = TrieNode::new(vec![], vec![], 0, None);
let mut nodes = SlotMap::new();
let root = nodes.insert(root);
RadixTrie {
leaves: BTreeSet::new(),
nodes,
root,
time: 0,
}
}
/// Find the prefix of the given tokens.
///
/// The blocks corresponding to the part of the prefix that could be found
/// are writteng to `blocks`. The number of blocks is in `0..=tokens.len()`.
/// Returns the identifier of the trie node that contains the longest
/// prefix. The node identifier can be used by callers to e.g. increase its
/// reference count.
///
/// Using this method will update the access time of the traversed nodes.
pub fn find(&mut self, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
self.time += 1;
self.find_(self.root, key, blocks)
}
/// Find worker.
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
let node = &self.nodes[node_id];
if let Some(&child_id) = node.children.get(&key[0]) {
self.update_access_time(child_id);
let child = self.nodes.get(child_id).expect("Invalid child identifier");
let shared_prefix_len = child.key.shared_prefix_len(key);
blocks.extend(&child.blocks[..shared_prefix_len]);
let key = &key[shared_prefix_len..];
if !key.is_empty() {
node_id = self.find_(child_id, key, blocks);
}
}
node_id
}
/// Decrease the reference count of a node.
pub fn decref(&mut self, node_id: NodeId) -> Result<(), TrieError> {
// We don't care about refcounting for root, since it will never
// be evicted.
if node_id == self.root {
return Ok(());
}
let node = self
.nodes
.get_mut(node_id)
.ok_or(TrieError::InvalidNodeId)?;
if node.ref_count == 0 {
return Err(TrieError::RefCountUnderflow);
}
node.ref_count -= 1;
if node.ref_count == 0 {
self.leaves.insert((node.last_accessed, node_id));
}
Ok(())
}
/// Increase the reference count of a node.
pub fn incref(&mut self, node_id: NodeId) -> Result<(), TrieError> {
if node_id == self.root {
return Ok(());
}
let node = self
.nodes
.get_mut(node_id)
.ok_or(TrieError::InvalidNodeId)?;
if node.ref_count == 0 {
self.leaves.remove(&(node.last_accessed, node_id));
}
node.ref_count += 1;
Ok(())
}
/// Evict `n_blocks` from the trie.
///
/// Returns the evicted blocks. When the length is less than `n_blocks`,
/// not enough blocks could beevicted.
pub fn evict(&mut self, n_blocks: usize) -> Vec<u32> {
// NOTE: we don't return Result here. If any of the unwrapping fails,
// it's a programming error in the trie implementation, not a user
// error caused by e.g. an invalid argument.
// TODO: add some bookkeeping in the future to check whether we can
// evict n_blocks and return `None` if we can't. We are now needlessly
// evicting prefixes from the cache in such a case.
let mut evicted = Vec::new();
while let Some((last_access, node_id)) = self.leaves.pop_first() {
let blocks_needed = n_blocks - evicted.len();
let node = self.nodes.get(node_id).expect("Leave does not exist");
if blocks_needed >= node.blocks.len() {
// We need to evict the whole node if we need more blocks than it has.
let node = self.remove_node(node_id);
evicted.extend(node.blocks);
if evicted.len() >= n_blocks {
break;
}
} else {
// The node has more blocks than needed, so we'll just remove
// the required number of blocks and leave the remaining blocks
// untouched.
let node = self.nodes.get_mut(node_id).expect("Leave does not exist");
node.key.truncate(node.blocks.len() - blocks_needed);
evicted.extend(node.blocks.split_off(node.blocks.len() - blocks_needed));
self.leaves.insert((last_access, node_id));
break;
}
}
evicted
}
/// Insert a prefill along with its blocks.
///
/// This method returns the length of the prefix that was already
/// in the trie. E.g. if the length is 10, this means that for
/// the first 10 elements of the tree **the blocks are not updated**.
pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result<usize, TrieError> {
self.time += 1;
self.insert_(self.root, tokens, blocks)
}
/// Insertion worker.
fn insert_(
&mut self,
node_id: NodeId,
tokens: &[u32],
blocks: &[u32],
) -> Result<usize, TrieError> {
// TODO: in the future we may want to check that the blocks match for
// the part of the prefix that is already in the trie to detect
// mismatches.
if tokens.len() != blocks.len() {
return Err(TrieError::BlockTokenCountMismatch);
}
if let Some(&child_id) = self.nodes[node_id].children.get(&tokens[0]) {
self.update_access_time(child_id);
let child = self
.nodes
.get_mut(child_id)
// Unwrap here, since failure is a bug.
.expect("Child node does not exist");
let shared_prefix_len = child.key.shared_prefix_len(tokens);
// We are done, the prefix is already in the trie.
if shared_prefix_len == tokens.len() {
return Ok(shared_prefix_len);
}
// The node's prefix is a prefix of the insertion prefix.
if shared_prefix_len == child.key.len() {
return Ok(shared_prefix_len
+ self.insert_(
child_id,
&tokens[shared_prefix_len..],
&blocks[shared_prefix_len..],
)?);
}
// The node's prefix and the insertion prefix only match partially,
// split the node to just contain the matching part. Then insert the
// remainder of the prefix into the node again
let child_id = self.split_node(child_id, shared_prefix_len);
let key = &tokens[shared_prefix_len..];
let blocks = &blocks[shared_prefix_len..];
Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?)
} else {
self.add_node(node_id, tokens, blocks);
Ok(0)
}
}
fn split_node(&mut self, node_id: NodeId, prefix_len: usize) -> NodeId {
// We have to make the current node a child to ensure that its
// properties and node id stay the same.
// This funcion unwraps, an invalid node_id is a programming error.
let node = self
.nodes
.get_mut(node_id)
.expect("Node to-be split does not exist");
let mut parent_key = node.key.split_off(prefix_len);
let mut parent_blocks = node.blocks.split_off(prefix_len);
// Move first part of the prefix to the parent. We swap to avoid
// an allocation + copy for both splits of the key/blocks.
std::mem::swap(&mut node.key, &mut parent_key);
std::mem::swap(&mut node.blocks, &mut parent_blocks);
let node_key = node.key[0];
let grandparent_id = node.parent.expect("Node does not have a parent");
let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks);
self.add_node_to_parent(parent_id, node_key, node_id);
// Reborrow to make the borrow checker happy.
let node = self
.nodes
.get_mut(node_id)
.expect("Node to-be split does not exist");
node.parent = Some(parent_id);
parent_id
}
/// Create a node and add it to the parent.
fn add_node(
&mut self,
parent_id: NodeId,
key: impl Into<Vec<u32>>,
blocks: impl Into<Vec<u32>>,
) -> NodeId {
let key = key.into();
let blocks = blocks.into();
let first = key[0];
let child = TrieNode::new(key, blocks, self.time, Some(parent_id));
let child_id = self.nodes.insert(child);
self.add_node_to_parent(parent_id, first, child_id);
self.leaves.insert((self.time, child_id));
child_id
}
/// Add a node to the parent.
fn add_node_to_parent(&mut self, parent_id: NodeId, first: u32, child_id: NodeId) {
// Unwrap here, passing in an unknown id is a programming error.
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
if parent.children.insert(first, child_id).is_none() {
// Only increase reference count if child does not replace another child.
self.incref(parent_id)
.expect("Failed to increase parent refcount");
}
}
/// Remove a node from the trie.
fn remove_node(&mut self, node_id: NodeId) -> TrieNode {
// Unwrap here, passing in an unknown id is a programming error.
let node = self.nodes.remove(node_id).expect("Unknown node");
let parent_id = node.parent.expect("Attempted to remove root node");
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
parent.children.remove(&node.key[0]);
self.decref(parent_id)
.expect("Failed to decrease parent refcount");
self.nodes.remove(node_id);
node
}
fn update_access_time(&mut self, node_id: NodeId) {
// Unwrap here, passing in an unknown id is a programming error.
let node = self.nodes.get_mut(node_id).expect("Unknown node");
// Update the ordered leaves set if the node is a leave.
if self.leaves.remove(&(node.last_accessed, node_id)) {
self.leaves.insert((self.time, node_id));
}
node.last_accessed = self.time;
}
#[allow(dead_code)]
#[doc(hidden)]
/// Print debugging output for the trie.
///
/// In contrast to `Debug` nicely formatted.
pub fn print_debug(&self) {
self.print_debug_(self.root, 0);
}
fn print_debug_(&self, node_id: NodeId, indent: usize) {
let node = &self.nodes[node_id];
eprintln!(
"{}{:?}, key: {:?}, blocks: {:?}, ref_count: {}, last_accessed: {}, parent: {:?}, children: {:?}",
" ".repeat(indent),
node_id,
node.key,
node.blocks,
node.ref_count,
node.last_accessed,
node.parent,
node.children
);
for child_id in self.nodes[node_id].children.values() {
self.print_debug_(*child_id, indent + 2);
}
}
pub(crate) fn root_id(&self) -> DefaultKey {
self.root
}
}
/// Trie node.
#[derive(Debug)]
struct TrieNode {
blocks: Vec<u32>,
children: HashMap<u32, NodeId>,
key: Vec<u32>,
last_accessed: u64,
parent: Option<NodeId>,
ref_count: usize,
}
impl TrieNode {
fn new(key: Vec<u32>, blocks: Vec<u32>, last_accessed: u64, parent: Option<NodeId>) -> Self {
TrieNode {
children: HashMap::new(),
key,
blocks,
last_accessed,
parent,
ref_count: 0,
}
}
}
/// Helper trait to get the length of the shared prefix of two sequences.
trait SharedPrefixLen {
fn shared_prefix_len(&self, other: &Self) -> usize;
}
impl<T> SharedPrefixLen for [T]
where
T: PartialEq,
{
fn shared_prefix_len(&self, other: &Self) -> usize {
self.iter().zip(other).take_while(|(a, b)| a == b).count()
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::block_allocator::Allocator;
use super::RadixAllocator;
#[test]
fn allocator_reuses_prefixes() {
let mut cache = RadixAllocator::new(1, 12, None);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
assert_eq!(allocation.slots, allocation.slots);
assert_eq!(allocation.prefix_len, 0);
cache.free(allocation.blocks.clone(), allocation.allocation_id);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
assert_eq!(allocation.prefix_len, 4);
}
#[test]
fn allocator_collects_older_prefixes_first() {
let mut cache = RadixAllocator::new(1, 7, None);
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]);
assert_eq!(allocation1.prefix_len, 0);
let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap();
assert_eq!(allocation2.blocks, vec![1, 2]);
assert_eq!(allocation2.prefix_len, 0);
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
// We should get the blocks of the first allocation, since they are more recent.
let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap();
assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]);
assert_eq!(allocation3.prefix_len, 0);
}
#[test]
fn allocator_frees_fully_overlapping_prefills() {
let mut cache = RadixAllocator::new(1, 10, None);
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation3.prefix_len, 4);
// 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks.
assert_eq!(cache.free_blocks.len(), 5);
}
#[test]
fn allocator_frees_partially_overlapping_prefills() {
let mut cache = RadixAllocator::new(1, 20, None);
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap();
assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]);
assert_eq!(allocation1.prefix_len, 0);
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
let allocation2 = cache
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))
.unwrap();
assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]);
assert_eq!(allocation2.prefix_len, 2);
let allocation3 = cache
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))
.unwrap();
assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]);
assert_eq!(allocation3.prefix_len, 2);
cache.free(allocation3.blocks.clone(), allocation3.allocation_id);
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
// 20 blocks, of which 1 reserved for health checks, 6 for allocation3, 2 for allocation2.
assert_eq!(cache.free_blocks.len(), 11);
let allocation4 = cache
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))
.unwrap();
assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]);
assert_eq!(allocation4.prefix_len, 6);
assert_eq!(cache.free_blocks.len(), 11);
let allocation5 = cache
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))
.unwrap();
assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]);
assert_eq!(allocation5.prefix_len, 6);
assert_eq!(cache.free_blocks.len(), 11);
}
#[test]
fn trie_insertions_have_correct_prefix_len() {
let mut trie = super::RadixTrie::new();
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0);
// Already exists.
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 3);
// Completely new at root-level
assert_eq!(trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(), 0);
// Contains full prefix, but longer.
assert_eq!(trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(), 3);
// Shares partial prefix, we need a split.
assert_eq!(
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
.unwrap(),
4
);
}
#[test]
fn trie_get_returns_correct_blocks() {
let mut trie = super::RadixTrie::new();
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
.unwrap();
let mut blocks = Vec::new();
trie.find(&[0], &mut blocks);
assert_eq!(blocks, vec![0]);
blocks.clear();
trie.find(&[0, 1, 2], &mut blocks);
assert_eq!(blocks, vec![0, 1, 2]);
blocks.clear();
trie.find(&[1, 2, 3], &mut blocks);
assert_eq!(blocks, vec![1, 2, 3]);
blocks.clear();
trie.find(&[0, 1, 2, 3], &mut blocks);
assert_eq!(blocks, vec![0, 1, 2, 3]);
blocks.clear();
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
assert_eq!(blocks, vec![0, 1, 2, 3, 4]);
blocks.clear();
trie.find(&[0, 1, 2, 3, 5], &mut blocks);
assert_eq!(blocks, vec![0, 1, 2, 3, 5]);
}
#[test]
fn trie_evict_removes_correct_blocks() {
let mut trie = super::RadixTrie::new();
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
.unwrap();
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
let mut blocks = Vec::new();
// Remove less than the leave blocks.
assert_eq!(trie.evict(1), vec![7]);
trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks);
assert_eq!(blocks, vec![0, 1, 2, 3, 5, 6]);
// Refresh other leaf.
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
trie.find(&[1, 2, 3], &mut blocks);
// Remove the leave blocks exactly.
assert_eq!(trie.evict(2), vec![5, 6]);
blocks.clear();
trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks);
assert_eq!(blocks, vec![0, 1, 2, 3]);
trie.find(&[1, 2, 3], &mut blocks);
// Remove more than the leave blocks.
assert_eq!(trie.evict(3), vec![4, 3, 2]);
blocks.clear();
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
assert_eq!(blocks, vec![0, 1]);
// Clear out the whole trie.
assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]);
}
}

View File

@ -157,6 +157,7 @@ async fn prefill(
top_n_tokens: top_n_tokens.unwrap_or(0), top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
prefix_len: 0,
adapter_id: None, adapter_id: None,
}) })
.collect(); .collect();

View File

@ -4,21 +4,22 @@ package generate.v3;
service TextGenerationService { service TextGenerationService {
/// Model Info /// Model Info
rpc Info (InfoRequest) returns (InfoResponse) {} rpc Info(InfoRequest) returns (InfoResponse) {}
/// Service discovery /// Service discovery
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {} rpc ServiceDiscovery(ServiceDiscoveryRequest)
returns (ServiceDiscoveryResponse) {}
/// Empties batch cache /// Empties batch cache
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse); rpc ClearCache(ClearCacheRequest) returns (ClearCacheResponse);
/// Remove requests from a cached batch /// Remove requests from a cached batch
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse); rpc FilterBatch(FilterBatchRequest) returns (FilterBatchResponse);
/// Warmup the model and compute max cache size /// Warmup the model and compute max cache size
rpc Warmup (WarmupRequest) returns (WarmupResponse); rpc Warmup(WarmupRequest) returns (WarmupResponse);
/// Prefill batch and decode first token /// Prefill batch and decode first token
rpc Prefill (PrefillRequest) returns (PrefillResponse); rpc Prefill(PrefillRequest) returns (PrefillResponse);
/// Decode token for a list of prefilled batches /// Decode token for a list of prefilled batches
rpc Decode (DecodeRequest) returns (DecodeResponse); rpc Decode(DecodeRequest) returns (DecodeResponse);
/// Health check /// Health check
rpc Health (HealthRequest) returns (HealthResponse); rpc Health(HealthRequest) returns (HealthResponse);
} }
message HealthRequest {} message HealthRequest {}
@ -68,9 +69,7 @@ message InputChunk {
} }
} }
message Input { message Input { repeated InputChunk chunks = 1; }
repeated InputChunk chunks = 1;
}
enum GrammarType { enum GrammarType {
GRAMMAR_TYPE_NONE = 0; GRAMMAR_TYPE_NONE = 0;
@ -136,6 +135,8 @@ message Request {
repeated uint32 slots = 10; repeated uint32 slots = 10;
/// LORA adapter index /// LORA adapter index
optional string adapter_id = 11; optional string adapter_id = 11;
/// Prefix length that can be retrieved from the KV cache.
uint32 prefix_len = 12;
} }
message Batch { message Batch {
@ -214,7 +215,6 @@ message FilterBatchResponse {
CachedBatch batch = 1; CachedBatch batch = 1;
} }
message PrefillRequest { message PrefillRequest {
/// Batch /// Batch
Batch batch = 1; Batch batch = 1;

View File

@ -11,6 +11,7 @@ use rand::{thread_rng, Rng};
use serde_json::Value; use serde_json::Value;
use std::io::Cursor; use std::io::Cursor;
use std::iter; use std::iter;
use std::sync::Arc;
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokio::sync::mpsc; use tokio::sync::mpsc;
@ -115,13 +116,14 @@ impl Validation {
} }
} }
#[allow(clippy::type_complexity)]
#[instrument(skip(self, inputs))] #[instrument(skip(self, inputs))]
async fn validate_input( async fn validate_input(
&self, &self,
inputs: String, inputs: String,
truncate: Option<usize>, truncate: Option<usize>,
max_new_tokens: Option<u32>, max_new_tokens: Option<u32>,
) -> Result<(Vec<Chunk>, usize, u32), ValidationError> { ) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
// If we have a fast tokenizer // If we have a fast tokenizer
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
// Create response channel // Create response channel
@ -156,8 +158,10 @@ impl Validation {
)); ));
} }
let input_ids = encoding.get_ids()[..input_length].to_owned();
metrics::histogram!("tgi_request_input_length").record(input_length as f64); metrics::histogram!("tgi_request_input_length").record(input_length as f64);
Ok((inputs, input_length, max_new_tokens)) Ok((inputs, Some(input_ids), input_length, max_new_tokens))
} }
// Return inputs without validation // Return inputs without validation
else { else {
@ -180,7 +184,12 @@ impl Validation {
input_length = input_length.saturating_sub(max_new_tokens as usize); input_length = input_length.saturating_sub(max_new_tokens as usize);
} }
Ok((vec![Chunk::Text(inputs)], input_length, max_new_tokens)) Ok((
vec![Chunk::Text(inputs)],
None,
input_length,
max_new_tokens,
))
} }
} }
@ -314,7 +323,7 @@ impl Validation {
.unwrap_or(Ok(None))?; .unwrap_or(Ok(None))?;
// Validate inputs // Validate inputs
let (inputs, input_length, max_new_tokens) = self let (inputs, input_ids, input_length, max_new_tokens) = self
.validate_input(request.inputs, truncate, max_new_tokens) .validate_input(request.inputs, truncate, max_new_tokens)
.await?; .await?;
@ -391,6 +400,7 @@ impl Validation {
Ok(ValidGenerateRequest { Ok(ValidGenerateRequest {
inputs, inputs,
input_ids: input_ids.map(Arc::new),
decoder_input_details, decoder_input_details,
input_length: input_length as u32, input_length: input_length as u32,
truncate: truncate.unwrap_or(self.max_input_length) as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32,
@ -707,6 +717,7 @@ pub struct ValidStoppingParameters {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ValidGenerateRequest { pub struct ValidGenerateRequest {
pub inputs: Vec<Chunk>, pub inputs: Vec<Chunk>,
pub input_ids: Option<Arc<Vec<u32>>>,
pub input_length: u32, pub input_length: u32,
pub truncate: u32, pub truncate: u32,
pub decoder_input_details: bool, pub decoder_input_details: bool,

View File

@ -6,7 +6,12 @@ from .common import Seqlen
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.") raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda": if SYSTEM == "cuda":
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING from .cuda import (
attention,
paged_attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
)
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif SYSTEM == "ipex": elif SYSTEM == "ipex":

View File

@ -221,9 +221,11 @@ SUPPORTS_WINDOWING = V2
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
def attention( def attention(
q, q: torch.Tensor,
k, k: torch.Tensor,
v, v: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,
@ -231,14 +233,15 @@ if ATTENTION == "flashinfer":
causal=True, causal=True,
softcap=0.0, softcap=0.0,
): ):
from text_generation_server.layers.attention.flash_infer import prefill_state assert window_size_left == -1, "Windowing is not supported with flash infer"
from text_generation_server.layers.attention.flash_infer import (
prefill_with_paged_kv_state,
)
return prefill_state.get().forward( return prefill_with_paged_kv_state.get().forward(
q, q.contiguous(),
k,
v,
causal=causal, causal=causal,
window_left=window_size_left, paged_kv_cache=(key_cache, value_cache),
logits_soft_cap=softcap, logits_soft_cap=softcap,
sm_scale=softmax_scale, sm_scale=softmax_scale,
) )
@ -249,6 +252,8 @@ elif V2:
q, q,
k, k,
v, v,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,
@ -289,6 +294,8 @@ else:
q, q,
k, k,
v, v,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,

View File

@ -9,6 +9,10 @@ prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = Con
"prefill_state" "prefill_state"
) )
prefill_with_paged_kv_state: ContextVar[
flashinfer.BatchPrefillWithPagedKVCacheWrapper
] = ContextVar("prefill_with_paged_kv_state")
decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar( decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
"decode_state" "decode_state"
) )
@ -24,6 +28,78 @@ def get_workspace(device):
return workspace return workspace
def create_prefill_with_paged_kv_state(
*,
device: torch.device,
):
"""Create a prefill state that uses the KV cache."""
workspace_buffer = get_workspace(device)
return flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
)
@contextmanager
def use_prefill_with_paged_kv_state(
*,
state: flashinfer.BatchPrefillWithPagedKVCacheWrapper,
block_tables: torch.Tensor,
cu_seqlens: torch.Tensor,
input_lengths: torch.Tensor,
num_heads: int,
num_kv_heads: int,
head_size: int,
page_size: int,
query_dtype: str = "float16",
):
"""
Context manager to set the active flashinfer prefill state to the given
`state` and parameters. This state will be used by all calls to the
`attention` function while the context manager is active.
"""
indptr = torch.zeros(
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
)
# Round up to page size and then calculate the cumulative sum to get
# the indices into the block table.
torch.add(input_lengths, page_size - 1, out=indptr[1:])
indptr[1:].div_(page_size, rounding_mode="floor")
indptr[1:].cumsum_(-1)
# Get the lengths of the last page in a block.
if page_size == 1:
last_page_len = torch.ones(
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
)
else:
last_page_len = torch.empty(
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
)
torch.sub(input_lengths, 1, out=last_page_len)
last_page_len.remainder_(page_size)
last_page_len += 1
token = prefill_with_paged_kv_state.set(state)
try:
state.begin_forward(
qo_indptr=cu_seqlens,
paged_kv_indptr=indptr,
paged_kv_indices=block_tables,
paged_kv_last_page_len=last_page_len,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
q_data_type=query_dtype,
page_size=page_size,
)
yield
finally:
state.end_forward()
if token is not None:
prefill_with_paged_kv_state.reset(token)
def create_prefill_state( def create_prefill_state(
*, *,
device: torch.device, device: torch.device,

View File

@ -298,6 +298,8 @@ class FlashCohereAttention(torch.nn.Module):
query, query,
key, key,
value, value,
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -337,6 +337,8 @@ class DbrxAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -365,6 +365,8 @@ class DeepseekV2Attention(torch.nn.Module):
query, query,
key, key,
value, value,
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -238,6 +238,8 @@ class FlashGemma2Attention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -232,6 +232,8 @@ class FlashGemmaAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -232,6 +232,8 @@ class FlashGPT2Attention(torch.nn.Module):
query, query,
key, key,
value, value,
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -21,6 +21,7 @@
from contextlib import contextmanager from contextlib import contextmanager
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from loguru import logger
import torch import torch
import torch.distributed import torch.distributed
@ -220,6 +221,8 @@ class FlashLlamaAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -219,6 +219,8 @@ class MistralAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -276,6 +276,8 @@ class MixtralAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -173,6 +173,8 @@ class FlashNeoxAttention(torch.nn.Module):
qkv[:, 0], qkv[:, 0],
qkv[:, 1], qkv[:, 1],
qkv[:, 2], qkv[:, 2],
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -194,6 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -137,6 +137,8 @@ class Qwen2Attention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -208,6 +208,8 @@ class FlashRWAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -326,6 +328,8 @@ class FlashRWLargeAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1), torch.select(kv, dim=2, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -293,6 +293,8 @@ class FlashMQAttention(torch.nn.Module):
query, query,
torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1), torch.select(key_value, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -242,6 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -20,6 +20,9 @@ from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, D
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from text_generation_server.layers.attention.flash_infer import (
create_prefill_with_paged_kv_state,
)
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models import Model from text_generation_server.models import Model
@ -43,6 +46,7 @@ from text_generation_server.models.globals import (
ATTENTION, ATTENTION,
BLOCK_SIZE, BLOCK_SIZE,
CUDA_GRAPHS, CUDA_GRAPHS,
PREFIX_CACHING,
get_adapter_to_index, get_adapter_to_index,
) )
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
@ -108,6 +112,9 @@ class FlashCausalLMBatch(Batch):
block_tables_tensor: torch.Tensor block_tables_tensor: torch.Tensor
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
slots: torch.Tensor slots: torch.Tensor
# size [b], containing the number of blocks that can be retrieved from the cache
prefix_lens: List[int]
prefix_lens_tensor: torch.Tensor
max_seqlen: int max_seqlen: int
@ -116,6 +123,9 @@ class FlashCausalLMBatch(Batch):
prefill_next_token_indices: Optional[torch.tensor] prefill_next_token_indices: Optional[torch.tensor]
prefill_cu_outlens: Optional[List[int]] prefill_cu_outlens: Optional[List[int]]
# Prefixes
prefix_ids: List[List[int]]
# All tokens # All tokens
all_input_ids: List[List[int]] all_input_ids: List[List[int]]
all_input_ids_tensor: torch.Tensor all_input_ids_tensor: torch.Tensor
@ -183,6 +193,7 @@ class FlashCausalLMBatch(Batch):
prefix_offsets = [] prefix_offsets = []
read_offsets = [] read_offsets = []
all_input_ids = [] all_input_ids = []
prefix_ids = []
requests_idx_mapping = {} requests_idx_mapping = {}
all_prefill_logprobs = True all_prefill_logprobs = True
@ -200,7 +211,7 @@ class FlashCausalLMBatch(Batch):
# Cumulative length # Cumulative length
cumulative_length = 0 cumulative_length = 0
cumulative_max_length = 0 cumulative_slot_tokens = 0
prefill_out_cumulative_length = 0 prefill_out_cumulative_length = 0
num_blocks = 0 num_blocks = 0
@ -210,6 +221,7 @@ class FlashCausalLMBatch(Batch):
block_tables = [] block_tables = []
slots = [] slots = []
prefix_lens = []
# Parse batch # Parse batch
for i, (r, tokenized_input) in enumerate( for i, (r, tokenized_input) in enumerate(
@ -225,6 +237,19 @@ class FlashCausalLMBatch(Batch):
): ):
tokenized_input = tokenized_input[1:] tokenized_input = tokenized_input[1:]
orig_input_length = len(tokenized_input)
if PREFIX_CACHING:
prefix_len = r.prefix_len
if prefix_len == orig_input_length:
assert prefix_len > 0
prefix_len -= 1
else:
prefix_len = 0
prefix_ids.append(tokenized_input[:prefix_len])
tokenized_input = tokenized_input[prefix_len:]
input_length = len(tokenized_input) input_length = len(tokenized_input)
input_lengths.append(input_length) input_lengths.append(input_length)
@ -234,7 +259,9 @@ class FlashCausalLMBatch(Batch):
all_input_ids.append(tokenized_input) all_input_ids.append(tokenized_input)
# Position ids # Position ids
request_position_ids = torch.arange(0, input_length, dtype=torch.int32) request_position_ids = torch.arange(
prefix_len, orig_input_length, dtype=torch.int32
)
position_ids.append(request_position_ids) position_ids.append(request_position_ids)
# Add cumulative lengths of all previous inputs # Add cumulative lengths of all previous inputs
@ -258,11 +285,17 @@ class FlashCausalLMBatch(Batch):
# Remove one as the first token des not have a past # Remove one as the first token des not have a past
speculative_length = get_speculate() speculative_length = get_speculate()
speculative_length = 0 if speculative_length is None else speculative_length speculative_length = 0 if speculative_length is None else speculative_length
total_tokens = input_length + max_new_tokens - 1 + speculative_length
# Tokens that need to be mapped to blocks.
block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length
# Tokens that need to be mapped to slots. We don't need slots for the
# cached prefix (if present).
slot_tokens = input_length + max_new_tokens - 1 + speculative_length
# blocks and slots can be empty (for example in warmup) # blocks and slots can be empty (for example in warmup)
if not r.blocks: if not r.blocks:
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) needed_blocks = math.ceil(block_tokens / BLOCK_SIZE)
request_blocks = [ request_blocks = [
b for b in range(num_blocks, num_blocks + needed_blocks) b for b in range(num_blocks, num_blocks + needed_blocks)
] ]
@ -273,16 +306,20 @@ class FlashCausalLMBatch(Batch):
] ]
else: else:
request_blocks = r.blocks request_blocks = r.blocks
request_slots = r.slots request_slots = r.slots[
prefix_len: #: orig_input_length + max_new_tokens + speculative_length
]
block_tables.append(request_blocks) block_tables.append(request_blocks)
slots.extend(request_slots[:total_tokens])
slots.extend(request_slots)
prefix_lens.append(prefix_len)
num_blocks += len(request_blocks) num_blocks += len(request_blocks)
start_slots.append(cumulative_max_length) start_slots.append(cumulative_slot_tokens)
request_slot_indices = torch.arange( request_slot_indices = torch.arange(
cumulative_max_length, cumulative_slot_tokens,
cumulative_max_length + input_length, cumulative_slot_tokens + input_length,
dtype=torch.int64, dtype=torch.int64,
) )
slot_indices.append(request_slot_indices) slot_indices.append(request_slot_indices)
@ -318,7 +355,7 @@ class FlashCausalLMBatch(Batch):
# Update # Update
cumulative_length += input_length cumulative_length += input_length
cumulative_max_length += total_tokens cumulative_slot_tokens += slot_tokens
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, len(request_blocks)) max_blocks = max(max_blocks, len(request_blocks))
max_length = max( max_length = max(
@ -395,12 +432,14 @@ class FlashCausalLMBatch(Batch):
) )
slots = torch.tensor(slots, dtype=torch.int64, device=device) slots = torch.tensor(slots, dtype=torch.int64, device=device)
block_tables_tensor = torch.zeros( block_tables_tensor = torch.zeros(
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu" (len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
) )
for i, request_blocks in enumerate(block_tables): for i, request_blocks in enumerate(block_tables):
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
block_tables_tensor = block_tables_tensor.to(device) block_tables_tensor = block_tables_tensor.to(device)
prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
@ -415,6 +454,8 @@ class FlashCausalLMBatch(Batch):
block_tables=block_tables, block_tables=block_tables,
block_tables_tensor=block_tables_tensor, block_tables_tensor=block_tables_tensor,
slots=slots, slots=slots,
prefix_lens=prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=prefill_head_indices, prefill_head_indices=prefill_head_indices,
prefill_next_token_indices=prefill_next_token_indices, prefill_next_token_indices=prefill_next_token_indices,
@ -425,6 +466,7 @@ class FlashCausalLMBatch(Batch):
read_offsets=read_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
@ -480,6 +522,7 @@ class FlashCausalLMBatch(Batch):
start_slots = [] start_slots = []
block_tables = [] block_tables = []
all_input_ids = [] all_input_ids = []
prefix_ids = []
input_lengths = [] input_lengths = []
prefix_offsets = [] prefix_offsets = []
@ -506,6 +549,7 @@ class FlashCausalLMBatch(Batch):
max_seqlen = max(max_seqlen, request_input_length) max_seqlen = max(max_seqlen, request_input_length)
all_input_ids.append(self.all_input_ids[idx]) all_input_ids.append(self.all_input_ids[idx])
prefix_ids.append(self.prefix_ids[idx])
input_lengths.append(request_input_length) input_lengths.append(request_input_length)
prefix_offsets.append(self.prefix_offsets[idx]) prefix_offsets.append(self.prefix_offsets[idx])
@ -591,6 +635,7 @@ class FlashCausalLMBatch(Batch):
read_offsets=read_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
@ -651,6 +696,7 @@ class FlashCausalLMBatch(Batch):
block_tables_tensor = batches[0].block_tables_tensor.new_zeros( block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
(total_batch_size, max_blocks) (total_batch_size, max_blocks)
) )
prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size)
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
(total_batch_size, max_length) (total_batch_size, max_length)
) )
@ -668,7 +714,9 @@ class FlashCausalLMBatch(Batch):
start_slots = [] start_slots = []
block_tables = [] block_tables = []
prefix_lens = []
all_input_ids = [] all_input_ids = []
prefix_ids = []
input_lengths = [] input_lengths = []
prefix_offsets = [] prefix_offsets = []
@ -730,10 +778,14 @@ class FlashCausalLMBatch(Batch):
start_index:end_index, : batch.block_tables_tensor.shape[1] start_index:end_index, : batch.block_tables_tensor.shape[1]
] = batch.block_tables_tensor[:, :max_blocks] ] = batch.block_tables_tensor[:, :max_blocks]
prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor
start_slots.append(batch.start_slots + cumulative_slots) start_slots.append(batch.start_slots + cumulative_slots)
block_tables.extend(batch.block_tables) block_tables.extend(batch.block_tables)
prefix_lens.extend(batch.prefix_lens)
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
prefix_ids.extend(batch.prefix_ids)
input_lengths.extend(batch.input_lengths) input_lengths.extend(batch.input_lengths)
prefix_offsets.extend(batch.prefix_offsets) prefix_offsets.extend(batch.prefix_offsets)
@ -779,6 +831,8 @@ class FlashCausalLMBatch(Batch):
slot_indices=slot_indices, slot_indices=slot_indices,
block_tables=block_tables, block_tables=block_tables,
block_tables_tensor=block_tables_tensor, block_tables_tensor=block_tables_tensor,
prefix_lens=prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
slots=slots, slots=slots,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=None, prefill_head_indices=None,
@ -790,6 +844,7 @@ class FlashCausalLMBatch(Batch):
read_offsets=read_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
@ -944,6 +999,9 @@ class FlashCausalLM(Model):
) )
self.prefill_state = create_prefill_state(device=device) self.prefill_state = create_prefill_state(device=device)
self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(
device=device
)
if not CUDA_GRAPHS: if not CUDA_GRAPHS:
self.decode_state = create_decode_state( self.decode_state = create_decode_state(
@ -1042,11 +1100,22 @@ class FlashCausalLM(Model):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
slots = torch.arange(bs, dtype=torch.int64, device=self.device) slots = torch.arange(bs, dtype=torch.int64, device=self.device)
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s input_lengths = [max_s] * bs
block_tables = ( prefix_lengths = [0] * bs
torch.arange(max_bt, dtype=torch.int32, device=self.device) input_lengths_tensor = (
.repeat(bs) torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
.reshape((bs, max_bt)) )
prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
block_tables = torch.arange(
max_bt, dtype=torch.int32, device=self.device
).repeat(bs)
block_tables = block_tables.reshape((bs, max_bt))
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=input_lengths,
prefix_lens=prefix_lengths,
) )
self.cuda_graphs[bs] = { self.cuda_graphs[bs] = {
@ -1055,9 +1124,9 @@ class FlashCausalLM(Model):
"kv_cache": self.kv_cache, "kv_cache": self.kv_cache,
"block_tables": block_tables, "block_tables": block_tables,
"slots": slots, "slots": slots,
"input_lengths": input_lengths, "input_lengths": input_lengths_tensor,
} }
input_lengths_ = Seqlen(input_lengths=input_lengths) input_lengths_ = Seqlen(input_lengths=input_lengths_tensor)
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph self.cuda_graphs[bs]["graph"] = graph
@ -1072,7 +1141,7 @@ class FlashCausalLM(Model):
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device) last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
state = create_decode_state_cuda_graphs( state = create_decode_state_cuda_graphs(
device=input_ids.device, device=input_ids.device,
block_tables=block_tables.view(-1), block_tables=block_tables,
block_tables_ptr=block_tables_ptr, block_tables_ptr=block_tables_ptr,
last_page_len=last_page_len, last_page_len=last_page_len,
num_heads=self.num_heads, num_heads=self.num_heads,
@ -1088,7 +1157,10 @@ class FlashCausalLM(Model):
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
input_lengths=input_lengths, input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
state=state, state=state,
prefix_lens=prefix_lengths,
prefix_lens_tensor=prefix_lengths_tensor,
): ):
self.model.forward( self.model.forward(
input_ids=input_ids, input_ids=input_ids,
@ -1106,7 +1178,7 @@ class FlashCausalLM(Model):
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL): with torch.cuda.graph(graph, pool=MEM_POOL):
input_lengths = Seqlen(input_lengths=input_lengths) input_lengths_tensor = Seqlen(input_lengths=input_lengths_tensor)
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
@ -1114,7 +1186,7 @@ class FlashCausalLM(Model):
kv_cache=self.kv_cache, kv_cache=self.kv_cache,
block_tables=block_tables, block_tables=block_tables,
slots=slots, slots=slots,
input_lengths=input_lengths, input_lengths=input_lengths_tensor,
max_s=max_s, max_s=max_s,
prefill_cache_indices=None, prefill_cache_indices=None,
lm_head_indices=None, lm_head_indices=None,
@ -1343,7 +1415,10 @@ class FlashCausalLM(Model):
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths=input_lengths, input_lengths=batch.input_lengths,
input_lengths_tensor=input_lengths,
prefix_lens=batch.prefix_lens,
prefix_lens_tensor=batch.prefix_lens_tensor,
): ):
input_lengths = Seqlen(input_lengths=input_lengths) input_lengths = Seqlen(input_lengths=input_lengths)
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
@ -1367,19 +1442,32 @@ class FlashCausalLM(Model):
# Static inputs are potentially padded # Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
)
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else:
cuda_graph["block_tables"][ cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1] : block_tables.shape[0], : block_tables.shape[1]
] = block_tables ] = block_tables
cuda_graph["slots"].fill_(-1) cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
input_lengths + batch.prefix_lens_tensor
)
state = cuda_graph.get("state") state = cuda_graph.get("state")
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
input_lengths=input_lengths, input_lengths=batch.input_lengths,
input_lengths_tensor=input_lengths,
prefix_lens=batch.prefix_lens,
prefix_lens_tensor=batch.prefix_lens_tensor,
state=state, state=state,
): ):
# Replay the graph # Replay the graph
@ -1578,6 +1666,7 @@ class FlashCausalLM(Model):
batch.read_offsets, batch.read_offsets,
batch.stopping_criterias, batch.stopping_criterias,
batch.all_input_ids, batch.all_input_ids,
batch.prefix_ids,
batch.next_token_chooser.do_sample, batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds, batch.next_token_chooser.seeds,
batch.top_n_tokens, batch.top_n_tokens,
@ -1595,6 +1684,7 @@ class FlashCausalLM(Model):
read_offset, read_offset,
stopping_criteria, stopping_criteria,
all_input_ids, all_input_ids,
prefix_ids,
do_sample, do_sample,
seed, seed,
top_n_tokens, top_n_tokens,
@ -1669,18 +1759,18 @@ class FlashCausalLM(Model):
out_end_index = batch.prefill_cu_outlens[i + 1] out_end_index = batch.prefill_cu_outlens[i + 1]
# Remove generated token to only have prefill and add nan for first prompt token # Remove generated token to only have prefill and add nan for first prompt token
request_prefill_logprobs = [float("nan")] + prefill_logprobs[ request_prefill_logprobs = (
out_start_index : out_end_index - 1 [float("nan")] * (len(prefix_ids) + 1)
] ) + prefill_logprobs[out_start_index : out_end_index - 1]
prefill_token_ids = all_input_ids[:-1] prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode( prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids, prefix_ids + prefill_token_ids,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
prefill_tokens = Tokens( prefill_tokens = Tokens(
prefill_token_ids, prefix_ids + prefill_token_ids,
request_prefill_logprobs, request_prefill_logprobs,
prefill_texts, prefill_texts,
is_special=[], is_special=[],
@ -1762,7 +1852,10 @@ class FlashCausalLM(Model):
*, *,
block_tables: torch.Tensor, block_tables: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
input_lengths: torch.Tensor, input_lengths: List[int],
input_lengths_tensor: torch.Tensor,
prefix_lens: List[int],
prefix_lens_tensor: torch.Tensor,
state: Optional[Any] = None, state: Optional[Any] = None,
) -> ContextManager: ) -> ContextManager:
if ATTENTION != "flashinfer": if ATTENTION != "flashinfer":
@ -1771,9 +1864,30 @@ class FlashCausalLM(Model):
from text_generation_server.layers.attention.flash_infer import ( from text_generation_server.layers.attention.flash_infer import (
use_decode_state, use_decode_state,
use_prefill_state, use_prefill_state,
use_prefill_with_paged_kv_state,
) )
# has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens)
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
if True: # has_prefix_lens:
return use_prefill_with_paged_kv_state(
state=(
state if state is not None else self.prefill_with_paged_kv_state
),
block_tables=block_tables_to_ragged(
block_tables=block_tables,
input_lengths=input_lengths,
prefix_lens=prefix_lens,
),
cu_seqlens=cu_seqlen_prefill,
input_lengths=input_lengths_tensor + prefix_lens_tensor,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
page_size=BLOCK_SIZE,
)
else:
return use_prefill_state( return use_prefill_state(
state=state if state is not None else self.prefill_state, state=state if state is not None else self.prefill_state,
cu_seqlens=cu_seqlen_prefill, cu_seqlens=cu_seqlen_prefill,
@ -1782,13 +1896,33 @@ class FlashCausalLM(Model):
head_size=self.head_size, head_size=self.head_size,
) )
else: else:
assert input_lengths is not None assert input_lengths_tensor is not None
return use_decode_state( return use_decode_state(
state=state if state is not None else self.decode_state, state=state if state is not None else self.decode_state,
input_lengths=input_lengths, input_lengths=input_lengths_tensor + prefix_lens_tensor,
block_tables=block_tables.view(-1), block_tables=block_tables,
num_heads=self.num_heads, num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
head_size=self.head_size, head_size=self.head_size,
page_size=BLOCK_SIZE, page_size=BLOCK_SIZE,
) )
def block_tables_to_ragged(
*, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int]
) -> torch.Tensor:
"""Convert block table to ragged format compatible with FlashInfer."""
assert len(input_lengths) == len(prefix_lens)
total_len = sum(input_lengths) + sum(prefix_lens)
block_tables_ragged = torch.empty(
total_len, dtype=torch.int32, device=block_tables.device
)
offset = 0
for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)):
seq_len = prefix_len + input_length
block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]
offset += seq_len
return block_tables_ragged

View File

@ -5,17 +5,29 @@ from typing import Dict, Optional
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
ATTENTION = os.getenv("ATTENTION", "paged") PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", False)
log_master(logger.info, f"Using Attention = {PREFIX_CACHING}")
ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged")
_expected = {"paged", "flashdecoding", "flashinfer"} _expected = {"paged", "flashdecoding", "flashinfer"}
assert ( assert (
ATTENTION in _expected ATTENTION in _expected
), f"Attention is not valid {ATTENTION}, expected {_expected}" ), f"Attention is not valid {ATTENTION}, expected {_expected}"
log_master(logger.info, f"Using Attention = {ATTENTION}") log_master(logger.info, f"Using Attention = {ATTENTION}")
if PREFIX_CACHING and ATTENTION != "flashinfer":
raise RuntimeError("Prefix caching is only supported with flashinfer")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli # This is overridden by the cli
BLOCK_SIZE: int = 256 if ATTENTION == "flashdecoding" else 16
BLOCK_SIZE: int
if ATTENTION == "flashdecoding":
BLOCK_SIZE = 256
elif ATTENTION == "flashinfer":
BLOCK_SIZE = 1
else:
BLOCK_SIZE = 16
cuda_graphs = os.getenv("CUDA_GRAPHS") cuda_graphs = os.getenv("CUDA_GRAPHS")
if cuda_graphs is not None: if cuda_graphs is not None: