Prefix caching WIP

This commit is contained in:
Daniël de Kok 2024-08-09 11:47:14 +00:00
parent 7a48a84784
commit 7735b385dc
34 changed files with 1451 additions and 307 deletions

1
Cargo.lock generated
View File

@ -4045,6 +4045,7 @@ dependencies = [
"reqwest", "reqwest",
"serde", "serde",
"serde_json", "serde_json",
"slotmap",
"text-generation-router", "text-generation-router",
"thiserror", "thiserror",
"tokenizers", "tokenizers",

View File

@ -156,6 +156,7 @@ impl Client {
// Blocks and slots will be set on the server side if we use paged attention // Blocks and slots will be set on the server side if we use paged attention
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
prefix_len: 0,
// Set sampling parameters to also take these ops into account in the max memory // Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 0.9, temperature: 0.9,

View File

@ -244,6 +244,7 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks // Block 0 is reserved for health checks
blocks: vec![0], blocks: vec![0],
slots: (0..16).collect(), slots: (0..16).collect(),
prefix_len: 0,
adapter_id: None, adapter_id: None,
}; };
let batch = Batch { let batch = Batch {

View File

@ -33,6 +33,7 @@ rand = "0.8.5"
reqwest = { version = "0.11.20", features = [] } reqwest = { version = "0.11.20", features = [] }
serde = "1.0.188" serde = "1.0.188"
serde_json = "1.0.107" serde_json = "1.0.107"
slotmap = "1.0.7"
thiserror = "1.0.48" thiserror = "1.0.48"
tokenizers = { workspace = true} tokenizers = { workspace = true}
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }

View File

@ -35,15 +35,24 @@ impl BackendV3 {
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
) -> Self { ) -> Self {
let prefix_caching = if let Ok(prefix_caching) = std::env::var("USE_PREFIX_CACHING") {
matches!(prefix_caching.as_str(), "true" | "1")
} else {
false
};
let attention = if let Ok(attention) = std::env::var("ATTENTION") { let attention = if let Ok(attention) = std::env::var("ATTENTION") {
attention attention
.parse() .parse()
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")) .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
} else if prefix_caching {
Attention::FlashInfer
} else { } else {
Attention::Paged Attention::Paged
}; };
let block_size = if attention == Attention::FlashDecoding { let block_size = if attention == Attention::FlashDecoding {
256 256
} else if attention == Attention::FlashInfer {
1
} else { } else {
16 16
}; };
@ -51,6 +60,7 @@ impl BackendV3 {
let queue = Queue::new( let queue = Queue::new(
requires_padding, requires_padding,
block_size, block_size,
prefix_caching,
window_size, window_size,
speculate, speculate,
max_batch_total_tokens, max_batch_total_tokens,

View File

@ -1,16 +1,26 @@
use std::cmp::min; use std::{cmp::min, sync::Arc};
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use crate::radix::RadixAllocator;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct BlockAllocation { pub(crate) struct BlockAllocation {
pub allocation_id: u64,
pub blocks: Vec<u32>, pub blocks: Vec<u32>,
pub slots: Vec<u32>, pub slots: Vec<u32>,
block_allocator: BlockAllocator,
/// Prefix that was cached and for which the KV does not have to
/// be recomputed.
pub prefix_len: u32,
pub(crate) block_allocator: Option<BlockAllocator>,
} }
impl Drop for BlockAllocation { impl Drop for BlockAllocation {
fn drop(&mut self) { fn drop(&mut self) {
self.block_allocator.free(self.blocks.clone()) if let Some(block_allocator) = self.block_allocator.as_mut() {
block_allocator.free(self.blocks.clone(), self.allocation_id)
}
} }
} }
@ -24,6 +34,7 @@ impl BlockAllocator {
pub(crate) fn new( pub(crate) fn new(
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
block_size: u32, block_size: u32,
prefix_caching: bool,
window_size: Option<u32>, window_size: Option<u32>,
) -> Self { ) -> Self {
// Create channel // Create channel
@ -33,6 +44,7 @@ impl BlockAllocator {
tokio::spawn(block_allocator_task( tokio::spawn(block_allocator_task(
max_batch_total_tokens / block_size, max_batch_total_tokens / block_size,
block_size, block_size,
prefix_caching,
window_size, window_size,
receiver, receiver,
)); ));
@ -42,28 +54,32 @@ impl BlockAllocator {
} }
} }
pub(crate) async fn allocate(&self, tokens: u32) -> Option<BlockAllocation> { pub(crate) async fn allocate(
&self,
tokens: u32,
prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation> {
let (response_sender, response_receiver) = oneshot::channel(); let (response_sender, response_receiver) = oneshot::channel();
self.block_allocator self.block_allocator
.send(BlockAllocatorCommand::Allocate { .send(BlockAllocatorCommand::Allocate {
tokens, tokens,
prefill_tokens,
response_sender, response_sender,
}) })
.unwrap(); .unwrap();
response_receiver response_receiver.await.unwrap().map(|mut allocation| {
.await allocation.block_allocator = Some(self.clone());
.unwrap() allocation
.map(|(blocks, slots)| BlockAllocation { })
blocks,
slots,
block_allocator: self.clone(),
})
} }
pub(crate) fn free(&self, blocks: Vec<u32>) { pub(crate) fn free(&self, blocks: Vec<u32>, allocation_id: u64) {
self.block_allocator self.block_allocator
.send(BlockAllocatorCommand::Free { blocks }) .send(BlockAllocatorCommand::Free {
allocation_id,
blocks,
})
.unwrap(); .unwrap();
} }
} }
@ -71,54 +87,29 @@ impl BlockAllocator {
async fn block_allocator_task( async fn block_allocator_task(
blocks: u32, blocks: u32,
block_size: u32, block_size: u32,
prefix_caching: bool,
window_size: Option<u32>, window_size: Option<u32>,
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>, mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
) { ) {
// Block 0 is reserved for health checks let mut allocator: Box<dyn Allocator + Send> = if prefix_caching {
let mut free_blocks: Vec<u32> = (1..blocks).collect(); Box::new(RadixAllocator::new(block_size, blocks, window_size))
} else {
Box::new(SimpleAllocator::new(blocks, block_size, window_size))
};
while let Some(cmd) = receiver.recv().await { while let Some(cmd) = receiver.recv().await {
match cmd { match cmd {
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks), BlockAllocatorCommand::Free {
blocks,
allocation_id,
} => allocator.free(blocks, allocation_id),
BlockAllocatorCommand::Allocate { BlockAllocatorCommand::Allocate {
tokens, tokens,
prefill_tokens,
response_sender, response_sender,
} => { } => {
// Apply window size response_sender
let (required_blocks, repeats) = { .send(allocator.allocate(tokens, prefill_tokens))
let (tokens, repeats) = match window_size { .unwrap();
None => (tokens, 1),
Some(window_size) => {
let repeats = (tokens + window_size - 1) / window_size;
let tokens = min(tokens, window_size);
(tokens, repeats as usize)
}
};
// Pad to a multiple of block size
let required_blocks = (tokens + block_size - 1) / block_size;
(required_blocks, repeats)
};
let tokens = tokens as usize;
let allocation = if required_blocks > free_blocks.len() as u32 {
None
} else {
let blocks =
free_blocks.split_off(free_blocks.len() - required_blocks as usize);
let mut slots = Vec::with_capacity(
(required_blocks * block_size * repeats as u32) as usize,
);
'slots: for block_id in blocks.repeat(repeats).iter() {
for s in (block_id * block_size)..((block_id + 1) * block_size) {
slots.push(s);
if slots.len() == tokens {
break 'slots;
}
}
}
Some((blocks, slots))
};
response_sender.send(allocation).unwrap();
} }
} }
} }
@ -128,9 +119,92 @@ async fn block_allocator_task(
enum BlockAllocatorCommand { enum BlockAllocatorCommand {
Free { Free {
blocks: Vec<u32>, blocks: Vec<u32>,
allocation_id: u64,
}, },
Allocate { Allocate {
tokens: u32, tokens: u32,
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>)>>, prefill_tokens: Option<Arc<Vec<u32>>>,
response_sender: oneshot::Sender<Option<BlockAllocation>>,
}, },
} }
pub(crate) trait Allocator {
fn allocate(
&mut self,
tokens: u32,
prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation>;
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
}
pub struct SimpleAllocator {
free_blocks: Vec<u32>,
block_size: u32,
window_size: Option<u32>,
}
impl SimpleAllocator {
fn new(blocks: u32, block_size: u32, window_size: Option<u32>) -> Self {
SimpleAllocator {
block_size,
// Block 0 is reserved for health checks
free_blocks: (1..blocks).collect(),
window_size,
}
}
}
impl Allocator for SimpleAllocator {
fn allocate(
&mut self,
tokens: u32,
_prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation> {
// Apply window size
let (required_blocks, repeats) = {
let (tokens, repeats) = match self.window_size {
None => (tokens, 1),
Some(window_size) => {
let repeats = (tokens + window_size - 1) / window_size;
let tokens = min(tokens, window_size);
(tokens, repeats as usize)
}
};
// Pad to a multiple of block size
let required_blocks = (tokens + self.block_size - 1) / self.block_size;
(required_blocks, repeats)
};
let tokens = tokens as usize;
if required_blocks > self.free_blocks.len() as u32 {
None
} else {
let blocks = self
.free_blocks
.split_off(self.free_blocks.len() - required_blocks as usize);
let mut slots =
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
'slots: for block_id in blocks.repeat(repeats).iter() {
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
slots.push(s);
if slots.len() == tokens {
break 'slots;
}
}
}
Some(BlockAllocation {
allocation_id: 0,
blocks,
slots,
prefix_len: 0,
block_allocator: None,
})
}
}
fn free(&mut self, blocks: Vec<u32>, _allocation_id: u64) {
self.free_blocks.extend(blocks)
}
}

View File

@ -157,6 +157,7 @@ impl Client {
// Blocks and slots will be set on the server side if we use paged attention // Blocks and slots will be set on the server side if we use paged attention
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
prefix_len: 0,
// Set sampling parameters to also take these ops into account in the max memory // Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 0.9, temperature: 0.9,

View File

@ -245,6 +245,7 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks // Block 0 is reserved for health checks
blocks: vec![0], blocks: vec![0],
slots: (0..16).collect(), slots: (0..16).collect(),
prefix_len: 0,
adapter_id: None, adapter_id: None,
}; };
let batch = Batch { let batch = Batch {

View File

@ -2,6 +2,7 @@ mod backend;
mod block_allocator; mod block_allocator;
mod client; mod client;
mod queue; mod queue;
mod radix;
use crate::client::{ClientError, ShardedClient}; use crate::client::{ClientError, ShardedClient};
pub(crate) use backend::BackendV3; pub(crate) use backend::BackendV3;

View File

@ -46,6 +46,7 @@ impl Queue {
pub(crate) fn new( pub(crate) fn new(
requires_padding: bool, requires_padding: bool,
block_size: u32, block_size: u32,
prefix_caching: bool,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
@ -57,6 +58,7 @@ impl Queue {
tokio::spawn(queue_task( tokio::spawn(queue_task(
requires_padding, requires_padding,
block_size, block_size,
prefix_caching,
window_size, window_size,
speculate, speculate,
max_batch_total_tokens, max_batch_total_tokens,
@ -109,6 +111,7 @@ impl Queue {
async fn queue_task( async fn queue_task(
requires_padding: bool, requires_padding: bool,
block_size: u32, block_size: u32,
prefix_caching: bool,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
@ -117,6 +120,7 @@ async fn queue_task(
let mut state = State::new( let mut state = State::new(
requires_padding, requires_padding,
block_size, block_size,
prefix_caching,
window_size, window_size,
speculate, speculate,
max_batch_total_tokens, max_batch_total_tokens,
@ -176,12 +180,19 @@ impl State {
fn new( fn new(
requires_padding: bool, requires_padding: bool,
block_size: u32, block_size: u32,
prefix_caching: bool,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
) -> Self { ) -> Self {
let block_allocator = (!requires_padding) let block_allocator = (!requires_padding).then(|| {
.then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size)); BlockAllocator::new(
max_batch_total_tokens,
block_size,
prefix_caching,
window_size,
)
});
Self { Self {
entries: VecDeque::with_capacity(128), entries: VecDeque::with_capacity(128),
@ -305,7 +316,10 @@ impl State {
+ self.speculate + self.speculate
- 1; - 1;
match block_allocator.allocate(tokens).await { match block_allocator
.allocate(tokens, entry.request.input_ids.clone())
.await
{
None => { None => {
// Entry is over budget // Entry is over budget
// Add it back to the front // Add it back to the front
@ -331,11 +345,12 @@ impl State {
// Update entry // Update entry
entry.temp_span = Some(entry_batch_span); entry.temp_span = Some(entry_batch_span);
let (blocks, slots) = match &block_allocation { let (blocks, slots, prefix_len) = match &block_allocation {
None => (Vec::new(), Vec::new()), None => (Vec::new(), Vec::new(), 0),
Some(block_allocation) => ( Some(block_allocation) => (
block_allocation.blocks.clone(), block_allocation.blocks.clone(),
block_allocation.slots.clone(), block_allocation.slots.clone(),
block_allocation.prefix_len,
), ),
}; };
@ -372,6 +387,7 @@ impl State {
top_n_tokens: entry.request.top_n_tokens, top_n_tokens: entry.request.top_n_tokens,
blocks, blocks,
slots, slots,
prefix_len,
adapter_id: entry.request.adapter_id.clone(), adapter_id: entry.request.adapter_id.clone(),
}); });
// Set batch_time // Set batch_time
@ -480,6 +496,8 @@ impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::sync::Arc;
use super::*; use super::*;
use tracing::info_span; use tracing::info_span;
@ -492,6 +510,7 @@ mod tests {
let entry = Entry { let entry = Entry {
request: ValidGenerateRequest { request: ValidGenerateRequest {
inputs: vec![], inputs: vec![],
input_ids: Some(Arc::new(vec![])),
input_length: 0, input_length: 0,
truncate: 0, truncate: 0,
decoder_input_details: false, decoder_input_details: false,
@ -527,7 +546,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_append() { async fn test_append() {
let mut state = State::new(false, 1, None, 0, 16); let mut state = State::new(false, 1, false, None, 0, 16);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0); assert_eq!(state.next_id, 0);
@ -543,7 +562,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_empty() { async fn test_next_batch_empty() {
let mut state = State::new(false, 1, None, 0, 16); let mut state = State::new(false, 1, false, None, 0, 16);
assert!(state.next_batch(None, None, 1, 1).await.is_none()); assert!(state.next_batch(None, None, 1, 1).await.is_none());
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none()); assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
@ -551,7 +570,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_min_size() { async fn test_next_batch_min_size() {
let mut state = State::new(false, 1, None, 0, 16); let mut state = State::new(false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -583,7 +602,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_max_size() { async fn test_next_batch_max_size() {
let mut state = State::new(false, 1, None, 0, 16); let mut state = State::new(false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -603,7 +622,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_token_budget() { async fn test_next_batch_token_budget() {
let mut state = State::new(false, 1, None, 0, 2); let mut state = State::new(false, 1, false, None, 0, 2);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -636,14 +655,14 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_append() { async fn test_queue_append() {
let queue = Queue::new(false, 1, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
queue.append(entry); queue.append(entry);
} }
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_empty() { async fn test_queue_next_batch_empty() {
let queue = Queue::new(false, 1, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16);
assert!(queue.next_batch(None, None, 1, 1).await.is_none()); assert!(queue.next_batch(None, None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
@ -651,7 +670,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_min_size() { async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false, 1, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -684,7 +703,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_max_size() { async fn test_queue_next_batch_max_size() {
let queue = Queue::new(false, 1, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -700,7 +719,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_budget() { async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false, 1, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -725,7 +744,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_speculate() { async fn test_queue_next_batch_token_speculate() {
let queue = Queue::new(false, 1, None, 2, 16); let queue = Queue::new(false, 1, false, None, 2, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -744,7 +763,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_dropped_receiver() { async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false, 1, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16);
let (entry, _) = default_entry(); let (entry, _) = default_entry();
queue.append(entry); queue.append(entry);

755
backends/v3/src/radix.rs Normal file
View File

@ -0,0 +1,755 @@
use std::{
collections::{BTreeSet, HashMap},
sync::Arc,
};
use slotmap::{DefaultKey, SlotMap};
use crate::block_allocator::{Allocator, BlockAllocation};
pub struct RadixAllocator {
allocation_id: u64,
allocations: HashMap<u64, RadixAllocation>,
cache_blocks: RadixTrie,
/// Blocks that are immediately available for allocation.
free_blocks: Vec<u32>,
}
impl RadixAllocator {
pub fn new(block_size: u32, n_blocks: u32, window_size: Option<u32>) -> Self {
assert_eq!(
block_size, 1,
"Radix tree allocator only works with block_size=1, was: {}",
block_size
);
if window_size.is_some() {
unimplemented!("Window size not supported in the prefix-caching block allocator yet");
}
RadixAllocator {
allocation_id: 0,
allocations: HashMap::new(),
cache_blocks: RadixTrie::new(),
// Block 0 is reserved for health checks.
free_blocks: (1..n_blocks).collect(),
}
}
fn alloc_or_reclaim(&mut self, n_blocks_needed: usize) -> Option<Vec<u32>> {
if self.free_blocks.len() < n_blocks_needed {
// This is a bit annoying, we first extend the free list and then
// split it off again below. This is because we need to put it on
// the free list if we cannot allocate enough blocks. This is only
// temporary, the trie needs to be able to report whether it can
// allocate the requested amount. Just not implemented yet.
self.free_blocks.extend(
self.cache_blocks
.evict(n_blocks_needed - self.free_blocks.len()),
);
}
if self.free_blocks.len() >= n_blocks_needed {
Some(
self.free_blocks
.split_off(self.free_blocks.len() - n_blocks_needed),
)
} else {
None
}
}
}
impl Allocator for RadixAllocator {
fn allocate(
&mut self,
tokens: u32,
prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation> {
let mut blocks = vec![];
let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() {
let node_id = self
.cache_blocks
.find(prefill_tokens.as_slice(), &mut blocks);
// Even if this allocation fails below, we need to increase he
// refcount to ensure that the prefix that was found is not evicted.
node_id
} else {
self.cache_blocks.root_id()
};
self.cache_blocks
.incref(prefix_node)
.expect("Failed to increment refcount");
let prefix_len = blocks.len();
let suffix_len = tokens - prefix_len as u32;
match self.alloc_or_reclaim(suffix_len as usize) {
Some(suffix_blocks) => blocks.extend(suffix_blocks),
None => {
self.cache_blocks
.decref(prefix_node)
.expect("Failed to decrement refcount");
return None;
}
}
// 1:1 mapping of blocks and slots.
let slots = blocks.clone();
let allocation = RadixAllocation {
prefix_node,
cached_prefix_len: prefix_len,
prefill_tokens: prefill_tokens.clone(),
};
self.allocation_id += 1;
self.allocations.insert(self.allocation_id, allocation);
Some(BlockAllocation {
allocation_id: self.allocation_id,
block_allocator: None,
blocks,
slots,
prefix_len: prefix_len as u32,
})
}
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
let allocation = match self.allocations.remove(&allocation_id) {
Some(allocation) => allocation,
None => unreachable!("Tried to free an unknown allocation."),
};
self.cache_blocks
.decref(allocation.prefix_node)
.expect("Failed to decrement refcount");
if let Some(prefill_tokens) = allocation.prefill_tokens {
let prefill_tokens = prefill_tokens.as_slice();
// If there are prefill tokens that did not come from the cache,
// add them to the cache.
if prefill_tokens.len() > allocation.cached_prefix_len {
let prefix_len = self
.cache_blocks
.insert(prefill_tokens, &blocks[..prefill_tokens.len()])
// Unwrap, failing is a programming error.
.expect("Failed to store prefill tokens");
// We can have a prefill with the following structure:
//
// |---| From the prefix cache.
// A B C D E F G
//|--------| Found in the trie during insertion.
//
// This means that while processing this request there was a
// partially overlapping request that had A..=E in its
// prefill. In this case we need to free the blocks D E.
self.free_blocks
.extend(&blocks[allocation.cached_prefix_len..prefix_len]);
}
// Free non-prefill blocks.
self.free_blocks.extend(&blocks[prefill_tokens.len()..]);
} else {
self.free_blocks.extend(blocks);
}
}
}
struct RadixAllocation {
prefix_node: NodeId,
cached_prefix_len: usize,
prefill_tokens: Option<Arc<Vec<u32>>>,
}
// Radix trie that is heavily inspired by radix attention from sglang.
//
// The trie is optimized for prefix caching:
//
// - A normal radix trie stores discrete values. In this radix trie,
// inserting *abc* with value *xyz* will also enable lookup for
// *a* (*x*) and *ab* (*xy*).
// - As a result, every value is required to have the same length as
// the key.
// - We store additional information in each node, such as last access
// time and a reference count.
#[derive(Debug)]
pub enum TrieError {
InvalidNodeId,
RefCountUnderflow,
BlockTokenCountMismatch,
}
pub type NodeId = DefaultKey;
#[derive(Debug)]
pub struct RadixTrie {
/// Identifier of the root nod.
root: DefaultKey,
/// Leave node identifiers ordered by increasing recency.
leaves: BTreeSet<(u64, NodeId)>,
/// All trie nodes.
nodes: SlotMap<NodeId, TrieNode>,
/// Time as a monotonically increating counter to avoid the system
/// call that a real time lookup would require.
time: u64,
}
impl RadixTrie {
/// Construct a new radix trie.
pub fn new() -> Self {
let root = TrieNode::new(vec![], vec![], 0, None);
let mut nodes = SlotMap::new();
let root = nodes.insert(root);
RadixTrie {
leaves: BTreeSet::new(),
nodes,
root,
time: 0,
}
}
/// Find the prefix of the given tokens.
///
/// The blocks corresponding to the part of the prefix that could be found
/// are writteng to `blocks`. The number of blocks is in `0..=tokens.len()`.
/// Returns the identifier of the trie node that contains the longest
/// prefix. The node identifier can be used by callers to e.g. increase its
/// reference count.
///
/// Using this method will update the access time of the traversed nodes.
pub fn find(&mut self, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
self.time += 1;
self.find_(self.root, key, blocks)
}
/// Find worker.
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
let node = &self.nodes[node_id];
if let Some(&child_id) = node.children.get(&key[0]) {
self.update_access_time(child_id);
let child = self.nodes.get(child_id).expect("Invalid child identifier");
let shared_prefix_len = child.key.shared_prefix_len(key);
blocks.extend(&child.blocks[..shared_prefix_len]);
let key = &key[shared_prefix_len..];
if !key.is_empty() {
node_id = self.find_(child_id, key, blocks);
}
}
node_id
}
/// Decrease the reference count of a node.
pub fn decref(&mut self, node_id: NodeId) -> Result<(), TrieError> {
// We don't care about refcounting for root, since it will never
// be evicted.
if node_id == self.root {
return Ok(());
}
let node = self
.nodes
.get_mut(node_id)
.ok_or(TrieError::InvalidNodeId)?;
if node.ref_count == 0 {
return Err(TrieError::RefCountUnderflow);
}
node.ref_count -= 1;
if node.ref_count == 0 {
self.leaves.insert((node.last_accessed, node_id));
}
Ok(())
}
/// Increase the reference count of a node.
pub fn incref(&mut self, node_id: NodeId) -> Result<(), TrieError> {
if node_id == self.root {
return Ok(());
}
let node = self
.nodes
.get_mut(node_id)
.ok_or(TrieError::InvalidNodeId)?;
if node.ref_count == 0 {
self.leaves.remove(&(node.last_accessed, node_id));
}
node.ref_count += 1;
Ok(())
}
/// Evict `n_blocks` from the trie.
///
/// Returns the evicted blocks. When the length is less than `n_blocks`,
/// not enough blocks could beevicted.
pub fn evict(&mut self, n_blocks: usize) -> Vec<u32> {
// NOTE: we don't return Result here. If any of the unwrapping fails,
// it's a programming error in the trie implementation, not a user
// error caused by e.g. an invalid argument.
// TODO: add some bookkeeping in the future to check whether we can
// evict n_blocks and return `None` if we can't. We are now needlessly
// evicting prefixes from the cache in such a case.
let mut evicted = Vec::new();
while let Some((last_access, node_id)) = self.leaves.pop_first() {
let blocks_needed = n_blocks - evicted.len();
let node = self.nodes.get(node_id).expect("Leave does not exist");
if blocks_needed >= node.blocks.len() {
// We need to evict the whole node if we need more blocks than it has.
let node = self.remove_node(node_id);
evicted.extend(node.blocks);
if evicted.len() >= n_blocks {
break;
}
} else {
// The node has more blocks than needed, so we'll just remove
// the required number of blocks and leave the remaining blocks
// untouched.
let node = self.nodes.get_mut(node_id).expect("Leave does not exist");
node.key.truncate(node.blocks.len() - blocks_needed);
evicted.extend(node.blocks.split_off(node.blocks.len() - blocks_needed));
self.leaves.insert((last_access, node_id));
break;
}
}
evicted
}
/// Insert a prefill along with its blocks.
///
/// This method returns the length of the prefix that was already
/// in the trie. E.g. if the length is 10, this means that for
/// the first 10 elements of the tree **the blocks are not updated**.
pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result<usize, TrieError> {
self.time += 1;
self.insert_(self.root, tokens, blocks)
}
/// Insertion worker.
fn insert_(
&mut self,
node_id: NodeId,
tokens: &[u32],
blocks: &[u32],
) -> Result<usize, TrieError> {
// TODO: in the future we may want to check that the blocks match for
// the part of the prefix that is already in the trie to detect
// mismatches.
if tokens.len() != blocks.len() {
return Err(TrieError::BlockTokenCountMismatch);
}
if let Some(&child_id) = self.nodes[node_id].children.get(&tokens[0]) {
self.update_access_time(child_id);
let child = self
.nodes
.get_mut(child_id)
// Unwrap here, since failure is a bug.
.expect("Child node does not exist");
let shared_prefix_len = child.key.shared_prefix_len(tokens);
// We are done, the prefix is already in the trie.
if shared_prefix_len == tokens.len() {
return Ok(shared_prefix_len);
}
// The node's prefix is a prefix of the insertion prefix.
if shared_prefix_len == child.key.len() {
return Ok(shared_prefix_len
+ self.insert_(
child_id,
&tokens[shared_prefix_len..],
&blocks[shared_prefix_len..],
)?);
}
// The node's prefix and the insertion prefix only match partially,
// split the node to just contain the matching part. Then insert the
// remainder of the prefix into the node again
let child_id = self.split_node(child_id, shared_prefix_len);
let key = &tokens[shared_prefix_len..];
let blocks = &blocks[shared_prefix_len..];
Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?)
} else {
self.add_node(node_id, tokens, blocks);
Ok(0)
}
}
fn split_node(&mut self, node_id: NodeId, prefix_len: usize) -> NodeId {
// We have to make the current node a child to ensure that its
// properties and node id stay the same.
// This funcion unwraps, an invalid node_id is a programming error.
let node = self
.nodes
.get_mut(node_id)
.expect("Node to-be split does not exist");
let mut parent_key = node.key.split_off(prefix_len);
let mut parent_blocks = node.blocks.split_off(prefix_len);
// Move first part of the prefix to the parent. We swap to avoid
// an allocation + copy for both splits of the key/blocks.
std::mem::swap(&mut node.key, &mut parent_key);
std::mem::swap(&mut node.blocks, &mut parent_blocks);
let node_key = node.key[0];
let grandparent_id = node.parent.expect("Node does not have a parent");
let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks);
self.add_node_to_parent(parent_id, node_key, node_id);
// Reborrow to make the borrow checker happy.
let node = self
.nodes
.get_mut(node_id)
.expect("Node to-be split does not exist");
node.parent = Some(parent_id);
parent_id
}
/// Create a node and add it to the parent.
fn add_node(
&mut self,
parent_id: NodeId,
key: impl Into<Vec<u32>>,
blocks: impl Into<Vec<u32>>,
) -> NodeId {
let key = key.into();
let blocks = blocks.into();
let first = key[0];
let child = TrieNode::new(key, blocks, self.time, Some(parent_id));
let child_id = self.nodes.insert(child);
self.add_node_to_parent(parent_id, first, child_id);
self.leaves.insert((self.time, child_id));
child_id
}
/// Add a node to the parent.
fn add_node_to_parent(&mut self, parent_id: NodeId, first: u32, child_id: NodeId) {
// Unwrap here, passing in an unknown id is a programming error.
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
if parent.children.insert(first, child_id).is_none() {
// Only increase reference count if child does not replace another child.
self.incref(parent_id)
.expect("Failed to increase parent refcount");
}
}
/// Remove a node from the trie.
fn remove_node(&mut self, node_id: NodeId) -> TrieNode {
// Unwrap here, passing in an unknown id is a programming error.
let node = self.nodes.remove(node_id).expect("Unknown node");
let parent_id = node.parent.expect("Attempted to remove root node");
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
parent.children.remove(&node.key[0]);
self.decref(parent_id)
.expect("Failed to decrease parent refcount");
self.nodes.remove(node_id);
node
}
fn update_access_time(&mut self, node_id: NodeId) {
// Unwrap here, passing in an unknown id is a programming error.
let node = self.nodes.get_mut(node_id).expect("Unknown node");
// Update the ordered leaves set if the node is a leave.
if self.leaves.remove(&(node.last_accessed, node_id)) {
self.leaves.insert((self.time, node_id));
}
node.last_accessed = self.time;
}
#[allow(dead_code)]
#[doc(hidden)]
/// Print debugging output for the trie.
///
/// In contrast to `Debug` nicely formatted.
pub fn print_debug(&self) {
self.print_debug_(self.root, 0);
}
fn print_debug_(&self, node_id: NodeId, indent: usize) {
let node = &self.nodes[node_id];
eprintln!(
"{}{:?}, key: {:?}, blocks: {:?}, ref_count: {}, last_accessed: {}, parent: {:?}, children: {:?}",
" ".repeat(indent),
node_id,
node.key,
node.blocks,
node.ref_count,
node.last_accessed,
node.parent,
node.children
);
for child_id in self.nodes[node_id].children.values() {
self.print_debug_(*child_id, indent + 2);
}
}
pub(crate) fn root_id(&self) -> DefaultKey {
self.root
}
}
/// Trie node.
#[derive(Debug)]
struct TrieNode {
blocks: Vec<u32>,
children: HashMap<u32, NodeId>,
key: Vec<u32>,
last_accessed: u64,
parent: Option<NodeId>,
ref_count: usize,
}
impl TrieNode {
fn new(key: Vec<u32>, blocks: Vec<u32>, last_accessed: u64, parent: Option<NodeId>) -> Self {
TrieNode {
children: HashMap::new(),
key,
blocks,
last_accessed,
parent,
ref_count: 0,
}
}
}
/// Helper trait to get the length of the shared prefix of two sequences.
trait SharedPrefixLen {
fn shared_prefix_len(&self, other: &Self) -> usize;
}
impl<T> SharedPrefixLen for [T]
where
T: PartialEq,
{
fn shared_prefix_len(&self, other: &Self) -> usize {
self.iter().zip(other).take_while(|(a, b)| a == b).count()
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::block_allocator::Allocator;
use super::RadixAllocator;
#[test]
fn allocator_reuses_prefixes() {
let mut cache = RadixAllocator::new(1, 12, None);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
assert_eq!(allocation.slots, allocation.slots);
assert_eq!(allocation.prefix_len, 0);
cache.free(allocation.blocks.clone(), allocation.allocation_id);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
assert_eq!(allocation.prefix_len, 4);
}
#[test]
fn allocator_collects_older_prefixes_first() {
let mut cache = RadixAllocator::new(1, 7, None);
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]);
assert_eq!(allocation1.prefix_len, 0);
let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap();
assert_eq!(allocation2.blocks, vec![1, 2]);
assert_eq!(allocation2.prefix_len, 0);
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
// We should get the blocks of the first allocation, since they are more recent.
let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap();
assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]);
assert_eq!(allocation3.prefix_len, 0);
}
#[test]
fn allocator_frees_fully_overlapping_prefills() {
let mut cache = RadixAllocator::new(1, 10, None);
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation3.prefix_len, 4);
// 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks.
assert_eq!(cache.free_blocks.len(), 5);
}
#[test]
fn allocator_frees_partially_overlapping_prefills() {
let mut cache = RadixAllocator::new(1, 20, None);
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap();
assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]);
assert_eq!(allocation1.prefix_len, 0);
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
let allocation2 = cache
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))
.unwrap();
assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]);
assert_eq!(allocation2.prefix_len, 2);
let allocation3 = cache
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))
.unwrap();
assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]);
assert_eq!(allocation3.prefix_len, 2);
cache.free(allocation3.blocks.clone(), allocation3.allocation_id);
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
// 20 blocks, of which 1 reserved for health checks, 6 for allocation3, 2 for allocation2.
assert_eq!(cache.free_blocks.len(), 11);
let allocation4 = cache
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))
.unwrap();
assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]);
assert_eq!(allocation4.prefix_len, 6);
assert_eq!(cache.free_blocks.len(), 11);
let allocation5 = cache
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))
.unwrap();
assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]);
assert_eq!(allocation5.prefix_len, 6);
assert_eq!(cache.free_blocks.len(), 11);
}
#[test]
fn trie_insertions_have_correct_prefix_len() {
let mut trie = super::RadixTrie::new();
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0);
// Already exists.
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 3);
// Completely new at root-level
assert_eq!(trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(), 0);
// Contains full prefix, but longer.
assert_eq!(trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(), 3);
// Shares partial prefix, we need a split.
assert_eq!(
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
.unwrap(),
4
);
}
#[test]
fn trie_get_returns_correct_blocks() {
let mut trie = super::RadixTrie::new();
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
.unwrap();
let mut blocks = Vec::new();
trie.find(&[0], &mut blocks);
assert_eq!(blocks, vec![0]);
blocks.clear();
trie.find(&[0, 1, 2], &mut blocks);
assert_eq!(blocks, vec![0, 1, 2]);
blocks.clear();
trie.find(&[1, 2, 3], &mut blocks);
assert_eq!(blocks, vec![1, 2, 3]);
blocks.clear();
trie.find(&[0, 1, 2, 3], &mut blocks);
assert_eq!(blocks, vec![0, 1, 2, 3]);
blocks.clear();
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
assert_eq!(blocks, vec![0, 1, 2, 3, 4]);
blocks.clear();
trie.find(&[0, 1, 2, 3, 5], &mut blocks);
assert_eq!(blocks, vec![0, 1, 2, 3, 5]);
}
#[test]
fn trie_evict_removes_correct_blocks() {
let mut trie = super::RadixTrie::new();
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
.unwrap();
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
let mut blocks = Vec::new();
// Remove less than the leave blocks.
assert_eq!(trie.evict(1), vec![7]);
trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks);
assert_eq!(blocks, vec![0, 1, 2, 3, 5, 6]);
// Refresh other leaf.
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
trie.find(&[1, 2, 3], &mut blocks);
// Remove the leave blocks exactly.
assert_eq!(trie.evict(2), vec![5, 6]);
blocks.clear();
trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks);
assert_eq!(blocks, vec![0, 1, 2, 3]);
trie.find(&[1, 2, 3], &mut blocks);
// Remove more than the leave blocks.
assert_eq!(trie.evict(3), vec![4, 3, 2]);
blocks.clear();
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
assert_eq!(blocks, vec![0, 1]);
// Clear out the whole trie.
assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]);
}
}

View File

@ -157,6 +157,7 @@ async fn prefill(
top_n_tokens: top_n_tokens.unwrap_or(0), top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
prefix_len: 0,
adapter_id: None, adapter_id: None,
}) })
.collect(); .collect();

View File

@ -3,22 +3,23 @@ syntax = "proto3";
package generate.v3; package generate.v3;
service TextGenerationService { service TextGenerationService {
/// Model Info /// Model Info
rpc Info (InfoRequest) returns (InfoResponse) {} rpc Info(InfoRequest) returns (InfoResponse) {}
/// Service discovery /// Service discovery
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {} rpc ServiceDiscovery(ServiceDiscoveryRequest)
/// Empties batch cache returns (ServiceDiscoveryResponse) {}
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse); /// Empties batch cache
/// Remove requests from a cached batch rpc ClearCache(ClearCacheRequest) returns (ClearCacheResponse);
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse); /// Remove requests from a cached batch
/// Warmup the model and compute max cache size rpc FilterBatch(FilterBatchRequest) returns (FilterBatchResponse);
rpc Warmup (WarmupRequest) returns (WarmupResponse); /// Warmup the model and compute max cache size
/// Prefill batch and decode first token rpc Warmup(WarmupRequest) returns (WarmupResponse);
rpc Prefill (PrefillRequest) returns (PrefillResponse); /// Prefill batch and decode first token
/// Decode token for a list of prefilled batches rpc Prefill(PrefillRequest) returns (PrefillResponse);
rpc Decode (DecodeRequest) returns (DecodeResponse); /// Decode token for a list of prefilled batches
/// Health check rpc Decode(DecodeRequest) returns (DecodeResponse);
rpc Health (HealthRequest) returns (HealthResponse); /// Health check
rpc Health(HealthRequest) returns (HealthResponse);
} }
message HealthRequest {} message HealthRequest {}
@ -28,240 +29,239 @@ message HealthResponse {}
message InfoRequest {} message InfoRequest {}
message InfoResponse { message InfoResponse {
bool requires_padding = 1; bool requires_padding = 1;
string dtype = 2; string dtype = 2;
string device_type = 3; string device_type = 3;
optional uint32 window_size = 4; optional uint32 window_size = 4;
uint32 speculate = 5; uint32 speculate = 5;
} }
/// Empty request /// Empty request
message ServiceDiscoveryRequest {} message ServiceDiscoveryRequest {}
message ServiceDiscoveryResponse { message ServiceDiscoveryResponse {
/// Other shards urls /// Other shards urls
repeated string urls = 1; repeated string urls = 1;
} }
message ClearCacheRequest { message ClearCacheRequest {
/// Optional batch id /// Optional batch id
optional uint64 id = 1; optional uint64 id = 1;
} }
/// Empty response /// Empty response
message ClearCacheResponse {} message ClearCacheResponse {}
message Image { message Image {
/// Binary image data. /// Binary image data.
bytes data = 1; bytes data = 1;
/// Image MIME type. /// Image MIME type.
string mimetype = 2; string mimetype = 2;
} }
message InputChunk { message InputChunk {
oneof chunk { oneof chunk {
/// Plain text data /// Plain text data
string text = 1; string text = 1;
/// Image data /// Image data
Image image = 2; Image image = 2;
} }
} }
message Input { message Input { repeated InputChunk chunks = 1; }
repeated InputChunk chunks = 1;
}
enum GrammarType { enum GrammarType {
GRAMMAR_TYPE_NONE = 0; GRAMMAR_TYPE_NONE = 0;
GRAMMAR_TYPE_JSON = 1; GRAMMAR_TYPE_JSON = 1;
GRAMMAR_TYPE_REGEX = 2; GRAMMAR_TYPE_REGEX = 2;
} }
message NextTokenChooserParameters { message NextTokenChooserParameters {
/// exponential scaling output probability distribution /// exponential scaling output probability distribution
float temperature = 1; float temperature = 1;
/// restricting to the k highest probability elements /// restricting to the k highest probability elements
uint32 top_k = 2; uint32 top_k = 2;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off /// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float top_p = 3; float top_p = 3;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off /// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float typical_p = 4; float typical_p = 4;
/// apply sampling on the logits /// apply sampling on the logits
bool do_sample = 5; bool do_sample = 5;
/// random seed for sampling /// random seed for sampling
uint64 seed = 6; uint64 seed = 6;
/// repetition penalty /// repetition penalty
float repetition_penalty = 7; float repetition_penalty = 7;
/// frequency penalty /// frequency penalty
float frequency_penalty = 9; float frequency_penalty = 9;
/// token watermarking using "A Watermark for Large Language Models" /// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8; bool watermark = 8;
/// grammar (applied if not empty) /// grammar (applied if not empty)
string grammar = 10; string grammar = 10;
/// grammar type /// grammar type
GrammarType grammar_type = 11; GrammarType grammar_type = 11;
} }
message StoppingCriteriaParameters { message StoppingCriteriaParameters {
/// Maximum number of generated tokens /// Maximum number of generated tokens
uint32 max_new_tokens = 1; uint32 max_new_tokens = 1;
/// Optional stopping sequences /// Optional stopping sequences
repeated string stop_sequences = 2; repeated string stop_sequences = 2;
/// Ignore end of sequence token /// Ignore end of sequence token
/// used for benchmarking /// used for benchmarking
bool ignore_eos_token = 3; bool ignore_eos_token = 3;
} }
message Request { message Request {
/// Request ID /// Request ID
uint64 id = 1; uint64 id = 1;
/// The generation context as chunks /// The generation context as chunks
Input input_chunks = 8; Input input_chunks = 8;
/// The generation context, stringified input_chunks /// The generation context, stringified input_chunks
string inputs = 2; string inputs = 2;
/// Context truncation /// Context truncation
uint32 truncate = 3; uint32 truncate = 3;
/// Next Token Chooser Parameters /// Next Token Chooser Parameters
NextTokenChooserParameters parameters = 4; NextTokenChooserParameters parameters = 4;
/// Stopping Criteria Parameters /// Stopping Criteria Parameters
StoppingCriteriaParameters stopping_parameters = 5; StoppingCriteriaParameters stopping_parameters = 5;
/// Return prefill logprobs /// Return prefill logprobs
bool prefill_logprobs = 6; bool prefill_logprobs = 6;
/// Return most likely n tokens /// Return most likely n tokens
uint32 top_n_tokens = 7; uint32 top_n_tokens = 7;
/// Paged attention blocks /// Paged attention blocks
repeated uint32 blocks = 9; repeated uint32 blocks = 9;
/// Paged attention slots /// Paged attention slots
repeated uint32 slots = 10; repeated uint32 slots = 10;
/// LORA adapter index /// LORA adapter index
optional string adapter_id = 11; optional string adapter_id = 11;
/// Prefix length that can be retrieved from the KV cache.
uint32 prefix_len = 12;
} }
message Batch { message Batch {
/// Batch ID /// Batch ID
uint64 id = 1; uint64 id = 1;
/// Individual requests /// Individual requests
repeated Request requests = 2; repeated Request requests = 2;
/// Batch size (==len(requests)) /// Batch size (==len(requests))
uint32 size = 3; uint32 size = 3;
/// Maximum number of tokens this batch will grow to /// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4; uint32 max_tokens = 4;
/// Maximum number of Paged Attention blocks /// Maximum number of Paged Attention blocks
uint32 max_blocks = 5; uint32 max_blocks = 5;
} }
message CachedBatch { message CachedBatch {
/// Batch ID /// Batch ID
uint64 id = 1; uint64 id = 1;
/// Individual requests ids /// Individual requests ids
repeated uint64 request_ids = 2; repeated uint64 request_ids = 2;
/// Batch size (==len(requests)) /// Batch size (==len(requests))
uint32 size = 3; uint32 size = 3;
/// Maximum number of tokens this batch will grow to /// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4; uint32 max_tokens = 4;
} }
enum FinishReason { enum FinishReason {
FINISH_REASON_LENGTH = 0; FINISH_REASON_LENGTH = 0;
FINISH_REASON_EOS_TOKEN = 1; FINISH_REASON_EOS_TOKEN = 1;
FINISH_REASON_STOP_SEQUENCE = 2; FINISH_REASON_STOP_SEQUENCE = 2;
} }
message GeneratedText { message GeneratedText {
/// Output /// Output
string text = 1; string text = 1;
/// Number of generated tokens /// Number of generated tokens
uint32 generated_tokens = 2; uint32 generated_tokens = 2;
/// Finish reason /// Finish reason
FinishReason finish_reason = 3; FinishReason finish_reason = 3;
/// Seed /// Seed
optional uint64 seed = 4; optional uint64 seed = 4;
} }
message Tokens { message Tokens {
/// Token IDs /// Token IDs
repeated uint32 ids = 1; repeated uint32 ids = 1;
/// Logprobs /// Logprobs
repeated float logprobs = 2; repeated float logprobs = 2;
/// tokens /// tokens
repeated string texts = 3; repeated string texts = 3;
/// special /// special
repeated bool is_special = 4; repeated bool is_special = 4;
} }
message Generation { message Generation {
/// Request ID /// Request ID
uint64 request_id = 1; uint64 request_id = 1;
/// Prefill tokens (optional) /// Prefill tokens (optional)
Tokens prefill_tokens = 2; Tokens prefill_tokens = 2;
Tokens tokens = 3; Tokens tokens = 3;
/// Complete generated text /// Complete generated text
optional GeneratedText generated_text = 4; optional GeneratedText generated_text = 4;
/// Top tokens /// Top tokens
repeated Tokens top_tokens = 5; repeated Tokens top_tokens = 5;
} }
message FilterBatchRequest { message FilterBatchRequest {
/// Batch ID /// Batch ID
uint64 batch_id = 1; uint64 batch_id = 1;
/// Requests to keep /// Requests to keep
repeated uint64 request_ids = 2; repeated uint64 request_ids = 2;
} }
message FilterBatchResponse { message FilterBatchResponse {
/// Filtered Batch (cached) /// Filtered Batch (cached)
CachedBatch batch = 1; CachedBatch batch = 1;
} }
message PrefillRequest { message PrefillRequest {
/// Batch /// Batch
Batch batch = 1; Batch batch = 1;
} }
message PrefillResponse { message PrefillResponse {
/// Generation /// Generation
repeated Generation generations = 1; repeated Generation generations = 1;
/// Next batch (cached) /// Next batch (cached)
optional CachedBatch batch = 2; optional CachedBatch batch = 2;
/// Forward elapsed time in nanoseconds /// Forward elapsed time in nanoseconds
uint64 forward_ns = 3; uint64 forward_ns = 3;
/// Decode elapsed time in nanoseconds /// Decode elapsed time in nanoseconds
uint64 decode_ns = 4; uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds /// Total elapsed time in nanoseconds
uint64 total_ns = 5; uint64 total_ns = 5;
} }
message DecodeRequest { message DecodeRequest {
/// Cached batches /// Cached batches
repeated CachedBatch batches = 1; repeated CachedBatch batches = 1;
} }
message DecodeResponse { message DecodeResponse {
/// Decodes /// Decodes
repeated Generation generations = 1; repeated Generation generations = 1;
/// Next batch (cached) /// Next batch (cached)
optional CachedBatch batch = 2; optional CachedBatch batch = 2;
/// Forward elapsed time in nanoseconds /// Forward elapsed time in nanoseconds
uint64 forward_ns = 3; uint64 forward_ns = 3;
/// Decode elapsed time in nanoseconds /// Decode elapsed time in nanoseconds
uint64 decode_ns = 4; uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds /// Total elapsed time in nanoseconds
uint64 total_ns = 5; uint64 total_ns = 5;
/// Concatenate elapsed time in nanoseconds /// Concatenate elapsed time in nanoseconds
optional uint64 concat_ns = 6; optional uint64 concat_ns = 6;
} }
message WarmupRequest { message WarmupRequest {
/// Batch to warmup on /// Batch to warmup on
Batch batch = 1; Batch batch = 1;
uint32 max_input_length = 2; uint32 max_input_length = 2;
uint32 max_prefill_tokens = 3; uint32 max_prefill_tokens = 3;
uint32 max_total_tokens = 4; uint32 max_total_tokens = 4;
} }
message WarmupResponse { message WarmupResponse {
/// Maximum number of tokens supported by the model /// Maximum number of tokens supported by the model
optional uint32 max_supported_total_tokens = 1; optional uint32 max_supported_total_tokens = 1;
} }

View File

@ -11,6 +11,7 @@ use rand::{thread_rng, Rng};
use serde_json::Value; use serde_json::Value;
use std::io::Cursor; use std::io::Cursor;
use std::iter; use std::iter;
use std::sync::Arc;
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokio::sync::mpsc; use tokio::sync::mpsc;
@ -115,13 +116,14 @@ impl Validation {
} }
} }
#[allow(clippy::type_complexity)]
#[instrument(skip(self, inputs))] #[instrument(skip(self, inputs))]
async fn validate_input( async fn validate_input(
&self, &self,
inputs: String, inputs: String,
truncate: Option<usize>, truncate: Option<usize>,
max_new_tokens: Option<u32>, max_new_tokens: Option<u32>,
) -> Result<(Vec<Chunk>, usize, u32), ValidationError> { ) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
// If we have a fast tokenizer // If we have a fast tokenizer
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
// Create response channel // Create response channel
@ -156,8 +158,10 @@ impl Validation {
)); ));
} }
let input_ids = encoding.get_ids()[..input_length].to_owned();
metrics::histogram!("tgi_request_input_length").record(input_length as f64); metrics::histogram!("tgi_request_input_length").record(input_length as f64);
Ok((inputs, input_length, max_new_tokens)) Ok((inputs, Some(input_ids), input_length, max_new_tokens))
} }
// Return inputs without validation // Return inputs without validation
else { else {
@ -180,7 +184,12 @@ impl Validation {
input_length = input_length.saturating_sub(max_new_tokens as usize); input_length = input_length.saturating_sub(max_new_tokens as usize);
} }
Ok((vec![Chunk::Text(inputs)], input_length, max_new_tokens)) Ok((
vec![Chunk::Text(inputs)],
None,
input_length,
max_new_tokens,
))
} }
} }
@ -314,7 +323,7 @@ impl Validation {
.unwrap_or(Ok(None))?; .unwrap_or(Ok(None))?;
// Validate inputs // Validate inputs
let (inputs, input_length, max_new_tokens) = self let (inputs, input_ids, input_length, max_new_tokens) = self
.validate_input(request.inputs, truncate, max_new_tokens) .validate_input(request.inputs, truncate, max_new_tokens)
.await?; .await?;
@ -391,6 +400,7 @@ impl Validation {
Ok(ValidGenerateRequest { Ok(ValidGenerateRequest {
inputs, inputs,
input_ids: input_ids.map(Arc::new),
decoder_input_details, decoder_input_details,
input_length: input_length as u32, input_length: input_length as u32,
truncate: truncate.unwrap_or(self.max_input_length) as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32,
@ -707,6 +717,7 @@ pub struct ValidStoppingParameters {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ValidGenerateRequest { pub struct ValidGenerateRequest {
pub inputs: Vec<Chunk>, pub inputs: Vec<Chunk>,
pub input_ids: Option<Arc<Vec<u32>>>,
pub input_length: u32, pub input_length: u32,
pub truncate: u32, pub truncate: u32,
pub decoder_input_details: bool, pub decoder_input_details: bool,

View File

@ -6,7 +6,12 @@ from .common import Seqlen
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.") raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda": if SYSTEM == "cuda":
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING from .cuda import (
attention,
paged_attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
)
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif SYSTEM == "ipex": elif SYSTEM == "ipex":

View File

@ -221,9 +221,11 @@ SUPPORTS_WINDOWING = V2
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
def attention( def attention(
q, q: torch.Tensor,
k, k: torch.Tensor,
v, v: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,
@ -231,14 +233,15 @@ if ATTENTION == "flashinfer":
causal=True, causal=True,
softcap=0.0, softcap=0.0,
): ):
from text_generation_server.layers.attention.flash_infer import prefill_state assert window_size_left == -1, "Windowing is not supported with flash infer"
from text_generation_server.layers.attention.flash_infer import (
prefill_with_paged_kv_state,
)
return prefill_state.get().forward( return prefill_with_paged_kv_state.get().forward(
q, q.contiguous(),
k,
v,
causal=causal, causal=causal,
window_left=window_size_left, paged_kv_cache=(key_cache, value_cache),
logits_soft_cap=softcap, logits_soft_cap=softcap,
sm_scale=softmax_scale, sm_scale=softmax_scale,
) )
@ -249,6 +252,8 @@ elif V2:
q, q,
k, k,
v, v,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,
@ -289,6 +294,8 @@ else:
q, q,
k, k,
v, v,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,

View File

@ -9,6 +9,10 @@ prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = Con
"prefill_state" "prefill_state"
) )
prefill_with_paged_kv_state: ContextVar[
flashinfer.BatchPrefillWithPagedKVCacheWrapper
] = ContextVar("prefill_with_paged_kv_state")
decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar( decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
"decode_state" "decode_state"
) )
@ -24,6 +28,78 @@ def get_workspace(device):
return workspace return workspace
def create_prefill_with_paged_kv_state(
*,
device: torch.device,
):
"""Create a prefill state that uses the KV cache."""
workspace_buffer = get_workspace(device)
return flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
)
@contextmanager
def use_prefill_with_paged_kv_state(
*,
state: flashinfer.BatchPrefillWithPagedKVCacheWrapper,
block_tables: torch.Tensor,
cu_seqlens: torch.Tensor,
input_lengths: torch.Tensor,
num_heads: int,
num_kv_heads: int,
head_size: int,
page_size: int,
query_dtype: str = "float16",
):
"""
Context manager to set the active flashinfer prefill state to the given
`state` and parameters. This state will be used by all calls to the
`attention` function while the context manager is active.
"""
indptr = torch.zeros(
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
)
# Round up to page size and then calculate the cumulative sum to get
# the indices into the block table.
torch.add(input_lengths, page_size - 1, out=indptr[1:])
indptr[1:].div_(page_size, rounding_mode="floor")
indptr[1:].cumsum_(-1)
# Get the lengths of the last page in a block.
if page_size == 1:
last_page_len = torch.ones(
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
)
else:
last_page_len = torch.empty(
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
)
torch.sub(input_lengths, 1, out=last_page_len)
last_page_len.remainder_(page_size)
last_page_len += 1
token = prefill_with_paged_kv_state.set(state)
try:
state.begin_forward(
qo_indptr=cu_seqlens,
paged_kv_indptr=indptr,
paged_kv_indices=block_tables,
paged_kv_last_page_len=last_page_len,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
q_data_type=query_dtype,
page_size=page_size,
)
yield
finally:
state.end_forward()
if token is not None:
prefill_with_paged_kv_state.reset(token)
def create_prefill_state( def create_prefill_state(
*, *,
device: torch.device, device: torch.device,

View File

@ -298,6 +298,8 @@ class FlashCohereAttention(torch.nn.Module):
query, query,
key, key,
value, value,
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -337,6 +337,8 @@ class DbrxAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -365,6 +365,8 @@ class DeepseekV2Attention(torch.nn.Module):
query, query,
key, key,
value, value,
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -238,6 +238,8 @@ class FlashGemma2Attention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -232,6 +232,8 @@ class FlashGemmaAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -232,6 +232,8 @@ class FlashGPT2Attention(torch.nn.Module):
query, query,
key, key,
value, value,
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -21,6 +21,7 @@
from contextlib import contextmanager from contextlib import contextmanager
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from loguru import logger
import torch import torch
import torch.distributed import torch.distributed
@ -220,6 +221,8 @@ class FlashLlamaAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -219,6 +219,8 @@ class MistralAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -276,6 +276,8 @@ class MixtralAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -173,6 +173,8 @@ class FlashNeoxAttention(torch.nn.Module):
qkv[:, 0], qkv[:, 0],
qkv[:, 1], qkv[:, 1],
qkv[:, 2], qkv[:, 2],
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -194,6 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -137,6 +137,8 @@ class Qwen2Attention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -208,6 +208,8 @@ class FlashRWAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -326,6 +328,8 @@ class FlashRWLargeAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1), torch.select(kv, dim=2, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -293,6 +293,8 @@ class FlashMQAttention(torch.nn.Module):
query, query,
torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1), torch.select(key_value, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -242,6 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -20,6 +20,9 @@ from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, D
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from text_generation_server.layers.attention.flash_infer import (
create_prefill_with_paged_kv_state,
)
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models import Model from text_generation_server.models import Model
@ -43,6 +46,7 @@ from text_generation_server.models.globals import (
ATTENTION, ATTENTION,
BLOCK_SIZE, BLOCK_SIZE,
CUDA_GRAPHS, CUDA_GRAPHS,
PREFIX_CACHING,
get_adapter_to_index, get_adapter_to_index,
) )
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
@ -108,6 +112,9 @@ class FlashCausalLMBatch(Batch):
block_tables_tensor: torch.Tensor block_tables_tensor: torch.Tensor
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
slots: torch.Tensor slots: torch.Tensor
# size [b], containing the number of blocks that can be retrieved from the cache
prefix_lens: List[int]
prefix_lens_tensor: torch.Tensor
max_seqlen: int max_seqlen: int
@ -116,6 +123,9 @@ class FlashCausalLMBatch(Batch):
prefill_next_token_indices: Optional[torch.tensor] prefill_next_token_indices: Optional[torch.tensor]
prefill_cu_outlens: Optional[List[int]] prefill_cu_outlens: Optional[List[int]]
# Prefixes
prefix_ids: List[List[int]]
# All tokens # All tokens
all_input_ids: List[List[int]] all_input_ids: List[List[int]]
all_input_ids_tensor: torch.Tensor all_input_ids_tensor: torch.Tensor
@ -183,6 +193,7 @@ class FlashCausalLMBatch(Batch):
prefix_offsets = [] prefix_offsets = []
read_offsets = [] read_offsets = []
all_input_ids = [] all_input_ids = []
prefix_ids = []
requests_idx_mapping = {} requests_idx_mapping = {}
all_prefill_logprobs = True all_prefill_logprobs = True
@ -200,7 +211,7 @@ class FlashCausalLMBatch(Batch):
# Cumulative length # Cumulative length
cumulative_length = 0 cumulative_length = 0
cumulative_max_length = 0 cumulative_slot_tokens = 0
prefill_out_cumulative_length = 0 prefill_out_cumulative_length = 0
num_blocks = 0 num_blocks = 0
@ -210,6 +221,7 @@ class FlashCausalLMBatch(Batch):
block_tables = [] block_tables = []
slots = [] slots = []
prefix_lens = []
# Parse batch # Parse batch
for i, (r, tokenized_input) in enumerate( for i, (r, tokenized_input) in enumerate(
@ -225,6 +237,19 @@ class FlashCausalLMBatch(Batch):
): ):
tokenized_input = tokenized_input[1:] tokenized_input = tokenized_input[1:]
orig_input_length = len(tokenized_input)
if PREFIX_CACHING:
prefix_len = r.prefix_len
if prefix_len == orig_input_length:
assert prefix_len > 0
prefix_len -= 1
else:
prefix_len = 0
prefix_ids.append(tokenized_input[:prefix_len])
tokenized_input = tokenized_input[prefix_len:]
input_length = len(tokenized_input) input_length = len(tokenized_input)
input_lengths.append(input_length) input_lengths.append(input_length)
@ -234,7 +259,9 @@ class FlashCausalLMBatch(Batch):
all_input_ids.append(tokenized_input) all_input_ids.append(tokenized_input)
# Position ids # Position ids
request_position_ids = torch.arange(0, input_length, dtype=torch.int32) request_position_ids = torch.arange(
prefix_len, orig_input_length, dtype=torch.int32
)
position_ids.append(request_position_ids) position_ids.append(request_position_ids)
# Add cumulative lengths of all previous inputs # Add cumulative lengths of all previous inputs
@ -258,11 +285,17 @@ class FlashCausalLMBatch(Batch):
# Remove one as the first token des not have a past # Remove one as the first token des not have a past
speculative_length = get_speculate() speculative_length = get_speculate()
speculative_length = 0 if speculative_length is None else speculative_length speculative_length = 0 if speculative_length is None else speculative_length
total_tokens = input_length + max_new_tokens - 1 + speculative_length
# Tokens that need to be mapped to blocks.
block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length
# Tokens that need to be mapped to slots. We don't need slots for the
# cached prefix (if present).
slot_tokens = input_length + max_new_tokens - 1 + speculative_length
# blocks and slots can be empty (for example in warmup) # blocks and slots can be empty (for example in warmup)
if not r.blocks: if not r.blocks:
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) needed_blocks = math.ceil(block_tokens / BLOCK_SIZE)
request_blocks = [ request_blocks = [
b for b in range(num_blocks, num_blocks + needed_blocks) b for b in range(num_blocks, num_blocks + needed_blocks)
] ]
@ -273,16 +306,20 @@ class FlashCausalLMBatch(Batch):
] ]
else: else:
request_blocks = r.blocks request_blocks = r.blocks
request_slots = r.slots request_slots = r.slots[
prefix_len: #: orig_input_length + max_new_tokens + speculative_length
]
block_tables.append(request_blocks) block_tables.append(request_blocks)
slots.extend(request_slots[:total_tokens])
slots.extend(request_slots)
prefix_lens.append(prefix_len)
num_blocks += len(request_blocks) num_blocks += len(request_blocks)
start_slots.append(cumulative_max_length) start_slots.append(cumulative_slot_tokens)
request_slot_indices = torch.arange( request_slot_indices = torch.arange(
cumulative_max_length, cumulative_slot_tokens,
cumulative_max_length + input_length, cumulative_slot_tokens + input_length,
dtype=torch.int64, dtype=torch.int64,
) )
slot_indices.append(request_slot_indices) slot_indices.append(request_slot_indices)
@ -318,7 +355,7 @@ class FlashCausalLMBatch(Batch):
# Update # Update
cumulative_length += input_length cumulative_length += input_length
cumulative_max_length += total_tokens cumulative_slot_tokens += slot_tokens
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, len(request_blocks)) max_blocks = max(max_blocks, len(request_blocks))
max_length = max( max_length = max(
@ -395,12 +432,14 @@ class FlashCausalLMBatch(Batch):
) )
slots = torch.tensor(slots, dtype=torch.int64, device=device) slots = torch.tensor(slots, dtype=torch.int64, device=device)
block_tables_tensor = torch.zeros( block_tables_tensor = torch.zeros(
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu" (len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
) )
for i, request_blocks in enumerate(block_tables): for i, request_blocks in enumerate(block_tables):
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
block_tables_tensor = block_tables_tensor.to(device) block_tables_tensor = block_tables_tensor.to(device)
prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
@ -415,6 +454,8 @@ class FlashCausalLMBatch(Batch):
block_tables=block_tables, block_tables=block_tables,
block_tables_tensor=block_tables_tensor, block_tables_tensor=block_tables_tensor,
slots=slots, slots=slots,
prefix_lens=prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=prefill_head_indices, prefill_head_indices=prefill_head_indices,
prefill_next_token_indices=prefill_next_token_indices, prefill_next_token_indices=prefill_next_token_indices,
@ -425,6 +466,7 @@ class FlashCausalLMBatch(Batch):
read_offsets=read_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
@ -480,6 +522,7 @@ class FlashCausalLMBatch(Batch):
start_slots = [] start_slots = []
block_tables = [] block_tables = []
all_input_ids = [] all_input_ids = []
prefix_ids = []
input_lengths = [] input_lengths = []
prefix_offsets = [] prefix_offsets = []
@ -506,6 +549,7 @@ class FlashCausalLMBatch(Batch):
max_seqlen = max(max_seqlen, request_input_length) max_seqlen = max(max_seqlen, request_input_length)
all_input_ids.append(self.all_input_ids[idx]) all_input_ids.append(self.all_input_ids[idx])
prefix_ids.append(self.prefix_ids[idx])
input_lengths.append(request_input_length) input_lengths.append(request_input_length)
prefix_offsets.append(self.prefix_offsets[idx]) prefix_offsets.append(self.prefix_offsets[idx])
@ -591,6 +635,7 @@ class FlashCausalLMBatch(Batch):
read_offsets=read_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
@ -651,6 +696,7 @@ class FlashCausalLMBatch(Batch):
block_tables_tensor = batches[0].block_tables_tensor.new_zeros( block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
(total_batch_size, max_blocks) (total_batch_size, max_blocks)
) )
prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size)
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
(total_batch_size, max_length) (total_batch_size, max_length)
) )
@ -668,7 +714,9 @@ class FlashCausalLMBatch(Batch):
start_slots = [] start_slots = []
block_tables = [] block_tables = []
prefix_lens = []
all_input_ids = [] all_input_ids = []
prefix_ids = []
input_lengths = [] input_lengths = []
prefix_offsets = [] prefix_offsets = []
@ -730,10 +778,14 @@ class FlashCausalLMBatch(Batch):
start_index:end_index, : batch.block_tables_tensor.shape[1] start_index:end_index, : batch.block_tables_tensor.shape[1]
] = batch.block_tables_tensor[:, :max_blocks] ] = batch.block_tables_tensor[:, :max_blocks]
prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor
start_slots.append(batch.start_slots + cumulative_slots) start_slots.append(batch.start_slots + cumulative_slots)
block_tables.extend(batch.block_tables) block_tables.extend(batch.block_tables)
prefix_lens.extend(batch.prefix_lens)
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
prefix_ids.extend(batch.prefix_ids)
input_lengths.extend(batch.input_lengths) input_lengths.extend(batch.input_lengths)
prefix_offsets.extend(batch.prefix_offsets) prefix_offsets.extend(batch.prefix_offsets)
@ -779,6 +831,8 @@ class FlashCausalLMBatch(Batch):
slot_indices=slot_indices, slot_indices=slot_indices,
block_tables=block_tables, block_tables=block_tables,
block_tables_tensor=block_tables_tensor, block_tables_tensor=block_tables_tensor,
prefix_lens=prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
slots=slots, slots=slots,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=None, prefill_head_indices=None,
@ -790,6 +844,7 @@ class FlashCausalLMBatch(Batch):
read_offsets=read_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
@ -944,6 +999,9 @@ class FlashCausalLM(Model):
) )
self.prefill_state = create_prefill_state(device=device) self.prefill_state = create_prefill_state(device=device)
self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(
device=device
)
if not CUDA_GRAPHS: if not CUDA_GRAPHS:
self.decode_state = create_decode_state( self.decode_state = create_decode_state(
@ -1042,12 +1100,23 @@ class FlashCausalLM(Model):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
slots = torch.arange(bs, dtype=torch.int64, device=self.device) slots = torch.arange(bs, dtype=torch.int64, device=self.device)
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s input_lengths = [max_s] * bs
block_tables = ( prefix_lengths = [0] * bs
torch.arange(max_bt, dtype=torch.int32, device=self.device) input_lengths_tensor = (
.repeat(bs) torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
.reshape((bs, max_bt))
) )
prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
block_tables = torch.arange(
max_bt, dtype=torch.int32, device=self.device
).repeat(bs)
block_tables = block_tables.reshape((bs, max_bt))
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=input_lengths,
prefix_lens=prefix_lengths,
)
self.cuda_graphs[bs] = { self.cuda_graphs[bs] = {
"input_ids": input_ids, "input_ids": input_ids,
@ -1055,9 +1124,9 @@ class FlashCausalLM(Model):
"kv_cache": self.kv_cache, "kv_cache": self.kv_cache,
"block_tables": block_tables, "block_tables": block_tables,
"slots": slots, "slots": slots,
"input_lengths": input_lengths, "input_lengths": input_lengths_tensor,
} }
input_lengths_ = Seqlen(input_lengths=input_lengths) input_lengths_ = Seqlen(input_lengths=input_lengths_tensor)
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph self.cuda_graphs[bs]["graph"] = graph
@ -1072,7 +1141,7 @@ class FlashCausalLM(Model):
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device) last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
state = create_decode_state_cuda_graphs( state = create_decode_state_cuda_graphs(
device=input_ids.device, device=input_ids.device,
block_tables=block_tables.view(-1), block_tables=block_tables,
block_tables_ptr=block_tables_ptr, block_tables_ptr=block_tables_ptr,
last_page_len=last_page_len, last_page_len=last_page_len,
num_heads=self.num_heads, num_heads=self.num_heads,
@ -1088,7 +1157,10 @@ class FlashCausalLM(Model):
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
input_lengths=input_lengths, input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
state=state, state=state,
prefix_lens=prefix_lengths,
prefix_lens_tensor=prefix_lengths_tensor,
): ):
self.model.forward( self.model.forward(
input_ids=input_ids, input_ids=input_ids,
@ -1106,7 +1178,7 @@ class FlashCausalLM(Model):
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL): with torch.cuda.graph(graph, pool=MEM_POOL):
input_lengths = Seqlen(input_lengths=input_lengths) input_lengths_tensor = Seqlen(input_lengths=input_lengths_tensor)
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
@ -1114,7 +1186,7 @@ class FlashCausalLM(Model):
kv_cache=self.kv_cache, kv_cache=self.kv_cache,
block_tables=block_tables, block_tables=block_tables,
slots=slots, slots=slots,
input_lengths=input_lengths, input_lengths=input_lengths_tensor,
max_s=max_s, max_s=max_s,
prefill_cache_indices=None, prefill_cache_indices=None,
lm_head_indices=None, lm_head_indices=None,
@ -1343,7 +1415,10 @@ class FlashCausalLM(Model):
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths=input_lengths, input_lengths=batch.input_lengths,
input_lengths_tensor=input_lengths,
prefix_lens=batch.prefix_lens,
prefix_lens_tensor=batch.prefix_lens_tensor,
): ):
input_lengths = Seqlen(input_lengths=input_lengths) input_lengths = Seqlen(input_lengths=input_lengths)
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
@ -1367,19 +1442,32 @@ class FlashCausalLM(Model):
# Static inputs are potentially padded # Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
cuda_graph["block_tables"][ if ATTENTION == "flashinfer":
: block_tables.shape[0], : block_tables.shape[1] block_tables = block_tables_to_ragged(
] = block_tables block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
)
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else:
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"].fill_(-1) cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
input_lengths + batch.prefix_lens_tensor
)
state = cuda_graph.get("state") state = cuda_graph.get("state")
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
input_lengths=input_lengths, input_lengths=batch.input_lengths,
input_lengths_tensor=input_lengths,
prefix_lens=batch.prefix_lens,
prefix_lens_tensor=batch.prefix_lens_tensor,
state=state, state=state,
): ):
# Replay the graph # Replay the graph
@ -1578,6 +1666,7 @@ class FlashCausalLM(Model):
batch.read_offsets, batch.read_offsets,
batch.stopping_criterias, batch.stopping_criterias,
batch.all_input_ids, batch.all_input_ids,
batch.prefix_ids,
batch.next_token_chooser.do_sample, batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds, batch.next_token_chooser.seeds,
batch.top_n_tokens, batch.top_n_tokens,
@ -1595,6 +1684,7 @@ class FlashCausalLM(Model):
read_offset, read_offset,
stopping_criteria, stopping_criteria,
all_input_ids, all_input_ids,
prefix_ids,
do_sample, do_sample,
seed, seed,
top_n_tokens, top_n_tokens,
@ -1669,18 +1759,18 @@ class FlashCausalLM(Model):
out_end_index = batch.prefill_cu_outlens[i + 1] out_end_index = batch.prefill_cu_outlens[i + 1]
# Remove generated token to only have prefill and add nan for first prompt token # Remove generated token to only have prefill and add nan for first prompt token
request_prefill_logprobs = [float("nan")] + prefill_logprobs[ request_prefill_logprobs = (
out_start_index : out_end_index - 1 [float("nan")] * (len(prefix_ids) + 1)
] ) + prefill_logprobs[out_start_index : out_end_index - 1]
prefill_token_ids = all_input_ids[:-1] prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode( prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids, prefix_ids + prefill_token_ids,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
prefill_tokens = Tokens( prefill_tokens = Tokens(
prefill_token_ids, prefix_ids + prefill_token_ids,
request_prefill_logprobs, request_prefill_logprobs,
prefill_texts, prefill_texts,
is_special=[], is_special=[],
@ -1762,7 +1852,10 @@ class FlashCausalLM(Model):
*, *,
block_tables: torch.Tensor, block_tables: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
input_lengths: torch.Tensor, input_lengths: List[int],
input_lengths_tensor: torch.Tensor,
prefix_lens: List[int],
prefix_lens_tensor: torch.Tensor,
state: Optional[Any] = None, state: Optional[Any] = None,
) -> ContextManager: ) -> ContextManager:
if ATTENTION != "flashinfer": if ATTENTION != "flashinfer":
@ -1771,24 +1864,65 @@ class FlashCausalLM(Model):
from text_generation_server.layers.attention.flash_infer import ( from text_generation_server.layers.attention.flash_infer import (
use_decode_state, use_decode_state,
use_prefill_state, use_prefill_state,
use_prefill_with_paged_kv_state,
) )
# has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens)
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
return use_prefill_state( if True: # has_prefix_lens:
state=state if state is not None else self.prefill_state, return use_prefill_with_paged_kv_state(
cu_seqlens=cu_seqlen_prefill, state=(
num_heads=self.num_heads, state if state is not None else self.prefill_with_paged_kv_state
num_kv_heads=self.num_kv_heads, ),
head_size=self.head_size, block_tables=block_tables_to_ragged(
) block_tables=block_tables,
input_lengths=input_lengths,
prefix_lens=prefix_lens,
),
cu_seqlens=cu_seqlen_prefill,
input_lengths=input_lengths_tensor + prefix_lens_tensor,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
page_size=BLOCK_SIZE,
)
else:
return use_prefill_state(
state=state if state is not None else self.prefill_state,
cu_seqlens=cu_seqlen_prefill,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
)
else: else:
assert input_lengths is not None assert input_lengths_tensor is not None
return use_decode_state( return use_decode_state(
state=state if state is not None else self.decode_state, state=state if state is not None else self.decode_state,
input_lengths=input_lengths, input_lengths=input_lengths_tensor + prefix_lens_tensor,
block_tables=block_tables.view(-1), block_tables=block_tables,
num_heads=self.num_heads, num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
head_size=self.head_size, head_size=self.head_size,
page_size=BLOCK_SIZE, page_size=BLOCK_SIZE,
) )
def block_tables_to_ragged(
*, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int]
) -> torch.Tensor:
"""Convert block table to ragged format compatible with FlashInfer."""
assert len(input_lengths) == len(prefix_lens)
total_len = sum(input_lengths) + sum(prefix_lens)
block_tables_ragged = torch.empty(
total_len, dtype=torch.int32, device=block_tables.device
)
offset = 0
for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)):
seq_len = prefix_len + input_length
block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]
offset += seq_len
return block_tables_ragged

