Prefix caching WIP
This commit is contained in:
parent
7a48a84784
commit
7735b385dc
|
@ -4045,6 +4045,7 @@ dependencies = [
|
|||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"slotmap",
|
||||
"text-generation-router",
|
||||
"thiserror",
|
||||
"tokenizers",
|
||||
|
|
|
@ -156,6 +156,7 @@ impl Client {
|
|||
// Blocks and slots will be set on the server side if we use paged attention
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
prefix_len: 0,
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
|
|
|
@ -244,6 +244,7 @@ impl Health for ShardedClient {
|
|||
// Block 0 is reserved for health checks
|
||||
blocks: vec![0],
|
||||
slots: (0..16).collect(),
|
||||
prefix_len: 0,
|
||||
adapter_id: None,
|
||||
};
|
||||
let batch = Batch {
|
||||
|
|
|
@ -33,6 +33,7 @@ rand = "0.8.5"
|
|||
reqwest = { version = "0.11.20", features = [] }
|
||||
serde = "1.0.188"
|
||||
serde_json = "1.0.107"
|
||||
slotmap = "1.0.7"
|
||||
thiserror = "1.0.48"
|
||||
tokenizers = { workspace = true}
|
||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
|
|
|
@ -35,15 +35,24 @@ impl BackendV3 {
|
|||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
) -> Self {
|
||||
let prefix_caching = if let Ok(prefix_caching) = std::env::var("USE_PREFIX_CACHING") {
|
||||
matches!(prefix_caching.as_str(), "true" | "1")
|
||||
} else {
|
||||
false
|
||||
};
|
||||
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
|
||||
attention
|
||||
.parse()
|
||||
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
|
||||
} else if prefix_caching {
|
||||
Attention::FlashInfer
|
||||
} else {
|
||||
Attention::Paged
|
||||
};
|
||||
let block_size = if attention == Attention::FlashDecoding {
|
||||
256
|
||||
} else if attention == Attention::FlashInfer {
|
||||
1
|
||||
} else {
|
||||
16
|
||||
};
|
||||
|
@ -51,6 +60,7 @@ impl BackendV3 {
|
|||
let queue = Queue::new(
|
||||
requires_padding,
|
||||
block_size,
|
||||
prefix_caching,
|
||||
window_size,
|
||||
speculate,
|
||||
max_batch_total_tokens,
|
||||
|
|
|
@ -1,16 +1,26 @@
|
|||
use std::cmp::min;
|
||||
use std::{cmp::min, sync::Arc};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
use crate::radix::RadixAllocator;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct BlockAllocation {
|
||||
pub allocation_id: u64,
|
||||
pub blocks: Vec<u32>,
|
||||
pub slots: Vec<u32>,
|
||||
block_allocator: BlockAllocator,
|
||||
|
||||
/// Prefix that was cached and for which the KV does not have to
|
||||
/// be recomputed.
|
||||
pub prefix_len: u32,
|
||||
|
||||
pub(crate) block_allocator: Option<BlockAllocator>,
|
||||
}
|
||||
|
||||
impl Drop for BlockAllocation {
|
||||
fn drop(&mut self) {
|
||||
self.block_allocator.free(self.blocks.clone())
|
||||
if let Some(block_allocator) = self.block_allocator.as_mut() {
|
||||
block_allocator.free(self.blocks.clone(), self.allocation_id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -24,6 +34,7 @@ impl BlockAllocator {
|
|||
pub(crate) fn new(
|
||||
max_batch_total_tokens: u32,
|
||||
block_size: u32,
|
||||
prefix_caching: bool,
|
||||
window_size: Option<u32>,
|
||||
) -> Self {
|
||||
// Create channel
|
||||
|
@ -33,6 +44,7 @@ impl BlockAllocator {
|
|||
tokio::spawn(block_allocator_task(
|
||||
max_batch_total_tokens / block_size,
|
||||
block_size,
|
||||
prefix_caching,
|
||||
window_size,
|
||||
receiver,
|
||||
));
|
||||
|
@ -42,28 +54,32 @@ impl BlockAllocator {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn allocate(&self, tokens: u32) -> Option<BlockAllocation> {
|
||||
pub(crate) async fn allocate(
|
||||
&self,
|
||||
tokens: u32,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
) -> Option<BlockAllocation> {
|
||||
let (response_sender, response_receiver) = oneshot::channel();
|
||||
self.block_allocator
|
||||
.send(BlockAllocatorCommand::Allocate {
|
||||
tokens,
|
||||
prefill_tokens,
|
||||
response_sender,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
response_receiver
|
||||
.await
|
||||
.unwrap()
|
||||
.map(|(blocks, slots)| BlockAllocation {
|
||||
blocks,
|
||||
slots,
|
||||
block_allocator: self.clone(),
|
||||
response_receiver.await.unwrap().map(|mut allocation| {
|
||||
allocation.block_allocator = Some(self.clone());
|
||||
allocation
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn free(&self, blocks: Vec<u32>) {
|
||||
pub(crate) fn free(&self, blocks: Vec<u32>, allocation_id: u64) {
|
||||
self.block_allocator
|
||||
.send(BlockAllocatorCommand::Free { blocks })
|
||||
.send(BlockAllocatorCommand::Free {
|
||||
allocation_id,
|
||||
blocks,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
@ -71,54 +87,29 @@ impl BlockAllocator {
|
|||
async fn block_allocator_task(
|
||||
blocks: u32,
|
||||
block_size: u32,
|
||||
prefix_caching: bool,
|
||||
window_size: Option<u32>,
|
||||
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
|
||||
) {
|
||||
// Block 0 is reserved for health checks
|
||||
let mut free_blocks: Vec<u32> = (1..blocks).collect();
|
||||
let mut allocator: Box<dyn Allocator + Send> = if prefix_caching {
|
||||
Box::new(RadixAllocator::new(block_size, blocks, window_size))
|
||||
} else {
|
||||
Box::new(SimpleAllocator::new(blocks, block_size, window_size))
|
||||
};
|
||||
while let Some(cmd) = receiver.recv().await {
|
||||
match cmd {
|
||||
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks),
|
||||
BlockAllocatorCommand::Free {
|
||||
blocks,
|
||||
allocation_id,
|
||||
} => allocator.free(blocks, allocation_id),
|
||||
BlockAllocatorCommand::Allocate {
|
||||
tokens,
|
||||
prefill_tokens,
|
||||
response_sender,
|
||||
} => {
|
||||
// Apply window size
|
||||
let (required_blocks, repeats) = {
|
||||
let (tokens, repeats) = match window_size {
|
||||
None => (tokens, 1),
|
||||
Some(window_size) => {
|
||||
let repeats = (tokens + window_size - 1) / window_size;
|
||||
let tokens = min(tokens, window_size);
|
||||
(tokens, repeats as usize)
|
||||
}
|
||||
};
|
||||
// Pad to a multiple of block size
|
||||
let required_blocks = (tokens + block_size - 1) / block_size;
|
||||
(required_blocks, repeats)
|
||||
};
|
||||
|
||||
let tokens = tokens as usize;
|
||||
let allocation = if required_blocks > free_blocks.len() as u32 {
|
||||
None
|
||||
} else {
|
||||
let blocks =
|
||||
free_blocks.split_off(free_blocks.len() - required_blocks as usize);
|
||||
let mut slots = Vec::with_capacity(
|
||||
(required_blocks * block_size * repeats as u32) as usize,
|
||||
);
|
||||
|
||||
'slots: for block_id in blocks.repeat(repeats).iter() {
|
||||
for s in (block_id * block_size)..((block_id + 1) * block_size) {
|
||||
slots.push(s);
|
||||
if slots.len() == tokens {
|
||||
break 'slots;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some((blocks, slots))
|
||||
};
|
||||
response_sender.send(allocation).unwrap();
|
||||
response_sender
|
||||
.send(allocator.allocate(tokens, prefill_tokens))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -128,9 +119,92 @@ async fn block_allocator_task(
|
|||
enum BlockAllocatorCommand {
|
||||
Free {
|
||||
blocks: Vec<u32>,
|
||||
allocation_id: u64,
|
||||
},
|
||||
Allocate {
|
||||
tokens: u32,
|
||||
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>)>>,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
response_sender: oneshot::Sender<Option<BlockAllocation>>,
|
||||
},
|
||||
}
|
||||
|
||||
pub(crate) trait Allocator {
|
||||
fn allocate(
|
||||
&mut self,
|
||||
tokens: u32,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
) -> Option<BlockAllocation>;
|
||||
|
||||
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
|
||||
}
|
||||
|
||||
pub struct SimpleAllocator {
|
||||
free_blocks: Vec<u32>,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
}
|
||||
|
||||
impl SimpleAllocator {
|
||||
fn new(blocks: u32, block_size: u32, window_size: Option<u32>) -> Self {
|
||||
SimpleAllocator {
|
||||
block_size,
|
||||
// Block 0 is reserved for health checks
|
||||
free_blocks: (1..blocks).collect(),
|
||||
window_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Allocator for SimpleAllocator {
|
||||
fn allocate(
|
||||
&mut self,
|
||||
tokens: u32,
|
||||
_prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
) -> Option<BlockAllocation> {
|
||||
// Apply window size
|
||||
let (required_blocks, repeats) = {
|
||||
let (tokens, repeats) = match self.window_size {
|
||||
None => (tokens, 1),
|
||||
Some(window_size) => {
|
||||
let repeats = (tokens + window_size - 1) / window_size;
|
||||
let tokens = min(tokens, window_size);
|
||||
(tokens, repeats as usize)
|
||||
}
|
||||
};
|
||||
// Pad to a multiple of block size
|
||||
let required_blocks = (tokens + self.block_size - 1) / self.block_size;
|
||||
(required_blocks, repeats)
|
||||
};
|
||||
|
||||
let tokens = tokens as usize;
|
||||
if required_blocks > self.free_blocks.len() as u32 {
|
||||
None
|
||||
} else {
|
||||
let blocks = self
|
||||
.free_blocks
|
||||
.split_off(self.free_blocks.len() - required_blocks as usize);
|
||||
let mut slots =
|
||||
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
|
||||
|
||||
'slots: for block_id in blocks.repeat(repeats).iter() {
|
||||
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
||||
slots.push(s);
|
||||
if slots.len() == tokens {
|
||||
break 'slots;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(BlockAllocation {
|
||||
allocation_id: 0,
|
||||
blocks,
|
||||
slots,
|
||||
prefix_len: 0,
|
||||
block_allocator: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn free(&mut self, blocks: Vec<u32>, _allocation_id: u64) {
|
||||
self.free_blocks.extend(blocks)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -157,6 +157,7 @@ impl Client {
|
|||
// Blocks and slots will be set on the server side if we use paged attention
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
prefix_len: 0,
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
|
|
|
@ -245,6 +245,7 @@ impl Health for ShardedClient {
|
|||
// Block 0 is reserved for health checks
|
||||
blocks: vec![0],
|
||||
slots: (0..16).collect(),
|
||||
prefix_len: 0,
|
||||
adapter_id: None,
|
||||
};
|
||||
let batch = Batch {
|
||||
|
|
|
@ -2,6 +2,7 @@ mod backend;
|
|||
mod block_allocator;
|
||||
mod client;
|
||||
mod queue;
|
||||
mod radix;
|
||||
|
||||
use crate::client::{ClientError, ShardedClient};
|
||||
pub(crate) use backend::BackendV3;
|
||||
|
|
|
@ -46,6 +46,7 @@ impl Queue {
|
|||
pub(crate) fn new(
|
||||
requires_padding: bool,
|
||||
block_size: u32,
|
||||
prefix_caching: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
|
@ -57,6 +58,7 @@ impl Queue {
|
|||
tokio::spawn(queue_task(
|
||||
requires_padding,
|
||||
block_size,
|
||||
prefix_caching,
|
||||
window_size,
|
||||
speculate,
|
||||
max_batch_total_tokens,
|
||||
|
@ -109,6 +111,7 @@ impl Queue {
|
|||
async fn queue_task(
|
||||
requires_padding: bool,
|
||||
block_size: u32,
|
||||
prefix_caching: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
|
@ -117,6 +120,7 @@ async fn queue_task(
|
|||
let mut state = State::new(
|
||||
requires_padding,
|
||||
block_size,
|
||||
prefix_caching,
|
||||
window_size,
|
||||
speculate,
|
||||
max_batch_total_tokens,
|
||||
|
@ -176,12 +180,19 @@ impl State {
|
|||
fn new(
|
||||
requires_padding: bool,
|
||||
block_size: u32,
|
||||
prefix_caching: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
) -> Self {
|
||||
let block_allocator = (!requires_padding)
|
||||
.then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size));
|
||||
let block_allocator = (!requires_padding).then(|| {
|
||||
BlockAllocator::new(
|
||||
max_batch_total_tokens,
|
||||
block_size,
|
||||
prefix_caching,
|
||||
window_size,
|
||||
)
|
||||
});
|
||||
|
||||
Self {
|
||||
entries: VecDeque::with_capacity(128),
|
||||
|
@ -305,7 +316,10 @@ impl State {
|
|||
+ self.speculate
|
||||
- 1;
|
||||
|
||||
match block_allocator.allocate(tokens).await {
|
||||
match block_allocator
|
||||
.allocate(tokens, entry.request.input_ids.clone())
|
||||
.await
|
||||
{
|
||||
None => {
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
|
@ -331,11 +345,12 @@ impl State {
|
|||
// Update entry
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
|
||||
let (blocks, slots) = match &block_allocation {
|
||||
None => (Vec::new(), Vec::new()),
|
||||
let (blocks, slots, prefix_len) = match &block_allocation {
|
||||
None => (Vec::new(), Vec::new(), 0),
|
||||
Some(block_allocation) => (
|
||||
block_allocation.blocks.clone(),
|
||||
block_allocation.slots.clone(),
|
||||
block_allocation.prefix_len,
|
||||
),
|
||||
};
|
||||
|
||||
|
@ -372,6 +387,7 @@ impl State {
|
|||
top_n_tokens: entry.request.top_n_tokens,
|
||||
blocks,
|
||||
slots,
|
||||
prefix_len,
|
||||
adapter_id: entry.request.adapter_id.clone(),
|
||||
});
|
||||
// Set batch_time
|
||||
|
@ -480,6 +496,8 @@ impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
use tracing::info_span;
|
||||
|
||||
|
@ -492,6 +510,7 @@ mod tests {
|
|||
let entry = Entry {
|
||||
request: ValidGenerateRequest {
|
||||
inputs: vec![],
|
||||
input_ids: Some(Arc::new(vec![])),
|
||||
input_length: 0,
|
||||
truncate: 0,
|
||||
decoder_input_details: false,
|
||||
|
@ -527,7 +546,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_append() {
|
||||
let mut state = State::new(false, 1, None, 0, 16);
|
||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||
let (entry, _guard) = default_entry();
|
||||
|
||||
assert_eq!(state.next_id, 0);
|
||||
|
@ -543,7 +562,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_empty() {
|
||||
let mut state = State::new(false, 1, None, 0, 16);
|
||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||
|
||||
assert!(state.next_batch(None, None, 1, 1).await.is_none());
|
||||
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||
|
@ -551,7 +570,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_min_size() {
|
||||
let mut state = State::new(false, 1, None, 0, 16);
|
||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
|
@ -583,7 +602,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_max_size() {
|
||||
let mut state = State::new(false, 1, None, 0, 16);
|
||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
|
@ -603,7 +622,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_token_budget() {
|
||||
let mut state = State::new(false, 1, None, 0, 2);
|
||||
let mut state = State::new(false, 1, false, None, 0, 2);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
|
@ -636,14 +655,14 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_append() {
|
||||
let queue = Queue::new(false, 1, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let (entry, _guard) = default_entry();
|
||||
queue.append(entry);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_empty() {
|
||||
let queue = Queue::new(false, 1, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
|
||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||
|
@ -651,7 +670,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_min_size() {
|
||||
let queue = Queue::new(false, 1, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
|
@ -684,7 +703,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_max_size() {
|
||||
let queue = Queue::new(false, 1, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
|
@ -700,7 +719,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_budget() {
|
||||
let queue = Queue::new(false, 1, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
|
@ -725,7 +744,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_speculate() {
|
||||
let queue = Queue::new(false, 1, None, 2, 16);
|
||||
let queue = Queue::new(false, 1, false, None, 2, 16);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
|
@ -744,7 +763,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_dropped_receiver() {
|
||||
let queue = Queue::new(false, 1, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let (entry, _) = default_entry();
|
||||
queue.append(entry);
|
||||
|
||||
|
|
|
@ -0,0 +1,755 @@
|
|||
use std::{
|
||||
collections::{BTreeSet, HashMap},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use slotmap::{DefaultKey, SlotMap};
|
||||
|
||||
use crate::block_allocator::{Allocator, BlockAllocation};
|
||||
|
||||
pub struct RadixAllocator {
|
||||
allocation_id: u64,
|
||||
|
||||
allocations: HashMap<u64, RadixAllocation>,
|
||||
|
||||
cache_blocks: RadixTrie,
|
||||
|
||||
/// Blocks that are immediately available for allocation.
|
||||
free_blocks: Vec<u32>,
|
||||
}
|
||||
|
||||
impl RadixAllocator {
|
||||
pub fn new(block_size: u32, n_blocks: u32, window_size: Option<u32>) -> Self {
|
||||
assert_eq!(
|
||||
block_size, 1,
|
||||
"Radix tree allocator only works with block_size=1, was: {}",
|
||||
block_size
|
||||
);
|
||||
if window_size.is_some() {
|
||||
unimplemented!("Window size not supported in the prefix-caching block allocator yet");
|
||||
}
|
||||
|
||||
RadixAllocator {
|
||||
allocation_id: 0,
|
||||
allocations: HashMap::new(),
|
||||
cache_blocks: RadixTrie::new(),
|
||||
|
||||
// Block 0 is reserved for health checks.
|
||||
free_blocks: (1..n_blocks).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
fn alloc_or_reclaim(&mut self, n_blocks_needed: usize) -> Option<Vec<u32>> {
|
||||
if self.free_blocks.len() < n_blocks_needed {
|
||||
// This is a bit annoying, we first extend the free list and then
|
||||
// split it off again below. This is because we need to put it on
|
||||
// the free list if we cannot allocate enough blocks. This is only
|
||||
// temporary, the trie needs to be able to report whether it can
|
||||
// allocate the requested amount. Just not implemented yet.
|
||||
self.free_blocks.extend(
|
||||
self.cache_blocks
|
||||
.evict(n_blocks_needed - self.free_blocks.len()),
|
||||
);
|
||||
}
|
||||
|
||||
if self.free_blocks.len() >= n_blocks_needed {
|
||||
Some(
|
||||
self.free_blocks
|
||||
.split_off(self.free_blocks.len() - n_blocks_needed),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Allocator for RadixAllocator {
|
||||
fn allocate(
|
||||
&mut self,
|
||||
tokens: u32,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
) -> Option<BlockAllocation> {
|
||||
let mut blocks = vec![];
|
||||
let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() {
|
||||
let node_id = self
|
||||
.cache_blocks
|
||||
.find(prefill_tokens.as_slice(), &mut blocks);
|
||||
// Even if this allocation fails below, we need to increase he
|
||||
// refcount to ensure that the prefix that was found is not evicted.
|
||||
|
||||
node_id
|
||||
} else {
|
||||
self.cache_blocks.root_id()
|
||||
};
|
||||
|
||||
self.cache_blocks
|
||||
.incref(prefix_node)
|
||||
.expect("Failed to increment refcount");
|
||||
|
||||
let prefix_len = blocks.len();
|
||||
let suffix_len = tokens - prefix_len as u32;
|
||||
|
||||
match self.alloc_or_reclaim(suffix_len as usize) {
|
||||
Some(suffix_blocks) => blocks.extend(suffix_blocks),
|
||||
None => {
|
||||
self.cache_blocks
|
||||
.decref(prefix_node)
|
||||
.expect("Failed to decrement refcount");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
// 1:1 mapping of blocks and slots.
|
||||
let slots = blocks.clone();
|
||||
|
||||
let allocation = RadixAllocation {
|
||||
prefix_node,
|
||||
cached_prefix_len: prefix_len,
|
||||
prefill_tokens: prefill_tokens.clone(),
|
||||
};
|
||||
|
||||
self.allocation_id += 1;
|
||||
self.allocations.insert(self.allocation_id, allocation);
|
||||
|
||||
Some(BlockAllocation {
|
||||
allocation_id: self.allocation_id,
|
||||
block_allocator: None,
|
||||
blocks,
|
||||
slots,
|
||||
prefix_len: prefix_len as u32,
|
||||
})
|
||||
}
|
||||
|
||||
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
|
||||
let allocation = match self.allocations.remove(&allocation_id) {
|
||||
Some(allocation) => allocation,
|
||||
None => unreachable!("Tried to free an unknown allocation."),
|
||||
};
|
||||
|
||||
self.cache_blocks
|
||||
.decref(allocation.prefix_node)
|
||||
.expect("Failed to decrement refcount");
|
||||
|
||||
if let Some(prefill_tokens) = allocation.prefill_tokens {
|
||||
let prefill_tokens = prefill_tokens.as_slice();
|
||||
|
||||
// If there are prefill tokens that did not come from the cache,
|
||||
// add them to the cache.
|
||||
if prefill_tokens.len() > allocation.cached_prefix_len {
|
||||
let prefix_len = self
|
||||
.cache_blocks
|
||||
.insert(prefill_tokens, &blocks[..prefill_tokens.len()])
|
||||
// Unwrap, failing is a programming error.
|
||||
.expect("Failed to store prefill tokens");
|
||||
|
||||
// We can have a prefill with the following structure:
|
||||
//
|
||||
// |---| From the prefix cache.
|
||||
// A B C D E F G
|
||||
//|--------| Found in the trie during insertion.
|
||||
//
|
||||
// This means that while processing this request there was a
|
||||
// partially overlapping request that had A..=E in its
|
||||
// prefill. In this case we need to free the blocks D E.
|
||||
self.free_blocks
|
||||
.extend(&blocks[allocation.cached_prefix_len..prefix_len]);
|
||||
}
|
||||
|
||||
// Free non-prefill blocks.
|
||||
self.free_blocks.extend(&blocks[prefill_tokens.len()..]);
|
||||
} else {
|
||||
self.free_blocks.extend(blocks);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct RadixAllocation {
|
||||
prefix_node: NodeId,
|
||||
cached_prefix_len: usize,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
}
|
||||
|
||||
// Radix trie that is heavily inspired by radix attention from sglang.
|
||||
//
|
||||
// The trie is optimized for prefix caching:
|
||||
//
|
||||
// - A normal radix trie stores discrete values. In this radix trie,
|
||||
// inserting *abc* with value *xyz* will also enable lookup for
|
||||
// *a* (*x*) and *ab* (*xy*).
|
||||
// - As a result, every value is required to have the same length as
|
||||
// the key.
|
||||
// - We store additional information in each node, such as last access
|
||||
// time and a reference count.
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum TrieError {
|
||||
InvalidNodeId,
|
||||
RefCountUnderflow,
|
||||
BlockTokenCountMismatch,
|
||||
}
|
||||
|
||||
pub type NodeId = DefaultKey;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RadixTrie {
|
||||
/// Identifier of the root nod.
|
||||
root: DefaultKey,
|
||||
|
||||
/// Leave node identifiers ordered by increasing recency.
|
||||
leaves: BTreeSet<(u64, NodeId)>,
|
||||
|
||||
/// All trie nodes.
|
||||
nodes: SlotMap<NodeId, TrieNode>,
|
||||
|
||||
/// Time as a monotonically increating counter to avoid the system
|
||||
/// call that a real time lookup would require.
|
||||
time: u64,
|
||||
}
|
||||
|
||||
impl RadixTrie {
|
||||
/// Construct a new radix trie.
|
||||
pub fn new() -> Self {
|
||||
let root = TrieNode::new(vec![], vec![], 0, None);
|
||||
let mut nodes = SlotMap::new();
|
||||
let root = nodes.insert(root);
|
||||
RadixTrie {
|
||||
leaves: BTreeSet::new(),
|
||||
nodes,
|
||||
root,
|
||||
time: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Find the prefix of the given tokens.
|
||||
///
|
||||
/// The blocks corresponding to the part of the prefix that could be found
|
||||
/// are writteng to `blocks`. The number of blocks is in `0..=tokens.len()`.
|
||||
/// Returns the identifier of the trie node that contains the longest
|
||||
/// prefix. The node identifier can be used by callers to e.g. increase its
|
||||
/// reference count.
|
||||
///
|
||||
/// Using this method will update the access time of the traversed nodes.
|
||||
pub fn find(&mut self, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
||||
self.time += 1;
|
||||
self.find_(self.root, key, blocks)
|
||||
}
|
||||
|
||||
/// Find worker.
|
||||
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
||||
let node = &self.nodes[node_id];
|
||||
|
||||
if let Some(&child_id) = node.children.get(&key[0]) {
|
||||
self.update_access_time(child_id);
|
||||
let child = self.nodes.get(child_id).expect("Invalid child identifier");
|
||||
let shared_prefix_len = child.key.shared_prefix_len(key);
|
||||
blocks.extend(&child.blocks[..shared_prefix_len]);
|
||||
|
||||
let key = &key[shared_prefix_len..];
|
||||
if !key.is_empty() {
|
||||
node_id = self.find_(child_id, key, blocks);
|
||||
}
|
||||
}
|
||||
|
||||
node_id
|
||||
}
|
||||
|
||||
/// Decrease the reference count of a node.
|
||||
pub fn decref(&mut self, node_id: NodeId) -> Result<(), TrieError> {
|
||||
// We don't care about refcounting for root, since it will never
|
||||
// be evicted.
|
||||
if node_id == self.root {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let node = self
|
||||
.nodes
|
||||
.get_mut(node_id)
|
||||
.ok_or(TrieError::InvalidNodeId)?;
|
||||
if node.ref_count == 0 {
|
||||
return Err(TrieError::RefCountUnderflow);
|
||||
}
|
||||
|
||||
node.ref_count -= 1;
|
||||
if node.ref_count == 0 {
|
||||
self.leaves.insert((node.last_accessed, node_id));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Increase the reference count of a node.
|
||||
pub fn incref(&mut self, node_id: NodeId) -> Result<(), TrieError> {
|
||||
if node_id == self.root {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let node = self
|
||||
.nodes
|
||||
.get_mut(node_id)
|
||||
.ok_or(TrieError::InvalidNodeId)?;
|
||||
if node.ref_count == 0 {
|
||||
self.leaves.remove(&(node.last_accessed, node_id));
|
||||
}
|
||||
node.ref_count += 1;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Evict `n_blocks` from the trie.
|
||||
///
|
||||
/// Returns the evicted blocks. When the length is less than `n_blocks`,
|
||||
/// not enough blocks could beevicted.
|
||||
pub fn evict(&mut self, n_blocks: usize) -> Vec<u32> {
|
||||
// NOTE: we don't return Result here. If any of the unwrapping fails,
|
||||
// it's a programming error in the trie implementation, not a user
|
||||
// error caused by e.g. an invalid argument.
|
||||
|
||||
// TODO: add some bookkeeping in the future to check whether we can
|
||||
// evict n_blocks and return `None` if we can't. We are now needlessly
|
||||
// evicting prefixes from the cache in such a case.
|
||||
let mut evicted = Vec::new();
|
||||
|
||||
while let Some((last_access, node_id)) = self.leaves.pop_first() {
|
||||
let blocks_needed = n_blocks - evicted.len();
|
||||
|
||||
let node = self.nodes.get(node_id).expect("Leave does not exist");
|
||||
if blocks_needed >= node.blocks.len() {
|
||||
// We need to evict the whole node if we need more blocks than it has.
|
||||
let node = self.remove_node(node_id);
|
||||
evicted.extend(node.blocks);
|
||||
|
||||
if evicted.len() >= n_blocks {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
// The node has more blocks than needed, so we'll just remove
|
||||
// the required number of blocks and leave the remaining blocks
|
||||
// untouched.
|
||||
let node = self.nodes.get_mut(node_id).expect("Leave does not exist");
|
||||
node.key.truncate(node.blocks.len() - blocks_needed);
|
||||
evicted.extend(node.blocks.split_off(node.blocks.len() - blocks_needed));
|
||||
self.leaves.insert((last_access, node_id));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
evicted
|
||||
}
|
||||
|
||||
/// Insert a prefill along with its blocks.
|
||||
///
|
||||
/// This method returns the length of the prefix that was already
|
||||
/// in the trie. E.g. if the length is 10, this means that for
|
||||
/// the first 10 elements of the tree **the blocks are not updated**.
|
||||
pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result<usize, TrieError> {
|
||||
self.time += 1;
|
||||
self.insert_(self.root, tokens, blocks)
|
||||
}
|
||||
|
||||
/// Insertion worker.
|
||||
fn insert_(
|
||||
&mut self,
|
||||
node_id: NodeId,
|
||||
tokens: &[u32],
|
||||
blocks: &[u32],
|
||||
) -> Result<usize, TrieError> {
|
||||
// TODO: in the future we may want to check that the blocks match for
|
||||
// the part of the prefix that is already in the trie to detect
|
||||
// mismatches.
|
||||
|
||||
if tokens.len() != blocks.len() {
|
||||
return Err(TrieError::BlockTokenCountMismatch);
|
||||
}
|
||||
|
||||
if let Some(&child_id) = self.nodes[node_id].children.get(&tokens[0]) {
|
||||
self.update_access_time(child_id);
|
||||
let child = self
|
||||
.nodes
|
||||
.get_mut(child_id)
|
||||
// Unwrap here, since failure is a bug.
|
||||
.expect("Child node does not exist");
|
||||
let shared_prefix_len = child.key.shared_prefix_len(tokens);
|
||||
|
||||
// We are done, the prefix is already in the trie.
|
||||
if shared_prefix_len == tokens.len() {
|
||||
return Ok(shared_prefix_len);
|
||||
}
|
||||
|
||||
// The node's prefix is a prefix of the insertion prefix.
|
||||
if shared_prefix_len == child.key.len() {
|
||||
return Ok(shared_prefix_len
|
||||
+ self.insert_(
|
||||
child_id,
|
||||
&tokens[shared_prefix_len..],
|
||||
&blocks[shared_prefix_len..],
|
||||
)?);
|
||||
}
|
||||
|
||||
// The node's prefix and the insertion prefix only match partially,
|
||||
// split the node to just contain the matching part. Then insert the
|
||||
// remainder of the prefix into the node again
|
||||
let child_id = self.split_node(child_id, shared_prefix_len);
|
||||
let key = &tokens[shared_prefix_len..];
|
||||
let blocks = &blocks[shared_prefix_len..];
|
||||
Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?)
|
||||
} else {
|
||||
self.add_node(node_id, tokens, blocks);
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
|
||||
fn split_node(&mut self, node_id: NodeId, prefix_len: usize) -> NodeId {
|
||||
// We have to make the current node a child to ensure that its
|
||||
// properties and node id stay the same.
|
||||
|
||||
// This funcion unwraps, an invalid node_id is a programming error.
|
||||
|
||||
let node = self
|
||||
.nodes
|
||||
.get_mut(node_id)
|
||||
.expect("Node to-be split does not exist");
|
||||
let mut parent_key = node.key.split_off(prefix_len);
|
||||
let mut parent_blocks = node.blocks.split_off(prefix_len);
|
||||
|
||||
// Move first part of the prefix to the parent. We swap to avoid
|
||||
// an allocation + copy for both splits of the key/blocks.
|
||||
std::mem::swap(&mut node.key, &mut parent_key);
|
||||
std::mem::swap(&mut node.blocks, &mut parent_blocks);
|
||||
|
||||
let node_key = node.key[0];
|
||||
|
||||
let grandparent_id = node.parent.expect("Node does not have a parent");
|
||||
let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks);
|
||||
self.add_node_to_parent(parent_id, node_key, node_id);
|
||||
|
||||
// Reborrow to make the borrow checker happy.
|
||||
let node = self
|
||||
.nodes
|
||||
.get_mut(node_id)
|
||||
.expect("Node to-be split does not exist");
|
||||
node.parent = Some(parent_id);
|
||||
|
||||
parent_id
|
||||
}
|
||||
|
||||
/// Create a node and add it to the parent.
|
||||
fn add_node(
|
||||
&mut self,
|
||||
parent_id: NodeId,
|
||||
key: impl Into<Vec<u32>>,
|
||||
blocks: impl Into<Vec<u32>>,
|
||||
) -> NodeId {
|
||||
let key = key.into();
|
||||
let blocks = blocks.into();
|
||||
let first = key[0];
|
||||
|
||||
let child = TrieNode::new(key, blocks, self.time, Some(parent_id));
|
||||
let child_id = self.nodes.insert(child);
|
||||
|
||||
self.add_node_to_parent(parent_id, first, child_id);
|
||||
self.leaves.insert((self.time, child_id));
|
||||
|
||||
child_id
|
||||
}
|
||||
|
||||
/// Add a node to the parent.
|
||||
fn add_node_to_parent(&mut self, parent_id: NodeId, first: u32, child_id: NodeId) {
|
||||
// Unwrap here, passing in an unknown id is a programming error.
|
||||
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
|
||||
if parent.children.insert(first, child_id).is_none() {
|
||||
// Only increase reference count if child does not replace another child.
|
||||
self.incref(parent_id)
|
||||
.expect("Failed to increase parent refcount");
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a node from the trie.
|
||||
fn remove_node(&mut self, node_id: NodeId) -> TrieNode {
|
||||
// Unwrap here, passing in an unknown id is a programming error.
|
||||
let node = self.nodes.remove(node_id).expect("Unknown node");
|
||||
let parent_id = node.parent.expect("Attempted to remove root node");
|
||||
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
|
||||
parent.children.remove(&node.key[0]);
|
||||
self.decref(parent_id)
|
||||
.expect("Failed to decrease parent refcount");
|
||||
self.nodes.remove(node_id);
|
||||
node
|
||||
}
|
||||
|
||||
fn update_access_time(&mut self, node_id: NodeId) {
|
||||
// Unwrap here, passing in an unknown id is a programming error.
|
||||
let node = self.nodes.get_mut(node_id).expect("Unknown node");
|
||||
|
||||
// Update the ordered leaves set if the node is a leave.
|
||||
if self.leaves.remove(&(node.last_accessed, node_id)) {
|
||||
self.leaves.insert((self.time, node_id));
|
||||
}
|
||||
|
||||
node.last_accessed = self.time;
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[doc(hidden)]
|
||||
/// Print debugging output for the trie.
|
||||
///
|
||||
/// In contrast to `Debug` nicely formatted.
|
||||
pub fn print_debug(&self) {
|
||||
self.print_debug_(self.root, 0);
|
||||
}
|
||||
|
||||
fn print_debug_(&self, node_id: NodeId, indent: usize) {
|
||||
let node = &self.nodes[node_id];
|
||||
eprintln!(
|
||||
"{}{:?}, key: {:?}, blocks: {:?}, ref_count: {}, last_accessed: {}, parent: {:?}, children: {:?}",
|
||||
" ".repeat(indent),
|
||||
node_id,
|
||||
node.key,
|
||||
node.blocks,
|
||||
node.ref_count,
|
||||
node.last_accessed,
|
||||
node.parent,
|
||||
node.children
|
||||
);
|
||||
for child_id in self.nodes[node_id].children.values() {
|
||||
self.print_debug_(*child_id, indent + 2);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn root_id(&self) -> DefaultKey {
|
||||
self.root
|
||||
}
|
||||
}
|
||||
|
||||
/// Trie node.
|
||||
#[derive(Debug)]
|
||||
struct TrieNode {
|
||||
blocks: Vec<u32>,
|
||||
children: HashMap<u32, NodeId>,
|
||||
key: Vec<u32>,
|
||||
last_accessed: u64,
|
||||
parent: Option<NodeId>,
|
||||
ref_count: usize,
|
||||
}
|
||||
|
||||
impl TrieNode {
|
||||
fn new(key: Vec<u32>, blocks: Vec<u32>, last_accessed: u64, parent: Option<NodeId>) -> Self {
|
||||
TrieNode {
|
||||
children: HashMap::new(),
|
||||
key,
|
||||
blocks,
|
||||
last_accessed,
|
||||
parent,
|
||||
ref_count: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper trait to get the length of the shared prefix of two sequences.
|
||||
trait SharedPrefixLen {
|
||||
fn shared_prefix_len(&self, other: &Self) -> usize;
|
||||
}
|
||||
|
||||
impl<T> SharedPrefixLen for [T]
|
||||
where
|
||||
T: PartialEq,
|
||||
{
|
||||
fn shared_prefix_len(&self, other: &Self) -> usize {
|
||||
self.iter().zip(other).take_while(|(a, b)| a == b).count()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::block_allocator::Allocator;
|
||||
|
||||
use super::RadixAllocator;
|
||||
|
||||
#[test]
|
||||
fn allocator_reuses_prefixes() {
|
||||
let mut cache = RadixAllocator::new(1, 12, None);
|
||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
|
||||
assert_eq!(allocation.slots, allocation.slots);
|
||||
assert_eq!(allocation.prefix_len, 0);
|
||||
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||
|
||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
|
||||
assert_eq!(allocation.prefix_len, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allocator_collects_older_prefixes_first() {
|
||||
let mut cache = RadixAllocator::new(1, 7, None);
|
||||
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]);
|
||||
assert_eq!(allocation1.prefix_len, 0);
|
||||
|
||||
let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap();
|
||||
assert_eq!(allocation2.blocks, vec![1, 2]);
|
||||
assert_eq!(allocation2.prefix_len, 0);
|
||||
|
||||
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
|
||||
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
|
||||
|
||||
// We should get the blocks of the first allocation, since they are more recent.
|
||||
let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap();
|
||||
assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]);
|
||||
assert_eq!(allocation3.prefix_len, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allocator_frees_fully_overlapping_prefills() {
|
||||
let mut cache = RadixAllocator::new(1, 10, None);
|
||||
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
|
||||
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
|
||||
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
|
||||
|
||||
let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation3.prefix_len, 4);
|
||||
|
||||
// 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks.
|
||||
assert_eq!(cache.free_blocks.len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allocator_frees_partially_overlapping_prefills() {
|
||||
let mut cache = RadixAllocator::new(1, 20, None);
|
||||
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap();
|
||||
assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]);
|
||||
assert_eq!(allocation1.prefix_len, 0);
|
||||
|
||||
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
|
||||
|
||||
let allocation2 = cache
|
||||
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))
|
||||
.unwrap();
|
||||
assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]);
|
||||
assert_eq!(allocation2.prefix_len, 2);
|
||||
|
||||
let allocation3 = cache
|
||||
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))
|
||||
.unwrap();
|
||||
assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]);
|
||||
assert_eq!(allocation3.prefix_len, 2);
|
||||
|
||||
cache.free(allocation3.blocks.clone(), allocation3.allocation_id);
|
||||
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
|
||||
|
||||
// 20 blocks, of which 1 reserved for health checks, 6 for allocation3, 2 for allocation2.
|
||||
assert_eq!(cache.free_blocks.len(), 11);
|
||||
|
||||
let allocation4 = cache
|
||||
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))
|
||||
.unwrap();
|
||||
assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]);
|
||||
assert_eq!(allocation4.prefix_len, 6);
|
||||
assert_eq!(cache.free_blocks.len(), 11);
|
||||
|
||||
let allocation5 = cache
|
||||
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))
|
||||
.unwrap();
|
||||
assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]);
|
||||
assert_eq!(allocation5.prefix_len, 6);
|
||||
assert_eq!(cache.free_blocks.len(), 11);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trie_insertions_have_correct_prefix_len() {
|
||||
let mut trie = super::RadixTrie::new();
|
||||
|
||||
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0);
|
||||
|
||||
// Already exists.
|
||||
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 3);
|
||||
|
||||
// Completely new at root-level
|
||||
assert_eq!(trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(), 0);
|
||||
|
||||
// Contains full prefix, but longer.
|
||||
assert_eq!(trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(), 3);
|
||||
|
||||
// Shares partial prefix, we need a split.
|
||||
assert_eq!(
|
||||
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
|
||||
.unwrap(),
|
||||
4
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trie_get_returns_correct_blocks() {
|
||||
let mut trie = super::RadixTrie::new();
|
||||
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
|
||||
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
|
||||
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
|
||||
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
|
||||
.unwrap();
|
||||
|
||||
let mut blocks = Vec::new();
|
||||
trie.find(&[0], &mut blocks);
|
||||
assert_eq!(blocks, vec![0]);
|
||||
|
||||
blocks.clear();
|
||||
trie.find(&[0, 1, 2], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1, 2]);
|
||||
|
||||
blocks.clear();
|
||||
trie.find(&[1, 2, 3], &mut blocks);
|
||||
assert_eq!(blocks, vec![1, 2, 3]);
|
||||
|
||||
blocks.clear();
|
||||
trie.find(&[0, 1, 2, 3], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1, 2, 3]);
|
||||
|
||||
blocks.clear();
|
||||
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1, 2, 3, 4]);
|
||||
|
||||
blocks.clear();
|
||||
trie.find(&[0, 1, 2, 3, 5], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1, 2, 3, 5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trie_evict_removes_correct_blocks() {
|
||||
let mut trie = super::RadixTrie::new();
|
||||
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
|
||||
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
|
||||
.unwrap();
|
||||
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
|
||||
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
|
||||
|
||||
let mut blocks = Vec::new();
|
||||
|
||||
// Remove less than the leave blocks.
|
||||
assert_eq!(trie.evict(1), vec![7]);
|
||||
trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1, 2, 3, 5, 6]);
|
||||
|
||||
// Refresh other leaf.
|
||||
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
|
||||
trie.find(&[1, 2, 3], &mut blocks);
|
||||
|
||||
// Remove the leave blocks exactly.
|
||||
assert_eq!(trie.evict(2), vec![5, 6]);
|
||||
blocks.clear();
|
||||
trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1, 2, 3]);
|
||||
|
||||
trie.find(&[1, 2, 3], &mut blocks);
|
||||
|
||||
// Remove more than the leave blocks.
|
||||
assert_eq!(trie.evict(3), vec![4, 3, 2]);
|
||||
blocks.clear();
|
||||
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1]);
|
||||
|
||||
// Clear out the whole trie.
|
||||
assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]);
|
||||
}
|
||||
}
|
|
@ -157,6 +157,7 @@ async fn prefill(
|
|||
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
prefix_len: 0,
|
||||
adapter_id: None,
|
||||
})
|
||||
.collect();
|
||||
|
|
|
@ -4,21 +4,22 @@ package generate.v3;
|
|||
|
||||
service TextGenerationService {
|
||||
/// Model Info
|
||||
rpc Info (InfoRequest) returns (InfoResponse) {}
|
||||
rpc Info(InfoRequest) returns (InfoResponse) {}
|
||||
/// Service discovery
|
||||
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}
|
||||
rpc ServiceDiscovery(ServiceDiscoveryRequest)
|
||||
returns (ServiceDiscoveryResponse) {}
|
||||
/// Empties batch cache
|
||||
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
|
||||
rpc ClearCache(ClearCacheRequest) returns (ClearCacheResponse);
|
||||
/// Remove requests from a cached batch
|
||||
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
|
||||
rpc FilterBatch(FilterBatchRequest) returns (FilterBatchResponse);
|
||||
/// Warmup the model and compute max cache size
|
||||
rpc Warmup (WarmupRequest) returns (WarmupResponse);
|
||||
rpc Warmup(WarmupRequest) returns (WarmupResponse);
|
||||
/// Prefill batch and decode first token
|
||||
rpc Prefill (PrefillRequest) returns (PrefillResponse);
|
||||
rpc Prefill(PrefillRequest) returns (PrefillResponse);
|
||||
/// Decode token for a list of prefilled batches
|
||||
rpc Decode (DecodeRequest) returns (DecodeResponse);
|
||||
rpc Decode(DecodeRequest) returns (DecodeResponse);
|
||||
/// Health check
|
||||
rpc Health (HealthRequest) returns (HealthResponse);
|
||||
rpc Health(HealthRequest) returns (HealthResponse);
|
||||
}
|
||||
|
||||
message HealthRequest {}
|
||||
|
@ -68,9 +69,7 @@ message InputChunk {
|
|||
}
|
||||
}
|
||||
|
||||
message Input {
|
||||
repeated InputChunk chunks = 1;
|
||||
}
|
||||
message Input { repeated InputChunk chunks = 1; }
|
||||
|
||||
enum GrammarType {
|
||||
GRAMMAR_TYPE_NONE = 0;
|
||||
|
@ -136,6 +135,8 @@ message Request {
|
|||
repeated uint32 slots = 10;
|
||||
/// LORA adapter index
|
||||
optional string adapter_id = 11;
|
||||
/// Prefix length that can be retrieved from the KV cache.
|
||||
uint32 prefix_len = 12;
|
||||
}
|
||||
|
||||
message Batch {
|
||||
|
@ -214,7 +215,6 @@ message FilterBatchResponse {
|
|||
CachedBatch batch = 1;
|
||||
}
|
||||
|
||||
|
||||
message PrefillRequest {
|
||||
/// Batch
|
||||
Batch batch = 1;
|
||||
|
|
|
@ -11,6 +11,7 @@ use rand::{thread_rng, Rng};
|
|||
use serde_json::Value;
|
||||
use std::io::Cursor;
|
||||
use std::iter;
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokio::sync::mpsc;
|
||||
|
@ -115,13 +116,14 @@ impl Validation {
|
|||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
#[instrument(skip(self, inputs))]
|
||||
async fn validate_input(
|
||||
&self,
|
||||
inputs: String,
|
||||
truncate: Option<usize>,
|
||||
max_new_tokens: Option<u32>,
|
||||
) -> Result<(Vec<Chunk>, usize, u32), ValidationError> {
|
||||
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
|
||||
// If we have a fast tokenizer
|
||||
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
|
||||
// Create response channel
|
||||
|
@ -156,8 +158,10 @@ impl Validation {
|
|||
));
|
||||
}
|
||||
|
||||
let input_ids = encoding.get_ids()[..input_length].to_owned();
|
||||
|
||||
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
|
||||
Ok((inputs, input_length, max_new_tokens))
|
||||
Ok((inputs, Some(input_ids), input_length, max_new_tokens))
|
||||
}
|
||||
// Return inputs without validation
|
||||
else {
|
||||
|
@ -180,7 +184,12 @@ impl Validation {
|
|||
input_length = input_length.saturating_sub(max_new_tokens as usize);
|
||||
}
|
||||
|
||||
Ok((vec![Chunk::Text(inputs)], input_length, max_new_tokens))
|
||||
Ok((
|
||||
vec![Chunk::Text(inputs)],
|
||||
None,
|
||||
input_length,
|
||||
max_new_tokens,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -314,7 +323,7 @@ impl Validation {
|
|||
.unwrap_or(Ok(None))?;
|
||||
|
||||
// Validate inputs
|
||||
let (inputs, input_length, max_new_tokens) = self
|
||||
let (inputs, input_ids, input_length, max_new_tokens) = self
|
||||
.validate_input(request.inputs, truncate, max_new_tokens)
|
||||
.await?;
|
||||
|
||||
|
@ -391,6 +400,7 @@ impl Validation {
|
|||
|
||||
Ok(ValidGenerateRequest {
|
||||
inputs,
|
||||
input_ids: input_ids.map(Arc::new),
|
||||
decoder_input_details,
|
||||
input_length: input_length as u32,
|
||||
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
||||
|
@ -707,6 +717,7 @@ pub struct ValidStoppingParameters {
|
|||
#[derive(Debug, Clone)]
|
||||
pub struct ValidGenerateRequest {
|
||||
pub inputs: Vec<Chunk>,
|
||||
pub input_ids: Option<Arc<Vec<u32>>>,
|
||||
pub input_length: u32,
|
||||
pub truncate: u32,
|
||||
pub decoder_input_details: bool,
|
||||
|
|
|
@ -6,7 +6,12 @@ from .common import Seqlen
|
|||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||
if SYSTEM == "cuda":
|
||||
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||
from .cuda import (
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
SUPPORTS_WINDOWING,
|
||||
)
|
||||
elif SYSTEM == "rocm":
|
||||
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||
elif SYSTEM == "ipex":
|
||||
|
|
|
@ -221,9 +221,11 @@ SUPPORTS_WINDOWING = V2
|
|||
if ATTENTION == "flashinfer":
|
||||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
|
@ -231,14 +233,15 @@ if ATTENTION == "flashinfer":
|
|||
causal=True,
|
||||
softcap=0.0,
|
||||
):
|
||||
from text_generation_server.layers.attention.flash_infer import prefill_state
|
||||
assert window_size_left == -1, "Windowing is not supported with flash infer"
|
||||
from text_generation_server.layers.attention.flash_infer import (
|
||||
prefill_with_paged_kv_state,
|
||||
)
|
||||
|
||||
return prefill_state.get().forward(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return prefill_with_paged_kv_state.get().forward(
|
||||
q.contiguous(),
|
||||
causal=causal,
|
||||
window_left=window_size_left,
|
||||
paged_kv_cache=(key_cache, value_cache),
|
||||
logits_soft_cap=softcap,
|
||||
sm_scale=softmax_scale,
|
||||
)
|
||||
|
@ -249,6 +252,8 @@ elif V2:
|
|||
q,
|
||||
k,
|
||||
v,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
|
@ -289,6 +294,8 @@ else:
|
|||
q,
|
||||
k,
|
||||
v,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
|
|
|
@ -9,6 +9,10 @@ prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = Con
|
|||
"prefill_state"
|
||||
)
|
||||
|
||||
prefill_with_paged_kv_state: ContextVar[
|
||||
flashinfer.BatchPrefillWithPagedKVCacheWrapper
|
||||
] = ContextVar("prefill_with_paged_kv_state")
|
||||
|
||||
decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
|
||||
"decode_state"
|
||||
)
|
||||
|
@ -24,6 +28,78 @@ def get_workspace(device):
|
|||
return workspace
|
||||
|
||||
|
||||
def create_prefill_with_paged_kv_state(
|
||||
*,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Create a prefill state that uses the KV cache."""
|
||||
workspace_buffer = get_workspace(device)
|
||||
return flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_prefill_with_paged_kv_state(
|
||||
*,
|
||||
state: flashinfer.BatchPrefillWithPagedKVCacheWrapper,
|
||||
block_tables: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
page_size: int,
|
||||
query_dtype: str = "float16",
|
||||
):
|
||||
"""
|
||||
Context manager to set the active flashinfer prefill state to the given
|
||||
`state` and parameters. This state will be used by all calls to the
|
||||
`attention` function while the context manager is active.
|
||||
"""
|
||||
|
||||
indptr = torch.zeros(
|
||||
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
|
||||
)
|
||||
# Round up to page size and then calculate the cumulative sum to get
|
||||
# the indices into the block table.
|
||||
torch.add(input_lengths, page_size - 1, out=indptr[1:])
|
||||
indptr[1:].div_(page_size, rounding_mode="floor")
|
||||
indptr[1:].cumsum_(-1)
|
||||
|
||||
# Get the lengths of the last page in a block.
|
||||
if page_size == 1:
|
||||
last_page_len = torch.ones(
|
||||
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
|
||||
)
|
||||
else:
|
||||
last_page_len = torch.empty(
|
||||
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
|
||||
)
|
||||
torch.sub(input_lengths, 1, out=last_page_len)
|
||||
last_page_len.remainder_(page_size)
|
||||
last_page_len += 1
|
||||
|
||||
token = prefill_with_paged_kv_state.set(state)
|
||||
try:
|
||||
state.begin_forward(
|
||||
qo_indptr=cu_seqlens,
|
||||
paged_kv_indptr=indptr,
|
||||
paged_kv_indices=block_tables,
|
||||
paged_kv_last_page_len=last_page_len,
|
||||
num_qo_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_size,
|
||||
q_data_type=query_dtype,
|
||||
page_size=page_size,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
state.end_forward()
|
||||
if token is not None:
|
||||
prefill_with_paged_kv_state.reset(token)
|
||||
|
||||
|
||||
def create_prefill_state(
|
||||
*,
|
||||
device: torch.device,
|
||||
|
|
|
@ -298,6 +298,8 @@ class FlashCohereAttention(torch.nn.Module):
|
|||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -337,6 +337,8 @@ class DbrxAttention(torch.nn.Module):
|
|||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -365,6 +365,8 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -238,6 +238,8 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -232,6 +232,8 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -232,6 +232,8 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
from contextlib import contextmanager
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
@ -220,6 +221,8 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -219,6 +219,8 @@ class MistralAttention(torch.nn.Module):
|
|||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -276,6 +276,8 @@ class MixtralAttention(torch.nn.Module):
|
|||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -173,6 +173,8 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
qkv[:, 0],
|
||||
qkv[:, 1],
|
||||
qkv[:, 2],
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -194,6 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
|
|||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -137,6 +137,8 @@ class Qwen2Attention(torch.nn.Module):
|
|||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -208,6 +208,8 @@ class FlashRWAttention(torch.nn.Module):
|
|||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
@ -326,6 +328,8 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||
query,
|
||||
torch.select(kv, dim=2, index=0),
|
||||
torch.select(kv, dim=2, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -293,6 +293,8 @@ class FlashMQAttention(torch.nn.Module):
|
|||
query,
|
||||
torch.select(key_value, dim=1, index=0),
|
||||
torch.select(key_value, dim=1, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -242,6 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
|
|||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -20,6 +20,9 @@ from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, D
|
|||
|
||||
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
|
||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||
from text_generation_server.layers.attention.flash_infer import (
|
||||
create_prefill_with_paged_kv_state,
|
||||
)
|
||||
from text_generation_server.utils.chunks import concat_text_chunks
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models import Model
|
||||
|
@ -43,6 +46,7 @@ from text_generation_server.models.globals import (
|
|||
ATTENTION,
|
||||
BLOCK_SIZE,
|
||||
CUDA_GRAPHS,
|
||||
PREFIX_CACHING,
|
||||
get_adapter_to_index,
|
||||
)
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
|
@ -108,6 +112,9 @@ class FlashCausalLMBatch(Batch):
|
|||
block_tables_tensor: torch.Tensor
|
||||
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
||||
slots: torch.Tensor
|
||||
# size [b], containing the number of blocks that can be retrieved from the cache
|
||||
prefix_lens: List[int]
|
||||
prefix_lens_tensor: torch.Tensor
|
||||
|
||||
max_seqlen: int
|
||||
|
||||
|
@ -116,6 +123,9 @@ class FlashCausalLMBatch(Batch):
|
|||
prefill_next_token_indices: Optional[torch.tensor]
|
||||
prefill_cu_outlens: Optional[List[int]]
|
||||
|
||||
# Prefixes
|
||||
prefix_ids: List[List[int]]
|
||||
|
||||
# All tokens
|
||||
all_input_ids: List[List[int]]
|
||||
all_input_ids_tensor: torch.Tensor
|
||||
|
@ -183,6 +193,7 @@ class FlashCausalLMBatch(Batch):
|
|||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
all_input_ids = []
|
||||
prefix_ids = []
|
||||
requests_idx_mapping = {}
|
||||
|
||||
all_prefill_logprobs = True
|
||||
|
@ -200,7 +211,7 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
# Cumulative length
|
||||
cumulative_length = 0
|
||||
cumulative_max_length = 0
|
||||
cumulative_slot_tokens = 0
|
||||
prefill_out_cumulative_length = 0
|
||||
|
||||
num_blocks = 0
|
||||
|
@ -210,6 +221,7 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
block_tables = []
|
||||
slots = []
|
||||
prefix_lens = []
|
||||
|
||||
# Parse batch
|
||||
for i, (r, tokenized_input) in enumerate(
|
||||
|
@ -225,6 +237,19 @@ class FlashCausalLMBatch(Batch):
|
|||
):
|
||||
tokenized_input = tokenized_input[1:]
|
||||
|
||||
orig_input_length = len(tokenized_input)
|
||||
|
||||
if PREFIX_CACHING:
|
||||
prefix_len = r.prefix_len
|
||||
if prefix_len == orig_input_length:
|
||||
assert prefix_len > 0
|
||||
prefix_len -= 1
|
||||
else:
|
||||
prefix_len = 0
|
||||
|
||||
prefix_ids.append(tokenized_input[:prefix_len])
|
||||
tokenized_input = tokenized_input[prefix_len:]
|
||||
|
||||
input_length = len(tokenized_input)
|
||||
input_lengths.append(input_length)
|
||||
|
||||
|
@ -234,7 +259,9 @@ class FlashCausalLMBatch(Batch):
|
|||
all_input_ids.append(tokenized_input)
|
||||
|
||||
# Position ids
|
||||
request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
|
||||
request_position_ids = torch.arange(
|
||||
prefix_len, orig_input_length, dtype=torch.int32
|
||||
)
|
||||
position_ids.append(request_position_ids)
|
||||
|
||||
# Add cumulative lengths of all previous inputs
|
||||
|
@ -258,11 +285,17 @@ class FlashCausalLMBatch(Batch):
|
|||
# Remove one as the first token des not have a past
|
||||
speculative_length = get_speculate()
|
||||
speculative_length = 0 if speculative_length is None else speculative_length
|
||||
total_tokens = input_length + max_new_tokens - 1 + speculative_length
|
||||
|
||||
# Tokens that need to be mapped to blocks.
|
||||
block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length
|
||||
|
||||
# Tokens that need to be mapped to slots. We don't need slots for the
|
||||
# cached prefix (if present).
|
||||
slot_tokens = input_length + max_new_tokens - 1 + speculative_length
|
||||
|
||||
# blocks and slots can be empty (for example in warmup)
|
||||
if not r.blocks:
|
||||
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
||||
needed_blocks = math.ceil(block_tokens / BLOCK_SIZE)
|
||||
request_blocks = [
|
||||
b for b in range(num_blocks, num_blocks + needed_blocks)
|
||||
]
|
||||
|
@ -273,16 +306,20 @@ class FlashCausalLMBatch(Batch):
|
|||
]
|
||||
else:
|
||||
request_blocks = r.blocks
|
||||
request_slots = r.slots
|
||||
request_slots = r.slots[
|
||||
prefix_len: #: orig_input_length + max_new_tokens + speculative_length
|
||||
]
|
||||
|
||||
block_tables.append(request_blocks)
|
||||
slots.extend(request_slots[:total_tokens])
|
||||
|
||||
slots.extend(request_slots)
|
||||
prefix_lens.append(prefix_len)
|
||||
num_blocks += len(request_blocks)
|
||||
start_slots.append(cumulative_max_length)
|
||||
start_slots.append(cumulative_slot_tokens)
|
||||
|
||||
request_slot_indices = torch.arange(
|
||||
cumulative_max_length,
|
||||
cumulative_max_length + input_length,
|
||||
cumulative_slot_tokens,
|
||||
cumulative_slot_tokens + input_length,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
slot_indices.append(request_slot_indices)
|
||||
|
@ -318,7 +355,7 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
# Update
|
||||
cumulative_length += input_length
|
||||
cumulative_max_length += total_tokens
|
||||
cumulative_slot_tokens += slot_tokens
|
||||
max_seqlen = max(max_seqlen, input_length)
|
||||
max_blocks = max(max_blocks, len(request_blocks))
|
||||
max_length = max(
|
||||
|
@ -395,12 +432,14 @@ class FlashCausalLMBatch(Batch):
|
|||
)
|
||||
|
||||
slots = torch.tensor(slots, dtype=torch.int64, device=device)
|
||||
|
||||
block_tables_tensor = torch.zeros(
|
||||
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
|
||||
)
|
||||
for i, request_blocks in enumerate(block_tables):
|
||||
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
||||
block_tables_tensor = block_tables_tensor.to(device)
|
||||
prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
|
@ -415,6 +454,8 @@ class FlashCausalLMBatch(Batch):
|
|||
block_tables=block_tables,
|
||||
block_tables_tensor=block_tables_tensor,
|
||||
slots=slots,
|
||||
prefix_lens=prefix_lens,
|
||||
prefix_lens_tensor=prefix_lens_tensor,
|
||||
max_seqlen=max_seqlen,
|
||||
prefill_head_indices=prefill_head_indices,
|
||||
prefill_next_token_indices=prefill_next_token_indices,
|
||||
|
@ -425,6 +466,7 @@ class FlashCausalLMBatch(Batch):
|
|||
read_offsets=read_offsets,
|
||||
all_input_ids=all_input_ids,
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
prefix_ids=prefix_ids,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
|
@ -480,6 +522,7 @@ class FlashCausalLMBatch(Batch):
|
|||
start_slots = []
|
||||
block_tables = []
|
||||
all_input_ids = []
|
||||
prefix_ids = []
|
||||
|
||||
input_lengths = []
|
||||
prefix_offsets = []
|
||||
|
@ -506,6 +549,7 @@ class FlashCausalLMBatch(Batch):
|
|||
max_seqlen = max(max_seqlen, request_input_length)
|
||||
|
||||
all_input_ids.append(self.all_input_ids[idx])
|
||||
prefix_ids.append(self.prefix_ids[idx])
|
||||
|
||||
input_lengths.append(request_input_length)
|
||||
prefix_offsets.append(self.prefix_offsets[idx])
|
||||
|
@ -591,6 +635,7 @@ class FlashCausalLMBatch(Batch):
|
|||
read_offsets=read_offsets,
|
||||
all_input_ids=all_input_ids,
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
prefix_ids=prefix_ids,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
|
@ -651,6 +696,7 @@ class FlashCausalLMBatch(Batch):
|
|||
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
|
||||
(total_batch_size, max_blocks)
|
||||
)
|
||||
prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size)
|
||||
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
||||
(total_batch_size, max_length)
|
||||
)
|
||||
|
@ -668,7 +714,9 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
start_slots = []
|
||||
block_tables = []
|
||||
prefix_lens = []
|
||||
all_input_ids = []
|
||||
prefix_ids = []
|
||||
|
||||
input_lengths = []
|
||||
prefix_offsets = []
|
||||
|
@ -730,10 +778,14 @@ class FlashCausalLMBatch(Batch):
|
|||
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
||||
] = batch.block_tables_tensor[:, :max_blocks]
|
||||
|
||||
prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor
|
||||
|
||||
start_slots.append(batch.start_slots + cumulative_slots)
|
||||
|
||||
block_tables.extend(batch.block_tables)
|
||||
prefix_lens.extend(batch.prefix_lens)
|
||||
all_input_ids.extend(batch.all_input_ids)
|
||||
prefix_ids.extend(batch.prefix_ids)
|
||||
|
||||
input_lengths.extend(batch.input_lengths)
|
||||
prefix_offsets.extend(batch.prefix_offsets)
|
||||
|
@ -779,6 +831,8 @@ class FlashCausalLMBatch(Batch):
|
|||
slot_indices=slot_indices,
|
||||
block_tables=block_tables,
|
||||
block_tables_tensor=block_tables_tensor,
|
||||
prefix_lens=prefix_lens,
|
||||
prefix_lens_tensor=prefix_lens_tensor,
|
||||
slots=slots,
|
||||
max_seqlen=max_seqlen,
|
||||
prefill_head_indices=None,
|
||||
|
@ -790,6 +844,7 @@ class FlashCausalLMBatch(Batch):
|
|||
read_offsets=read_offsets,
|
||||
all_input_ids=all_input_ids,
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
prefix_ids=prefix_ids,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
|
@ -944,6 +999,9 @@ class FlashCausalLM(Model):
|
|||
)
|
||||
|
||||
self.prefill_state = create_prefill_state(device=device)
|
||||
self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(
|
||||
device=device
|
||||
)
|
||||
|
||||
if not CUDA_GRAPHS:
|
||||
self.decode_state = create_decode_state(
|
||||
|
@ -1042,11 +1100,22 @@ class FlashCausalLM(Model):
|
|||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
||||
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||
block_tables = (
|
||||
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
||||
.repeat(bs)
|
||||
.reshape((bs, max_bt))
|
||||
input_lengths = [max_s] * bs
|
||||
prefix_lengths = [0] * bs
|
||||
input_lengths_tensor = (
|
||||
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||
)
|
||||
prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
block_tables = torch.arange(
|
||||
max_bt, dtype=torch.int32, device=self.device
|
||||
).repeat(bs)
|
||||
block_tables = block_tables.reshape((bs, max_bt))
|
||||
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=input_lengths,
|
||||
prefix_lens=prefix_lengths,
|
||||
)
|
||||
|
||||
self.cuda_graphs[bs] = {
|
||||
|
@ -1055,9 +1124,9 @@ class FlashCausalLM(Model):
|
|||
"kv_cache": self.kv_cache,
|
||||
"block_tables": block_tables,
|
||||
"slots": slots,
|
||||
"input_lengths": input_lengths,
|
||||
"input_lengths": input_lengths_tensor,
|
||||
}
|
||||
input_lengths_ = Seqlen(input_lengths=input_lengths)
|
||||
input_lengths_ = Seqlen(input_lengths=input_lengths_tensor)
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
self.cuda_graphs[bs]["graph"] = graph
|
||||
|
||||
|
@ -1072,7 +1141,7 @@ class FlashCausalLM(Model):
|
|||
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
|
||||
state = create_decode_state_cuda_graphs(
|
||||
device=input_ids.device,
|
||||
block_tables=block_tables.view(-1),
|
||||
block_tables=block_tables,
|
||||
block_tables_ptr=block_tables_ptr,
|
||||
last_page_len=last_page_len,
|
||||
num_heads=self.num_heads,
|
||||
|
@ -1088,7 +1157,10 @@ class FlashCausalLM(Model):
|
|||
block_tables=block_tables,
|
||||
cu_seqlen_prefill=None,
|
||||
input_lengths=input_lengths,
|
||||
input_lengths_tensor=input_lengths_tensor,
|
||||
state=state,
|
||||
prefix_lens=prefix_lengths,
|
||||
prefix_lens_tensor=prefix_lengths_tensor,
|
||||
):
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
|
@ -1106,7 +1178,7 @@ class FlashCausalLM(Model):
|
|||
torch.cuda.synchronize()
|
||||
|
||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||
input_lengths_tensor = Seqlen(input_lengths=input_lengths_tensor)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
|
@ -1114,7 +1186,7 @@ class FlashCausalLM(Model):
|
|||
kv_cache=self.kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
input_lengths=input_lengths_tensor,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
|
@ -1343,7 +1415,10 @@ class FlashCausalLM(Model):
|
|||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
input_lengths=input_lengths,
|
||||
input_lengths=batch.input_lengths,
|
||||
input_lengths_tensor=input_lengths,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
prefix_lens_tensor=batch.prefix_lens_tensor,
|
||||
):
|
||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
|
@ -1367,19 +1442,32 @@ class FlashCausalLM(Model):
|
|||
# Static inputs are potentially padded
|
||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
)
|
||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||
else:
|
||||
cuda_graph["block_tables"][
|
||||
: block_tables.shape[0], : block_tables.shape[1]
|
||||
] = block_tables
|
||||
cuda_graph["slots"].fill_(-1)
|
||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
||||
input_lengths + batch.prefix_lens_tensor
|
||||
)
|
||||
|
||||
state = cuda_graph.get("state")
|
||||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
cu_seqlen_prefill=None,
|
||||
input_lengths=input_lengths,
|
||||
input_lengths=batch.input_lengths,
|
||||
input_lengths_tensor=input_lengths,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
prefix_lens_tensor=batch.prefix_lens_tensor,
|
||||
state=state,
|
||||
):
|
||||
# Replay the graph
|
||||
|
@ -1578,6 +1666,7 @@ class FlashCausalLM(Model):
|
|||
batch.read_offsets,
|
||||
batch.stopping_criterias,
|
||||
batch.all_input_ids,
|
||||
batch.prefix_ids,
|
||||
batch.next_token_chooser.do_sample,
|
||||
batch.next_token_chooser.seeds,
|
||||
batch.top_n_tokens,
|
||||
|
@ -1595,6 +1684,7 @@ class FlashCausalLM(Model):
|
|||
read_offset,
|
||||
stopping_criteria,
|
||||
all_input_ids,
|
||||
prefix_ids,
|
||||
do_sample,
|
||||
seed,
|
||||
top_n_tokens,
|
||||
|
@ -1669,18 +1759,18 @@ class FlashCausalLM(Model):
|
|||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
request_prefill_logprobs = [float("nan")] + prefill_logprobs[
|
||||
out_start_index : out_end_index - 1
|
||||
]
|
||||
request_prefill_logprobs = (
|
||||
[float("nan")] * (len(prefix_ids) + 1)
|
||||
) + prefill_logprobs[out_start_index : out_end_index - 1]
|
||||
prefill_token_ids = all_input_ids[:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
prefix_ids + prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
prefill_tokens = Tokens(
|
||||
prefill_token_ids,
|
||||
prefix_ids + prefill_token_ids,
|
||||
request_prefill_logprobs,
|
||||
prefill_texts,
|
||||
is_special=[],
|
||||
|
@ -1762,7 +1852,10 @@ class FlashCausalLM(Model):
|
|||
*,
|
||||
block_tables: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
input_lengths: torch.Tensor,
|
||||
input_lengths: List[int],
|
||||
input_lengths_tensor: torch.Tensor,
|
||||
prefix_lens: List[int],
|
||||
prefix_lens_tensor: torch.Tensor,
|
||||
state: Optional[Any] = None,
|
||||
) -> ContextManager:
|
||||
if ATTENTION != "flashinfer":
|
||||
|
@ -1771,9 +1864,30 @@ class FlashCausalLM(Model):
|
|||
from text_generation_server.layers.attention.flash_infer import (
|
||||
use_decode_state,
|
||||
use_prefill_state,
|
||||
use_prefill_with_paged_kv_state,
|
||||
)
|
||||
|
||||
# has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens)
|
||||
|
||||
if cu_seqlen_prefill is not None:
|
||||
if True: # has_prefix_lens:
|
||||
return use_prefill_with_paged_kv_state(
|
||||
state=(
|
||||
state if state is not None else self.prefill_with_paged_kv_state
|
||||
),
|
||||
block_tables=block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=input_lengths,
|
||||
prefix_lens=prefix_lens,
|
||||
),
|
||||
cu_seqlens=cu_seqlen_prefill,
|
||||
input_lengths=input_lengths_tensor + prefix_lens_tensor,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
page_size=BLOCK_SIZE,
|
||||
)
|
||||
else:
|
||||
return use_prefill_state(
|
||||
state=state if state is not None else self.prefill_state,
|
||||
cu_seqlens=cu_seqlen_prefill,
|
||||
|
@ -1782,13 +1896,33 @@ class FlashCausalLM(Model):
|
|||
head_size=self.head_size,
|
||||
)
|
||||
else:
|
||||
assert input_lengths is not None
|
||||
assert input_lengths_tensor is not None
|
||||
return use_decode_state(
|
||||
state=state if state is not None else self.decode_state,
|
||||
input_lengths=input_lengths,
|
||||
block_tables=block_tables.view(-1),
|
||||
input_lengths=input_lengths_tensor + prefix_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
page_size=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
|
||||
def block_tables_to_ragged(
|
||||
*, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int]
|
||||
) -> torch.Tensor:
|
||||
"""Convert block table to ragged format compatible with FlashInfer."""
|
||||
assert len(input_lengths) == len(prefix_lens)
|
||||
|
||||
total_len = sum(input_lengths) + sum(prefix_lens)
|
||||
block_tables_ragged = torch.empty(
|
||||
total_len, dtype=torch.int32, device=block_tables.device
|
||||
)
|
||||
|
||||
offset = 0
|
||||
for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)):
|
||||
seq_len = prefix_len + input_length
|
||||
block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]
|
||||
offset += seq_len
|
||||
|
||||
return block_tables_ragged
|
||||
|
|
|
@ -5,17 +5,29 @@ from typing import Dict, Optional
|
|||
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
ATTENTION = os.getenv("ATTENTION", "paged")
|
||||
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", False)
|
||||
log_master(logger.info, f"Using Attention = {PREFIX_CACHING}")
|
||||
|
||||
ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged")
|
||||
_expected = {"paged", "flashdecoding", "flashinfer"}
|
||||
assert (
|
||||
ATTENTION in _expected
|
||||
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
||||
log_master(logger.info, f"Using Attention = {ATTENTION}")
|
||||
|
||||
if PREFIX_CACHING and ATTENTION != "flashinfer":
|
||||
raise RuntimeError("Prefix caching is only supported with flashinfer")
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
# This is overridden by the cli
|
||||
BLOCK_SIZE: int = 256 if ATTENTION == "flashdecoding" else 16
|
||||
|
||||
BLOCK_SIZE: int
|
||||
if ATTENTION == "flashdecoding":
|
||||
BLOCK_SIZE = 256
|
||||
elif ATTENTION == "flashinfer":
|
||||
BLOCK_SIZE = 1
|
||||
else:
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||
if cuda_graphs is not None:
|
||||
|
|
Loading…
Reference in New Issue