View File

@ -5,17 +5,29 @@ from typing import Dict, Optional
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
ATTENTION = os.getenv("ATTENTION", "paged") PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", False)
log_master(logger.info, f"Using Attention = {PREFIX_CACHING}")
ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged")
_expected = {"paged", "flashdecoding", "flashinfer"} _expected = {"paged", "flashdecoding", "flashinfer"}
assert ( assert (
ATTENTION in _expected ATTENTION in _expected
), f"Attention is not valid {ATTENTION}, expected {_expected}" ), f"Attention is not valid {ATTENTION}, expected {_expected}"
log_master(logger.info, f"Using Attention = {ATTENTION}") log_master(logger.info, f"Using Attention = {ATTENTION}")
if PREFIX_CACHING and ATTENTION != "flashinfer":
raise RuntimeError("Prefix caching is only supported with flashinfer")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli # This is overridden by the cli
BLOCK_SIZE: int = 256 if ATTENTION == "flashdecoding" else 16
BLOCK_SIZE: int
if ATTENTION == "flashdecoding":
BLOCK_SIZE = 256
elif ATTENTION == "flashinfer":
BLOCK_SIZE = 1
else:
BLOCK_SIZE = 16
cuda_graphs = os.getenv("CUDA_GRAPHS") cuda_graphs = os.getenv("CUDA_GRAPHS")
if cuda_graphs is not None: if cuda_graphs is not None: