Prefix caching WIP
This commit is contained in:
parent
7a48a84784
commit
7735b385dc
|
@ -4045,6 +4045,7 @@ dependencies = [
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"slotmap",
|
||||||
"text-generation-router",
|
"text-generation-router",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
|
|
|
@ -156,6 +156,7 @@ impl Client {
|
||||||
// Blocks and slots will be set on the server side if we use paged attention
|
// Blocks and slots will be set on the server side if we use paged attention
|
||||||
blocks: vec![],
|
blocks: vec![],
|
||||||
slots: vec![],
|
slots: vec![],
|
||||||
|
prefix_len: 0,
|
||||||
// Set sampling parameters to also take these ops into account in the max memory
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 0.9,
|
temperature: 0.9,
|
||||||
|
|
|
@ -244,6 +244,7 @@ impl Health for ShardedClient {
|
||||||
// Block 0 is reserved for health checks
|
// Block 0 is reserved for health checks
|
||||||
blocks: vec![0],
|
blocks: vec![0],
|
||||||
slots: (0..16).collect(),
|
slots: (0..16).collect(),
|
||||||
|
prefix_len: 0,
|
||||||
adapter_id: None,
|
adapter_id: None,
|
||||||
};
|
};
|
||||||
let batch = Batch {
|
let batch = Batch {
|
||||||
|
|
|
@ -33,6 +33,7 @@ rand = "0.8.5"
|
||||||
reqwest = { version = "0.11.20", features = [] }
|
reqwest = { version = "0.11.20", features = [] }
|
||||||
serde = "1.0.188"
|
serde = "1.0.188"
|
||||||
serde_json = "1.0.107"
|
serde_json = "1.0.107"
|
||||||
|
slotmap = "1.0.7"
|
||||||
thiserror = "1.0.48"
|
thiserror = "1.0.48"
|
||||||
tokenizers = { workspace = true}
|
tokenizers = { workspace = true}
|
||||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||||
|
|
|
@ -35,15 +35,24 @@ impl BackendV3 {
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
let prefix_caching = if let Ok(prefix_caching) = std::env::var("USE_PREFIX_CACHING") {
|
||||||
|
matches!(prefix_caching.as_str(), "true" | "1")
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
};
|
||||||
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
|
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
|
||||||
attention
|
attention
|
||||||
.parse()
|
.parse()
|
||||||
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
|
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
|
||||||
|
} else if prefix_caching {
|
||||||
|
Attention::FlashInfer
|
||||||
} else {
|
} else {
|
||||||
Attention::Paged
|
Attention::Paged
|
||||||
};
|
};
|
||||||
let block_size = if attention == Attention::FlashDecoding {
|
let block_size = if attention == Attention::FlashDecoding {
|
||||||
256
|
256
|
||||||
|
} else if attention == Attention::FlashInfer {
|
||||||
|
1
|
||||||
} else {
|
} else {
|
||||||
16
|
16
|
||||||
};
|
};
|
||||||
|
@ -51,6 +60,7 @@ impl BackendV3 {
|
||||||
let queue = Queue::new(
|
let queue = Queue::new(
|
||||||
requires_padding,
|
requires_padding,
|
||||||
block_size,
|
block_size,
|
||||||
|
prefix_caching,
|
||||||
window_size,
|
window_size,
|
||||||
speculate,
|
speculate,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
|
|
|
@ -1,16 +1,26 @@
|
||||||
use std::cmp::min;
|
use std::{cmp::min, sync::Arc};
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
|
|
||||||
|
use crate::radix::RadixAllocator;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct BlockAllocation {
|
pub(crate) struct BlockAllocation {
|
||||||
|
pub allocation_id: u64,
|
||||||
pub blocks: Vec<u32>,
|
pub blocks: Vec<u32>,
|
||||||
pub slots: Vec<u32>,
|
pub slots: Vec<u32>,
|
||||||
block_allocator: BlockAllocator,
|
|
||||||
|
/// Prefix that was cached and for which the KV does not have to
|
||||||
|
/// be recomputed.
|
||||||
|
pub prefix_len: u32,
|
||||||
|
|
||||||
|
pub(crate) block_allocator: Option<BlockAllocator>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for BlockAllocation {
|
impl Drop for BlockAllocation {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
self.block_allocator.free(self.blocks.clone())
|
if let Some(block_allocator) = self.block_allocator.as_mut() {
|
||||||
|
block_allocator.free(self.blocks.clone(), self.allocation_id)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -24,6 +34,7 @@ impl BlockAllocator {
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
|
prefix_caching: bool,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Create channel
|
// Create channel
|
||||||
|
@ -33,6 +44,7 @@ impl BlockAllocator {
|
||||||
tokio::spawn(block_allocator_task(
|
tokio::spawn(block_allocator_task(
|
||||||
max_batch_total_tokens / block_size,
|
max_batch_total_tokens / block_size,
|
||||||
block_size,
|
block_size,
|
||||||
|
prefix_caching,
|
||||||
window_size,
|
window_size,
|
||||||
receiver,
|
receiver,
|
||||||
));
|
));
|
||||||
|
@ -42,28 +54,32 @@ impl BlockAllocator {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn allocate(&self, tokens: u32) -> Option<BlockAllocation> {
|
pub(crate) async fn allocate(
|
||||||
|
&self,
|
||||||
|
tokens: u32,
|
||||||
|
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||||
|
) -> Option<BlockAllocation> {
|
||||||
let (response_sender, response_receiver) = oneshot::channel();
|
let (response_sender, response_receiver) = oneshot::channel();
|
||||||
self.block_allocator
|
self.block_allocator
|
||||||
.send(BlockAllocatorCommand::Allocate {
|
.send(BlockAllocatorCommand::Allocate {
|
||||||
tokens,
|
tokens,
|
||||||
|
prefill_tokens,
|
||||||
response_sender,
|
response_sender,
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
response_receiver
|
response_receiver.await.unwrap().map(|mut allocation| {
|
||||||
.await
|
allocation.block_allocator = Some(self.clone());
|
||||||
.unwrap()
|
allocation
|
||||||
.map(|(blocks, slots)| BlockAllocation {
|
|
||||||
blocks,
|
|
||||||
slots,
|
|
||||||
block_allocator: self.clone(),
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn free(&self, blocks: Vec<u32>) {
|
pub(crate) fn free(&self, blocks: Vec<u32>, allocation_id: u64) {
|
||||||
self.block_allocator
|
self.block_allocator
|
||||||
.send(BlockAllocatorCommand::Free { blocks })
|
.send(BlockAllocatorCommand::Free {
|
||||||
|
allocation_id,
|
||||||
|
blocks,
|
||||||
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -71,54 +87,29 @@ impl BlockAllocator {
|
||||||
async fn block_allocator_task(
|
async fn block_allocator_task(
|
||||||
blocks: u32,
|
blocks: u32,
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
|
prefix_caching: bool,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
|
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
|
||||||
) {
|
) {
|
||||||
// Block 0 is reserved for health checks
|
let mut allocator: Box<dyn Allocator + Send> = if prefix_caching {
|
||||||
let mut free_blocks: Vec<u32> = (1..blocks).collect();
|
Box::new(RadixAllocator::new(block_size, blocks, window_size))
|
||||||
|
} else {
|
||||||
|
Box::new(SimpleAllocator::new(blocks, block_size, window_size))
|
||||||
|
};
|
||||||
while let Some(cmd) = receiver.recv().await {
|
while let Some(cmd) = receiver.recv().await {
|
||||||
match cmd {
|
match cmd {
|
||||||
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks),
|
BlockAllocatorCommand::Free {
|
||||||
|
blocks,
|
||||||
|
allocation_id,
|
||||||
|
} => allocator.free(blocks, allocation_id),
|
||||||
BlockAllocatorCommand::Allocate {
|
BlockAllocatorCommand::Allocate {
|
||||||
tokens,
|
tokens,
|
||||||
|
prefill_tokens,
|
||||||
response_sender,
|
response_sender,
|
||||||
} => {
|
} => {
|
||||||
// Apply window size
|
response_sender
|
||||||
let (required_blocks, repeats) = {
|
.send(allocator.allocate(tokens, prefill_tokens))
|
||||||
let (tokens, repeats) = match window_size {
|
.unwrap();
|
||||||
None => (tokens, 1),
|
|
||||||
Some(window_size) => {
|
|
||||||
let repeats = (tokens + window_size - 1) / window_size;
|
|
||||||
let tokens = min(tokens, window_size);
|
|
||||||
(tokens, repeats as usize)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
// Pad to a multiple of block size
|
|
||||||
let required_blocks = (tokens + block_size - 1) / block_size;
|
|
||||||
(required_blocks, repeats)
|
|
||||||
};
|
|
||||||
|
|
||||||
let tokens = tokens as usize;
|
|
||||||
let allocation = if required_blocks > free_blocks.len() as u32 {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
let blocks =
|
|
||||||
free_blocks.split_off(free_blocks.len() - required_blocks as usize);
|
|
||||||
let mut slots = Vec::with_capacity(
|
|
||||||
(required_blocks * block_size * repeats as u32) as usize,
|
|
||||||
);
|
|
||||||
|
|
||||||
'slots: for block_id in blocks.repeat(repeats).iter() {
|
|
||||||
for s in (block_id * block_size)..((block_id + 1) * block_size) {
|
|
||||||
slots.push(s);
|
|
||||||
if slots.len() == tokens {
|
|
||||||
break 'slots;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Some((blocks, slots))
|
|
||||||
};
|
|
||||||
response_sender.send(allocation).unwrap();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -128,9 +119,92 @@ async fn block_allocator_task(
|
||||||
enum BlockAllocatorCommand {
|
enum BlockAllocatorCommand {
|
||||||
Free {
|
Free {
|
||||||
blocks: Vec<u32>,
|
blocks: Vec<u32>,
|
||||||
|
allocation_id: u64,
|
||||||
},
|
},
|
||||||
Allocate {
|
Allocate {
|
||||||
tokens: u32,
|
tokens: u32,
|
||||||
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>)>>,
|
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||||
|
response_sender: oneshot::Sender<Option<BlockAllocation>>,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) trait Allocator {
|
||||||
|
fn allocate(
|
||||||
|
&mut self,
|
||||||
|
tokens: u32,
|
||||||
|
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||||
|
) -> Option<BlockAllocation>;
|
||||||
|
|
||||||
|
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct SimpleAllocator {
|
||||||
|
free_blocks: Vec<u32>,
|
||||||
|
block_size: u32,
|
||||||
|
window_size: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SimpleAllocator {
|
||||||
|
fn new(blocks: u32, block_size: u32, window_size: Option<u32>) -> Self {
|
||||||
|
SimpleAllocator {
|
||||||
|
block_size,
|
||||||
|
// Block 0 is reserved for health checks
|
||||||
|
free_blocks: (1..blocks).collect(),
|
||||||
|
window_size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Allocator for SimpleAllocator {
|
||||||
|
fn allocate(
|
||||||
|
&mut self,
|
||||||
|
tokens: u32,
|
||||||
|
_prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||||
|
) -> Option<BlockAllocation> {
|
||||||
|
// Apply window size
|
||||||
|
let (required_blocks, repeats) = {
|
||||||
|
let (tokens, repeats) = match self.window_size {
|
||||||
|
None => (tokens, 1),
|
||||||
|
Some(window_size) => {
|
||||||
|
let repeats = (tokens + window_size - 1) / window_size;
|
||||||
|
let tokens = min(tokens, window_size);
|
||||||
|
(tokens, repeats as usize)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// Pad to a multiple of block size
|
||||||
|
let required_blocks = (tokens + self.block_size - 1) / self.block_size;
|
||||||
|
(required_blocks, repeats)
|
||||||
|
};
|
||||||
|
|
||||||
|
let tokens = tokens as usize;
|
||||||
|
if required_blocks > self.free_blocks.len() as u32 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let blocks = self
|
||||||
|
.free_blocks
|
||||||
|
.split_off(self.free_blocks.len() - required_blocks as usize);
|
||||||
|
let mut slots =
|
||||||
|
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
|
||||||
|
|
||||||
|
'slots: for block_id in blocks.repeat(repeats).iter() {
|
||||||
|
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
||||||
|
slots.push(s);
|
||||||
|
if slots.len() == tokens {
|
||||||
|
break 'slots;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(BlockAllocation {
|
||||||
|
allocation_id: 0,
|
||||||
|
blocks,
|
||||||
|
slots,
|
||||||
|
prefix_len: 0,
|
||||||
|
block_allocator: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn free(&mut self, blocks: Vec<u32>, _allocation_id: u64) {
|
||||||
|
self.free_blocks.extend(blocks)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -157,6 +157,7 @@ impl Client {
|
||||||
// Blocks and slots will be set on the server side if we use paged attention
|
// Blocks and slots will be set on the server side if we use paged attention
|
||||||
blocks: vec![],
|
blocks: vec![],
|
||||||
slots: vec![],
|
slots: vec![],
|
||||||
|
prefix_len: 0,
|
||||||
// Set sampling parameters to also take these ops into account in the max memory
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 0.9,
|
temperature: 0.9,
|
||||||
|
|
|
@ -245,6 +245,7 @@ impl Health for ShardedClient {
|
||||||
// Block 0 is reserved for health checks
|
// Block 0 is reserved for health checks
|
||||||
blocks: vec![0],
|
blocks: vec![0],
|
||||||
slots: (0..16).collect(),
|
slots: (0..16).collect(),
|
||||||
|
prefix_len: 0,
|
||||||
adapter_id: None,
|
adapter_id: None,
|
||||||
};
|
};
|
||||||
let batch = Batch {
|
let batch = Batch {
|
||||||
|
|
|
@ -2,6 +2,7 @@ mod backend;
|
||||||
mod block_allocator;
|
mod block_allocator;
|
||||||
mod client;
|
mod client;
|
||||||
mod queue;
|
mod queue;
|
||||||
|
mod radix;
|
||||||
|
|
||||||
use crate::client::{ClientError, ShardedClient};
|
use crate::client::{ClientError, ShardedClient};
|
||||||
pub(crate) use backend::BackendV3;
|
pub(crate) use backend::BackendV3;
|
||||||
|
|
|
@ -46,6 +46,7 @@ impl Queue {
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
|
prefix_caching: bool,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
|
@ -57,6 +58,7 @@ impl Queue {
|
||||||
tokio::spawn(queue_task(
|
tokio::spawn(queue_task(
|
||||||
requires_padding,
|
requires_padding,
|
||||||
block_size,
|
block_size,
|
||||||
|
prefix_caching,
|
||||||
window_size,
|
window_size,
|
||||||
speculate,
|
speculate,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
|
@ -109,6 +111,7 @@ impl Queue {
|
||||||
async fn queue_task(
|
async fn queue_task(
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
|
prefix_caching: bool,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
|
@ -117,6 +120,7 @@ async fn queue_task(
|
||||||
let mut state = State::new(
|
let mut state = State::new(
|
||||||
requires_padding,
|
requires_padding,
|
||||||
block_size,
|
block_size,
|
||||||
|
prefix_caching,
|
||||||
window_size,
|
window_size,
|
||||||
speculate,
|
speculate,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
|
@ -176,12 +180,19 @@ impl State {
|
||||||
fn new(
|
fn new(
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
|
prefix_caching: bool,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let block_allocator = (!requires_padding)
|
let block_allocator = (!requires_padding).then(|| {
|
||||||
.then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size));
|
BlockAllocator::new(
|
||||||
|
max_batch_total_tokens,
|
||||||
|
block_size,
|
||||||
|
prefix_caching,
|
||||||
|
window_size,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
entries: VecDeque::with_capacity(128),
|
entries: VecDeque::with_capacity(128),
|
||||||
|
@ -305,7 +316,10 @@ impl State {
|
||||||
+ self.speculate
|
+ self.speculate
|
||||||
- 1;
|
- 1;
|
||||||
|
|
||||||
match block_allocator.allocate(tokens).await {
|
match block_allocator
|
||||||
|
.allocate(tokens, entry.request.input_ids.clone())
|
||||||
|
.await
|
||||||
|
{
|
||||||
None => {
|
None => {
|
||||||
// Entry is over budget
|
// Entry is over budget
|
||||||
// Add it back to the front
|
// Add it back to the front
|
||||||
|
@ -331,11 +345,12 @@ impl State {
|
||||||
// Update entry
|
// Update entry
|
||||||
entry.temp_span = Some(entry_batch_span);
|
entry.temp_span = Some(entry_batch_span);
|
||||||
|
|
||||||
let (blocks, slots) = match &block_allocation {
|
let (blocks, slots, prefix_len) = match &block_allocation {
|
||||||
None => (Vec::new(), Vec::new()),
|
None => (Vec::new(), Vec::new(), 0),
|
||||||
Some(block_allocation) => (
|
Some(block_allocation) => (
|
||||||
block_allocation.blocks.clone(),
|
block_allocation.blocks.clone(),
|
||||||
block_allocation.slots.clone(),
|
block_allocation.slots.clone(),
|
||||||
|
block_allocation.prefix_len,
|
||||||
),
|
),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -372,6 +387,7 @@ impl State {
|
||||||
top_n_tokens: entry.request.top_n_tokens,
|
top_n_tokens: entry.request.top_n_tokens,
|
||||||
blocks,
|
blocks,
|
||||||
slots,
|
slots,
|
||||||
|
prefix_len,
|
||||||
adapter_id: entry.request.adapter_id.clone(),
|
adapter_id: entry.request.adapter_id.clone(),
|
||||||
});
|
});
|
||||||
// Set batch_time
|
// Set batch_time
|
||||||
|
@ -480,6 +496,8 @@ impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use tracing::info_span;
|
use tracing::info_span;
|
||||||
|
|
||||||
|
@ -492,6 +510,7 @@ mod tests {
|
||||||
let entry = Entry {
|
let entry = Entry {
|
||||||
request: ValidGenerateRequest {
|
request: ValidGenerateRequest {
|
||||||
inputs: vec![],
|
inputs: vec![],
|
||||||
|
input_ids: Some(Arc::new(vec![])),
|
||||||
input_length: 0,
|
input_length: 0,
|
||||||
truncate: 0,
|
truncate: 0,
|
||||||
decoder_input_details: false,
|
decoder_input_details: false,
|
||||||
|
@ -527,7 +546,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_append() {
|
async fn test_append() {
|
||||||
let mut state = State::new(false, 1, None, 0, 16);
|
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||||
let (entry, _guard) = default_entry();
|
let (entry, _guard) = default_entry();
|
||||||
|
|
||||||
assert_eq!(state.next_id, 0);
|
assert_eq!(state.next_id, 0);
|
||||||
|
@ -543,7 +562,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_next_batch_empty() {
|
async fn test_next_batch_empty() {
|
||||||
let mut state = State::new(false, 1, None, 0, 16);
|
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||||
|
|
||||||
assert!(state.next_batch(None, None, 1, 1).await.is_none());
|
assert!(state.next_batch(None, None, 1, 1).await.is_none());
|
||||||
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
|
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||||
|
@ -551,7 +570,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_next_batch_min_size() {
|
async fn test_next_batch_min_size() {
|
||||||
let mut state = State::new(false, 1, None, 0, 16);
|
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
|
@ -583,7 +602,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_next_batch_max_size() {
|
async fn test_next_batch_max_size() {
|
||||||
let mut state = State::new(false, 1, None, 0, 16);
|
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
|
@ -603,7 +622,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_next_batch_token_budget() {
|
async fn test_next_batch_token_budget() {
|
||||||
let mut state = State::new(false, 1, None, 0, 2);
|
let mut state = State::new(false, 1, false, None, 0, 2);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
|
@ -636,14 +655,14 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_append() {
|
async fn test_queue_append() {
|
||||||
let queue = Queue::new(false, 1, None, 0, 16);
|
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||||
let (entry, _guard) = default_entry();
|
let (entry, _guard) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_empty() {
|
async fn test_queue_next_batch_empty() {
|
||||||
let queue = Queue::new(false, 1, None, 0, 16);
|
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||||
|
|
||||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||||
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||||
|
@ -651,7 +670,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_min_size() {
|
async fn test_queue_next_batch_min_size() {
|
||||||
let queue = Queue::new(false, 1, None, 0, 16);
|
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
|
@ -684,7 +703,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_max_size() {
|
async fn test_queue_next_batch_max_size() {
|
||||||
let queue = Queue::new(false, 1, None, 0, 16);
|
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
|
@ -700,7 +719,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_token_budget() {
|
async fn test_queue_next_batch_token_budget() {
|
||||||
let queue = Queue::new(false, 1, None, 0, 16);
|
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
|
@ -725,7 +744,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_token_speculate() {
|
async fn test_queue_next_batch_token_speculate() {
|
||||||
let queue = Queue::new(false, 1, None, 2, 16);
|
let queue = Queue::new(false, 1, false, None, 2, 16);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
|
@ -744,7 +763,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_dropped_receiver() {
|
async fn test_queue_next_batch_dropped_receiver() {
|
||||||
let queue = Queue::new(false, 1, None, 0, 16);
|
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||||
let (entry, _) = default_entry();
|
let (entry, _) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,755 @@
|
||||||
|
use std::{
|
||||||
|
collections::{BTreeSet, HashMap},
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
|
|
||||||
|
use slotmap::{DefaultKey, SlotMap};
|
||||||
|
|
||||||
|
use crate::block_allocator::{Allocator, BlockAllocation};
|
||||||
|
|
||||||
|
pub struct RadixAllocator {
|
||||||
|
allocation_id: u64,
|
||||||
|
|
||||||
|
allocations: HashMap<u64, RadixAllocation>,
|
||||||
|
|
||||||
|
cache_blocks: RadixTrie,
|
||||||
|
|
||||||
|
/// Blocks that are immediately available for allocation.
|
||||||
|
free_blocks: Vec<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RadixAllocator {
|
||||||
|
pub fn new(block_size: u32, n_blocks: u32, window_size: Option<u32>) -> Self {
|
||||||
|
assert_eq!(
|
||||||
|
block_size, 1,
|
||||||
|
"Radix tree allocator only works with block_size=1, was: {}",
|
||||||
|
block_size
|
||||||
|
);
|
||||||
|
if window_size.is_some() {
|
||||||
|
unimplemented!("Window size not supported in the prefix-caching block allocator yet");
|
||||||
|
}
|
||||||
|
|
||||||
|
RadixAllocator {
|
||||||
|
allocation_id: 0,
|
||||||
|
allocations: HashMap::new(),
|
||||||
|
cache_blocks: RadixTrie::new(),
|
||||||
|
|
||||||
|
// Block 0 is reserved for health checks.
|
||||||
|
free_blocks: (1..n_blocks).collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn alloc_or_reclaim(&mut self, n_blocks_needed: usize) -> Option<Vec<u32>> {
|
||||||
|
if self.free_blocks.len() < n_blocks_needed {
|
||||||
|
// This is a bit annoying, we first extend the free list and then
|
||||||
|
// split it off again below. This is because we need to put it on
|
||||||
|
// the free list if we cannot allocate enough blocks. This is only
|
||||||
|
// temporary, the trie needs to be able to report whether it can
|
||||||
|
// allocate the requested amount. Just not implemented yet.
|
||||||
|
self.free_blocks.extend(
|
||||||
|
self.cache_blocks
|
||||||
|
.evict(n_blocks_needed - self.free_blocks.len()),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.free_blocks.len() >= n_blocks_needed {
|
||||||
|
Some(
|
||||||
|
self.free_blocks
|
||||||
|
.split_off(self.free_blocks.len() - n_blocks_needed),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Allocator for RadixAllocator {
|
||||||
|
fn allocate(
|
||||||
|
&mut self,
|
||||||
|
tokens: u32,
|
||||||
|
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||||
|
) -> Option<BlockAllocation> {
|
||||||
|
let mut blocks = vec![];
|
||||||
|
let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() {
|
||||||
|
let node_id = self
|
||||||
|
.cache_blocks
|
||||||
|
.find(prefill_tokens.as_slice(), &mut blocks);
|
||||||
|
// Even if this allocation fails below, we need to increase he
|
||||||
|
// refcount to ensure that the prefix that was found is not evicted.
|
||||||
|
|
||||||
|
node_id
|
||||||
|
} else {
|
||||||
|
self.cache_blocks.root_id()
|
||||||
|
};
|
||||||
|
|
||||||
|
self.cache_blocks
|
||||||
|
.incref(prefix_node)
|
||||||
|
.expect("Failed to increment refcount");
|
||||||
|
|
||||||
|
let prefix_len = blocks.len();
|
||||||
|
let suffix_len = tokens - prefix_len as u32;
|
||||||
|
|
||||||
|
match self.alloc_or_reclaim(suffix_len as usize) {
|
||||||
|
Some(suffix_blocks) => blocks.extend(suffix_blocks),
|
||||||
|
None => {
|
||||||
|
self.cache_blocks
|
||||||
|
.decref(prefix_node)
|
||||||
|
.expect("Failed to decrement refcount");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1:1 mapping of blocks and slots.
|
||||||
|
let slots = blocks.clone();
|
||||||
|
|
||||||
|
let allocation = RadixAllocation {
|
||||||
|
prefix_node,
|
||||||
|
cached_prefix_len: prefix_len,
|
||||||
|
prefill_tokens: prefill_tokens.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
self.allocation_id += 1;
|
||||||
|
self.allocations.insert(self.allocation_id, allocation);
|
||||||
|
|
||||||
|
Some(BlockAllocation {
|
||||||
|
allocation_id: self.allocation_id,
|
||||||
|
block_allocator: None,
|
||||||
|
blocks,
|
||||||
|
slots,
|
||||||
|
prefix_len: prefix_len as u32,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
|
||||||
|
let allocation = match self.allocations.remove(&allocation_id) {
|
||||||
|
Some(allocation) => allocation,
|
||||||
|
None => unreachable!("Tried to free an unknown allocation."),
|
||||||
|
};
|
||||||
|
|
||||||
|
self.cache_blocks
|
||||||
|
.decref(allocation.prefix_node)
|
||||||
|
.expect("Failed to decrement refcount");
|
||||||
|
|
||||||
|
if let Some(prefill_tokens) = allocation.prefill_tokens {
|
||||||
|
let prefill_tokens = prefill_tokens.as_slice();
|
||||||
|
|
||||||
|
// If there are prefill tokens that did not come from the cache,
|
||||||
|
// add them to the cache.
|
||||||
|
if prefill_tokens.len() > allocation.cached_prefix_len {
|
||||||
|
let prefix_len = self
|
||||||
|
.cache_blocks
|
||||||
|
.insert(prefill_tokens, &blocks[..prefill_tokens.len()])
|
||||||
|
// Unwrap, failing is a programming error.
|
||||||
|
.expect("Failed to store prefill tokens");
|
||||||
|
|
||||||
|
// We can have a prefill with the following structure:
|
||||||
|
//
|
||||||
|
// |---| From the prefix cache.
|
||||||
|
// A B C D E F G
|
||||||
|
//|--------| Found in the trie during insertion.
|
||||||
|
//
|
||||||
|
// This means that while processing this request there was a
|
||||||
|
// partially overlapping request that had A..=E in its
|
||||||
|
// prefill. In this case we need to free the blocks D E.
|
||||||
|
self.free_blocks
|
||||||
|
.extend(&blocks[allocation.cached_prefix_len..prefix_len]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Free non-prefill blocks.
|
||||||
|
self.free_blocks.extend(&blocks[prefill_tokens.len()..]);
|
||||||
|
} else {
|
||||||
|
self.free_blocks.extend(blocks);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct RadixAllocation {
|
||||||
|
prefix_node: NodeId,
|
||||||
|
cached_prefix_len: usize,
|
||||||
|
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Radix trie that is heavily inspired by radix attention from sglang.
|
||||||
|
//
|
||||||
|
// The trie is optimized for prefix caching:
|
||||||
|
//
|
||||||
|
// - A normal radix trie stores discrete values. In this radix trie,
|
||||||
|
// inserting *abc* with value *xyz* will also enable lookup for
|
||||||
|
// *a* (*x*) and *ab* (*xy*).
|
||||||
|
// - As a result, every value is required to have the same length as
|
||||||
|
// the key.
|
||||||
|
// - We store additional information in each node, such as last access
|
||||||
|
// time and a reference count.
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum TrieError {
|
||||||
|
InvalidNodeId,
|
||||||
|
RefCountUnderflow,
|
||||||
|
BlockTokenCountMismatch,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type NodeId = DefaultKey;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct RadixTrie {
|
||||||
|
/// Identifier of the root nod.
|
||||||
|
root: DefaultKey,
|
||||||
|
|
||||||
|
/// Leave node identifiers ordered by increasing recency.
|
||||||
|
leaves: BTreeSet<(u64, NodeId)>,
|
||||||
|
|
||||||
|
/// All trie nodes.
|
||||||
|
nodes: SlotMap<NodeId, TrieNode>,
|
||||||
|
|
||||||
|
/// Time as a monotonically increating counter to avoid the system
|
||||||
|
/// call that a real time lookup would require.
|
||||||
|
time: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RadixTrie {
|
||||||
|
/// Construct a new radix trie.
|
||||||
|
pub fn new() -> Self {
|
||||||
|
let root = TrieNode::new(vec![], vec![], 0, None);
|
||||||
|
let mut nodes = SlotMap::new();
|
||||||
|
let root = nodes.insert(root);
|
||||||
|
RadixTrie {
|
||||||
|
leaves: BTreeSet::new(),
|
||||||
|
nodes,
|
||||||
|
root,
|
||||||
|
time: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find the prefix of the given tokens.
|
||||||
|
///
|
||||||
|
/// The blocks corresponding to the part of the prefix that could be found
|
||||||
|
/// are writteng to `blocks`. The number of blocks is in `0..=tokens.len()`.
|
||||||
|
/// Returns the identifier of the trie node that contains the longest
|
||||||
|
/// prefix. The node identifier can be used by callers to e.g. increase its
|
||||||
|
/// reference count.
|
||||||
|
///
|
||||||
|
/// Using this method will update the access time of the traversed nodes.
|
||||||
|
pub fn find(&mut self, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
||||||
|
self.time += 1;
|
||||||
|
self.find_(self.root, key, blocks)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find worker.
|
||||||
|
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
||||||
|
let node = &self.nodes[node_id];
|
||||||
|
|
||||||
|
if let Some(&child_id) = node.children.get(&key[0]) {
|
||||||
|
self.update_access_time(child_id);
|
||||||
|
let child = self.nodes.get(child_id).expect("Invalid child identifier");
|
||||||
|
let shared_prefix_len = child.key.shared_prefix_len(key);
|
||||||
|
blocks.extend(&child.blocks[..shared_prefix_len]);
|
||||||
|
|
||||||
|
let key = &key[shared_prefix_len..];
|
||||||
|
if !key.is_empty() {
|
||||||
|
node_id = self.find_(child_id, key, blocks);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
node_id
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decrease the reference count of a node.
|
||||||
|
pub fn decref(&mut self, node_id: NodeId) -> Result<(), TrieError> {
|
||||||
|
// We don't care about refcounting for root, since it will never
|
||||||
|
// be evicted.
|
||||||
|
if node_id == self.root {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let node = self
|
||||||
|
.nodes
|
||||||
|
.get_mut(node_id)
|
||||||
|
.ok_or(TrieError::InvalidNodeId)?;
|
||||||
|
if node.ref_count == 0 {
|
||||||
|
return Err(TrieError::RefCountUnderflow);
|
||||||
|
}
|
||||||
|
|
||||||
|
node.ref_count -= 1;
|
||||||
|
if node.ref_count == 0 {
|
||||||
|
self.leaves.insert((node.last_accessed, node_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Increase the reference count of a node.
|
||||||
|
pub fn incref(&mut self, node_id: NodeId) -> Result<(), TrieError> {
|
||||||
|
if node_id == self.root {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let node = self
|
||||||
|
.nodes
|
||||||
|
.get_mut(node_id)
|
||||||
|
.ok_or(TrieError::InvalidNodeId)?;
|
||||||
|
if node.ref_count == 0 {
|
||||||
|
self.leaves.remove(&(node.last_accessed, node_id));
|
||||||
|
}
|
||||||
|
node.ref_count += 1;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Evict `n_blocks` from the trie.
|
||||||
|
///
|
||||||
|
/// Returns the evicted blocks. When the length is less than `n_blocks`,
|
||||||
|
/// not enough blocks could beevicted.
|
||||||
|
pub fn evict(&mut self, n_blocks: usize) -> Vec<u32> {
|
||||||
|
// NOTE: we don't return Result here. If any of the unwrapping fails,
|
||||||
|
// it's a programming error in the trie implementation, not a user
|
||||||
|
// error caused by e.g. an invalid argument.
|
||||||
|
|
||||||
|
// TODO: add some bookkeeping in the future to check whether we can
|
||||||
|
// evict n_blocks and return `None` if we can't. We are now needlessly
|
||||||
|
// evicting prefixes from the cache in such a case.
|
||||||
|
let mut evicted = Vec::new();
|
||||||
|
|
||||||
|
while let Some((last_access, node_id)) = self.leaves.pop_first() {
|
||||||
|
let blocks_needed = n_blocks - evicted.len();
|
||||||
|
|
||||||
|
let node = self.nodes.get(node_id).expect("Leave does not exist");
|
||||||
|
if blocks_needed >= node.blocks.len() {
|
||||||
|
// We need to evict the whole node if we need more blocks than it has.
|
||||||
|
let node = self.remove_node(node_id);
|
||||||
|
evicted.extend(node.blocks);
|
||||||
|
|
||||||
|
if evicted.len() >= n_blocks {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// The node has more blocks than needed, so we'll just remove
|
||||||
|
// the required number of blocks and leave the remaining blocks
|
||||||
|
// untouched.
|
||||||
|
let node = self.nodes.get_mut(node_id).expect("Leave does not exist");
|
||||||
|
node.key.truncate(node.blocks.len() - blocks_needed);
|
||||||
|
evicted.extend(node.blocks.split_off(node.blocks.len() - blocks_needed));
|
||||||
|
self.leaves.insert((last_access, node_id));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
evicted
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Insert a prefill along with its blocks.
|
||||||
|
///
|
||||||
|
/// This method returns the length of the prefix that was already
|
||||||
|
/// in the trie. E.g. if the length is 10, this means that for
|
||||||
|
/// the first 10 elements of the tree **the blocks are not updated**.
|
||||||
|
pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result<usize, TrieError> {
|
||||||
|
self.time += 1;
|
||||||
|
self.insert_(self.root, tokens, blocks)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Insertion worker.
|
||||||
|
fn insert_(
|
||||||
|
&mut self,
|
||||||
|
node_id: NodeId,
|
||||||
|
tokens: &[u32],
|
||||||
|
blocks: &[u32],
|
||||||
|
) -> Result<usize, TrieError> {
|
||||||
|
// TODO: in the future we may want to check that the blocks match for
|
||||||
|
// the part of the prefix that is already in the trie to detect
|
||||||
|
// mismatches.
|
||||||
|
|
||||||
|
if tokens.len() != blocks.len() {
|
||||||
|
return Err(TrieError::BlockTokenCountMismatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(&child_id) = self.nodes[node_id].children.get(&tokens[0]) {
|
||||||
|
self.update_access_time(child_id);
|
||||||
|
let child = self
|
||||||
|
.nodes
|
||||||
|
.get_mut(child_id)
|
||||||
|
// Unwrap here, since failure is a bug.
|
||||||
|
.expect("Child node does not exist");
|
||||||
|
let shared_prefix_len = child.key.shared_prefix_len(tokens);
|
||||||
|
|
||||||
|
// We are done, the prefix is already in the trie.
|
||||||
|
if shared_prefix_len == tokens.len() {
|
||||||
|
return Ok(shared_prefix_len);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The node's prefix is a prefix of the insertion prefix.
|
||||||
|
if shared_prefix_len == child.key.len() {
|
||||||
|
return Ok(shared_prefix_len
|
||||||
|
+ self.insert_(
|
||||||
|
child_id,
|
||||||
|
&tokens[shared_prefix_len..],
|
||||||
|
&blocks[shared_prefix_len..],
|
||||||
|
)?);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The node's prefix and the insertion prefix only match partially,
|
||||||
|
// split the node to just contain the matching part. Then insert the
|
||||||
|
// remainder of the prefix into the node again
|
||||||
|
let child_id = self.split_node(child_id, shared_prefix_len);
|
||||||
|
let key = &tokens[shared_prefix_len..];
|
||||||
|
let blocks = &blocks[shared_prefix_len..];
|
||||||
|
Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?)
|
||||||
|
} else {
|
||||||
|
self.add_node(node_id, tokens, blocks);
|
||||||
|
Ok(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn split_node(&mut self, node_id: NodeId, prefix_len: usize) -> NodeId {
|
||||||
|
// We have to make the current node a child to ensure that its
|
||||||
|
// properties and node id stay the same.
|
||||||
|
|
||||||
|
// This funcion unwraps, an invalid node_id is a programming error.
|
||||||
|
|
||||||
|
let node = self
|
||||||
|
.nodes
|
||||||
|
.get_mut(node_id)
|
||||||
|
.expect("Node to-be split does not exist");
|
||||||
|
let mut parent_key = node.key.split_off(prefix_len);
|
||||||
|
let mut parent_blocks = node.blocks.split_off(prefix_len);
|
||||||
|
|
||||||
|
// Move first part of the prefix to the parent. We swap to avoid
|
||||||
|
// an allocation + copy for both splits of the key/blocks.
|
||||||
|
std::mem::swap(&mut node.key, &mut parent_key);
|
||||||
|
std::mem::swap(&mut node.blocks, &mut parent_blocks);
|
||||||
|
|
||||||
|
let node_key = node.key[0];
|
||||||
|
|
||||||
|
let grandparent_id = node.parent.expect("Node does not have a parent");
|
||||||
|
let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks);
|
||||||
|
self.add_node_to_parent(parent_id, node_key, node_id);
|
||||||
|
|
||||||
|
// Reborrow to make the borrow checker happy.
|
||||||
|
let node = self
|
||||||
|
.nodes
|
||||||
|
.get_mut(node_id)
|
||||||
|
.expect("Node to-be split does not exist");
|
||||||
|
node.parent = Some(parent_id);
|
||||||
|
|
||||||
|
parent_id
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a node and add it to the parent.
|
||||||
|
fn add_node(
|
||||||
|
&mut self,
|
||||||
|
parent_id: NodeId,
|
||||||
|
key: impl Into<Vec<u32>>,
|
||||||
|
blocks: impl Into<Vec<u32>>,
|
||||||
|
) -> NodeId {
|
||||||
|
let key = key.into();
|
||||||
|
let blocks = blocks.into();
|
||||||
|
let first = key[0];
|
||||||
|
|
||||||
|
let child = TrieNode::new(key, blocks, self.time, Some(parent_id));
|
||||||
|
let child_id = self.nodes.insert(child);
|
||||||
|
|
||||||
|
self.add_node_to_parent(parent_id, first, child_id);
|
||||||
|
self.leaves.insert((self.time, child_id));
|
||||||
|
|
||||||
|
child_id
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a node to the parent.
|
||||||
|
fn add_node_to_parent(&mut self, parent_id: NodeId, first: u32, child_id: NodeId) {
|
||||||
|
// Unwrap here, passing in an unknown id is a programming error.
|
||||||
|
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
|
||||||
|
if parent.children.insert(first, child_id).is_none() {
|
||||||
|
// Only increase reference count if child does not replace another child.
|
||||||
|
self.incref(parent_id)
|
||||||
|
.expect("Failed to increase parent refcount");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove a node from the trie.
|
||||||
|
fn remove_node(&mut self, node_id: NodeId) -> TrieNode {
|
||||||
|
// Unwrap here, passing in an unknown id is a programming error.
|
||||||
|
let node = self.nodes.remove(node_id).expect("Unknown node");
|
||||||
|
let parent_id = node.parent.expect("Attempted to remove root node");
|
||||||
|
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
|
||||||
|
parent.children.remove(&node.key[0]);
|
||||||
|
self.decref(parent_id)
|
||||||
|
.expect("Failed to decrease parent refcount");
|
||||||
|
self.nodes.remove(node_id);
|
||||||
|
node
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update_access_time(&mut self, node_id: NodeId) {
|
||||||
|
// Unwrap here, passing in an unknown id is a programming error.
|
||||||
|
let node = self.nodes.get_mut(node_id).expect("Unknown node");
|
||||||
|
|
||||||
|
// Update the ordered leaves set if the node is a leave.
|
||||||
|
if self.leaves.remove(&(node.last_accessed, node_id)) {
|
||||||
|
self.leaves.insert((self.time, node_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
node.last_accessed = self.time;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
#[doc(hidden)]
|
||||||
|
/// Print debugging output for the trie.
|
||||||
|
///
|
||||||
|
/// In contrast to `Debug` nicely formatted.
|
||||||
|
pub fn print_debug(&self) {
|
||||||
|
self.print_debug_(self.root, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_debug_(&self, node_id: NodeId, indent: usize) {
|
||||||
|
let node = &self.nodes[node_id];
|
||||||
|
eprintln!(
|
||||||
|
"{}{:?}, key: {:?}, blocks: {:?}, ref_count: {}, last_accessed: {}, parent: {:?}, children: {:?}",
|
||||||
|
" ".repeat(indent),
|
||||||
|
node_id,
|
||||||
|
node.key,
|
||||||
|
node.blocks,
|
||||||
|
node.ref_count,
|
||||||
|
node.last_accessed,
|
||||||
|
node.parent,
|
||||||
|
node.children
|
||||||
|
);
|
||||||
|
for child_id in self.nodes[node_id].children.values() {
|
||||||
|
self.print_debug_(*child_id, indent + 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn root_id(&self) -> DefaultKey {
|
||||||
|
self.root
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Trie node.
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct TrieNode {
|
||||||
|
blocks: Vec<u32>,
|
||||||
|
children: HashMap<u32, NodeId>,
|
||||||
|
key: Vec<u32>,
|
||||||
|
last_accessed: u64,
|
||||||
|
parent: Option<NodeId>,
|
||||||
|
ref_count: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TrieNode {
|
||||||
|
fn new(key: Vec<u32>, blocks: Vec<u32>, last_accessed: u64, parent: Option<NodeId>) -> Self {
|
||||||
|
TrieNode {
|
||||||
|
children: HashMap::new(),
|
||||||
|
key,
|
||||||
|
blocks,
|
||||||
|
last_accessed,
|
||||||
|
parent,
|
||||||
|
ref_count: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper trait to get the length of the shared prefix of two sequences.
|
||||||
|
trait SharedPrefixLen {
|
||||||
|
fn shared_prefix_len(&self, other: &Self) -> usize;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> SharedPrefixLen for [T]
|
||||||
|
where
|
||||||
|
T: PartialEq,
|
||||||
|
{
|
||||||
|
fn shared_prefix_len(&self, other: &Self) -> usize {
|
||||||
|
self.iter().zip(other).take_while(|(a, b)| a == b).count()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use crate::block_allocator::Allocator;
|
||||||
|
|
||||||
|
use super::RadixAllocator;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn allocator_reuses_prefixes() {
|
||||||
|
let mut cache = RadixAllocator::new(1, 12, None);
|
||||||
|
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||||
|
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
|
||||||
|
assert_eq!(allocation.slots, allocation.slots);
|
||||||
|
assert_eq!(allocation.prefix_len, 0);
|
||||||
|
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||||
|
|
||||||
|
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||||
|
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
|
||||||
|
assert_eq!(allocation.prefix_len, 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn allocator_collects_older_prefixes_first() {
|
||||||
|
let mut cache = RadixAllocator::new(1, 7, None);
|
||||||
|
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||||
|
assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]);
|
||||||
|
assert_eq!(allocation1.prefix_len, 0);
|
||||||
|
|
||||||
|
let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap();
|
||||||
|
assert_eq!(allocation2.blocks, vec![1, 2]);
|
||||||
|
assert_eq!(allocation2.prefix_len, 0);
|
||||||
|
|
||||||
|
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
|
||||||
|
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
|
||||||
|
|
||||||
|
// We should get the blocks of the first allocation, since they are more recent.
|
||||||
|
let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap();
|
||||||
|
assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]);
|
||||||
|
assert_eq!(allocation3.prefix_len, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn allocator_frees_fully_overlapping_prefills() {
|
||||||
|
let mut cache = RadixAllocator::new(1, 10, None);
|
||||||
|
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||||
|
let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||||
|
|
||||||
|
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
|
||||||
|
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
|
||||||
|
|
||||||
|
let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||||
|
assert_eq!(allocation3.prefix_len, 4);
|
||||||
|
|
||||||
|
// 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks.
|
||||||
|
assert_eq!(cache.free_blocks.len(), 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn allocator_frees_partially_overlapping_prefills() {
|
||||||
|
let mut cache = RadixAllocator::new(1, 20, None);
|
||||||
|
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap();
|
||||||
|
assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]);
|
||||||
|
assert_eq!(allocation1.prefix_len, 0);
|
||||||
|
|
||||||
|
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
|
||||||
|
|
||||||
|
let allocation2 = cache
|
||||||
|
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]);
|
||||||
|
assert_eq!(allocation2.prefix_len, 2);
|
||||||
|
|
||||||
|
let allocation3 = cache
|
||||||
|
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]);
|
||||||
|
assert_eq!(allocation3.prefix_len, 2);
|
||||||
|
|
||||||
|
cache.free(allocation3.blocks.clone(), allocation3.allocation_id);
|
||||||
|
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
|
||||||
|
|
||||||
|
// 20 blocks, of which 1 reserved for health checks, 6 for allocation3, 2 for allocation2.
|
||||||
|
assert_eq!(cache.free_blocks.len(), 11);
|
||||||
|
|
||||||
|
let allocation4 = cache
|
||||||
|
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]);
|
||||||
|
assert_eq!(allocation4.prefix_len, 6);
|
||||||
|
assert_eq!(cache.free_blocks.len(), 11);
|
||||||
|
|
||||||
|
let allocation5 = cache
|
||||||
|
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]);
|
||||||
|
assert_eq!(allocation5.prefix_len, 6);
|
||||||
|
assert_eq!(cache.free_blocks.len(), 11);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn trie_insertions_have_correct_prefix_len() {
|
||||||
|
let mut trie = super::RadixTrie::new();
|
||||||
|
|
||||||
|
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0);
|
||||||
|
|
||||||
|
// Already exists.
|
||||||
|
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 3);
|
||||||
|
|
||||||
|
// Completely new at root-level
|
||||||
|
assert_eq!(trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(), 0);
|
||||||
|
|
||||||
|
// Contains full prefix, but longer.
|
||||||
|
assert_eq!(trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(), 3);
|
||||||
|
|
||||||
|
// Shares partial prefix, we need a split.
|
||||||
|
assert_eq!(
|
||||||
|
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
|
||||||
|
.unwrap(),
|
||||||
|
4
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn trie_get_returns_correct_blocks() {
|
||||||
|
let mut trie = super::RadixTrie::new();
|
||||||
|
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
|
||||||
|
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
|
||||||
|
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
|
||||||
|
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut blocks = Vec::new();
|
||||||
|
trie.find(&[0], &mut blocks);
|
||||||
|
assert_eq!(blocks, vec![0]);
|
||||||
|
|
||||||
|
blocks.clear();
|
||||||
|
trie.find(&[0, 1, 2], &mut blocks);
|
||||||
|
assert_eq!(blocks, vec![0, 1, 2]);
|
||||||
|
|
||||||
|
blocks.clear();
|
||||||
|
trie.find(&[1, 2, 3], &mut blocks);
|
||||||
|
assert_eq!(blocks, vec![1, 2, 3]);
|
||||||
|
|
||||||
|
blocks.clear();
|
||||||
|
trie.find(&[0, 1, 2, 3], &mut blocks);
|
||||||
|
assert_eq!(blocks, vec![0, 1, 2, 3]);
|
||||||
|
|
||||||
|
blocks.clear();
|
||||||
|
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
|
||||||
|
assert_eq!(blocks, vec![0, 1, 2, 3, 4]);
|
||||||
|
|
||||||
|
blocks.clear();
|
||||||
|
trie.find(&[0, 1, 2, 3, 5], &mut blocks);
|
||||||
|
assert_eq!(blocks, vec![0, 1, 2, 3, 5]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn trie_evict_removes_correct_blocks() {
|
||||||
|
let mut trie = super::RadixTrie::new();
|
||||||
|
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
|
||||||
|
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
|
||||||
|
.unwrap();
|
||||||
|
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
|
||||||
|
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
|
||||||
|
|
||||||
|
let mut blocks = Vec::new();
|
||||||
|
|
||||||
|
// Remove less than the leave blocks.
|
||||||
|
assert_eq!(trie.evict(1), vec![7]);
|
||||||
|
trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks);
|
||||||
|
assert_eq!(blocks, vec![0, 1, 2, 3, 5, 6]);
|
||||||
|
|
||||||
|
// Refresh other leaf.
|
||||||
|
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
|
||||||
|
trie.find(&[1, 2, 3], &mut blocks);
|
||||||
|
|
||||||
|
// Remove the leave blocks exactly.
|
||||||
|
assert_eq!(trie.evict(2), vec![5, 6]);
|
||||||
|
blocks.clear();
|
||||||
|
trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks);
|
||||||
|
assert_eq!(blocks, vec![0, 1, 2, 3]);
|
||||||
|
|
||||||
|
trie.find(&[1, 2, 3], &mut blocks);
|
||||||
|
|
||||||
|
// Remove more than the leave blocks.
|
||||||
|
assert_eq!(trie.evict(3), vec![4, 3, 2]);
|
||||||
|
blocks.clear();
|
||||||
|
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
|
||||||
|
assert_eq!(blocks, vec![0, 1]);
|
||||||
|
|
||||||
|
// Clear out the whole trie.
|
||||||
|
assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]);
|
||||||
|
}
|
||||||
|
}
|
|
@ -157,6 +157,7 @@ async fn prefill(
|
||||||
top_n_tokens: top_n_tokens.unwrap_or(0),
|
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||||
blocks: vec![],
|
blocks: vec![],
|
||||||
slots: vec![],
|
slots: vec![],
|
||||||
|
prefix_len: 0,
|
||||||
adapter_id: None,
|
adapter_id: None,
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
|
@ -4,21 +4,22 @@ package generate.v3;
|
||||||
|
|
||||||
service TextGenerationService {
|
service TextGenerationService {
|
||||||
/// Model Info
|
/// Model Info
|
||||||
rpc Info (InfoRequest) returns (InfoResponse) {}
|
rpc Info(InfoRequest) returns (InfoResponse) {}
|
||||||
/// Service discovery
|
/// Service discovery
|
||||||
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}
|
rpc ServiceDiscovery(ServiceDiscoveryRequest)
|
||||||
|
returns (ServiceDiscoveryResponse) {}
|
||||||
/// Empties batch cache
|
/// Empties batch cache
|
||||||
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
|
rpc ClearCache(ClearCacheRequest) returns (ClearCacheResponse);
|
||||||
/// Remove requests from a cached batch
|
/// Remove requests from a cached batch
|
||||||
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
|
rpc FilterBatch(FilterBatchRequest) returns (FilterBatchResponse);
|
||||||
/// Warmup the model and compute max cache size
|
/// Warmup the model and compute max cache size
|
||||||
rpc Warmup (WarmupRequest) returns (WarmupResponse);
|
rpc Warmup(WarmupRequest) returns (WarmupResponse);
|
||||||
/// Prefill batch and decode first token
|
/// Prefill batch and decode first token
|
||||||
rpc Prefill (PrefillRequest) returns (PrefillResponse);
|
rpc Prefill(PrefillRequest) returns (PrefillResponse);
|
||||||
/// Decode token for a list of prefilled batches
|
/// Decode token for a list of prefilled batches
|
||||||
rpc Decode (DecodeRequest) returns (DecodeResponse);
|
rpc Decode(DecodeRequest) returns (DecodeResponse);
|
||||||
/// Health check
|
/// Health check
|
||||||
rpc Health (HealthRequest) returns (HealthResponse);
|
rpc Health(HealthRequest) returns (HealthResponse);
|
||||||
}
|
}
|
||||||
|
|
||||||
message HealthRequest {}
|
message HealthRequest {}
|
||||||
|
@ -68,9 +69,7 @@ message InputChunk {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
message Input {
|
message Input { repeated InputChunk chunks = 1; }
|
||||||
repeated InputChunk chunks = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
enum GrammarType {
|
enum GrammarType {
|
||||||
GRAMMAR_TYPE_NONE = 0;
|
GRAMMAR_TYPE_NONE = 0;
|
||||||
|
@ -136,6 +135,8 @@ message Request {
|
||||||
repeated uint32 slots = 10;
|
repeated uint32 slots = 10;
|
||||||
/// LORA adapter index
|
/// LORA adapter index
|
||||||
optional string adapter_id = 11;
|
optional string adapter_id = 11;
|
||||||
|
/// Prefix length that can be retrieved from the KV cache.
|
||||||
|
uint32 prefix_len = 12;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Batch {
|
message Batch {
|
||||||
|
@ -214,7 +215,6 @@ message FilterBatchResponse {
|
||||||
CachedBatch batch = 1;
|
CachedBatch batch = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
message PrefillRequest {
|
message PrefillRequest {
|
||||||
/// Batch
|
/// Batch
|
||||||
Batch batch = 1;
|
Batch batch = 1;
|
||||||
|
|
|
@ -11,6 +11,7 @@ use rand::{thread_rng, Rng};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::io::Cursor;
|
use std::io::Cursor;
|
||||||
use std::iter;
|
use std::iter;
|
||||||
|
use std::sync::Arc;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
|
@ -115,13 +116,14 @@ impl Validation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
#[instrument(skip(self, inputs))]
|
#[instrument(skip(self, inputs))]
|
||||||
async fn validate_input(
|
async fn validate_input(
|
||||||
&self,
|
&self,
|
||||||
inputs: String,
|
inputs: String,
|
||||||
truncate: Option<usize>,
|
truncate: Option<usize>,
|
||||||
max_new_tokens: Option<u32>,
|
max_new_tokens: Option<u32>,
|
||||||
) -> Result<(Vec<Chunk>, usize, u32), ValidationError> {
|
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
|
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
|
||||||
// Create response channel
|
// Create response channel
|
||||||
|
@ -156,8 +158,10 @@ impl Validation {
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let input_ids = encoding.get_ids()[..input_length].to_owned();
|
||||||
|
|
||||||
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
|
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
|
||||||
Ok((inputs, input_length, max_new_tokens))
|
Ok((inputs, Some(input_ids), input_length, max_new_tokens))
|
||||||
}
|
}
|
||||||
// Return inputs without validation
|
// Return inputs without validation
|
||||||
else {
|
else {
|
||||||
|
@ -180,7 +184,12 @@ impl Validation {
|
||||||
input_length = input_length.saturating_sub(max_new_tokens as usize);
|
input_length = input_length.saturating_sub(max_new_tokens as usize);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok((vec![Chunk::Text(inputs)], input_length, max_new_tokens))
|
Ok((
|
||||||
|
vec![Chunk::Text(inputs)],
|
||||||
|
None,
|
||||||
|
input_length,
|
||||||
|
max_new_tokens,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -314,7 +323,7 @@ impl Validation {
|
||||||
.unwrap_or(Ok(None))?;
|
.unwrap_or(Ok(None))?;
|
||||||
|
|
||||||
// Validate inputs
|
// Validate inputs
|
||||||
let (inputs, input_length, max_new_tokens) = self
|
let (inputs, input_ids, input_length, max_new_tokens) = self
|
||||||
.validate_input(request.inputs, truncate, max_new_tokens)
|
.validate_input(request.inputs, truncate, max_new_tokens)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
@ -391,6 +400,7 @@ impl Validation {
|
||||||
|
|
||||||
Ok(ValidGenerateRequest {
|
Ok(ValidGenerateRequest {
|
||||||
inputs,
|
inputs,
|
||||||
|
input_ids: input_ids.map(Arc::new),
|
||||||
decoder_input_details,
|
decoder_input_details,
|
||||||
input_length: input_length as u32,
|
input_length: input_length as u32,
|
||||||
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
||||||
|
@ -707,6 +717,7 @@ pub struct ValidStoppingParameters {
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct ValidGenerateRequest {
|
pub struct ValidGenerateRequest {
|
||||||
pub inputs: Vec<Chunk>,
|
pub inputs: Vec<Chunk>,
|
||||||
|
pub input_ids: Option<Arc<Vec<u32>>>,
|
||||||
pub input_length: u32,
|
pub input_length: u32,
|
||||||
pub truncate: u32,
|
pub truncate: u32,
|
||||||
pub decoder_input_details: bool,
|
pub decoder_input_details: bool,
|
||||||
|
|
|
@ -6,7 +6,12 @@ from .common import Seqlen
|
||||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
from .cuda import (
|
||||||
|
attention,
|
||||||
|
paged_attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
SUPPORTS_WINDOWING,
|
||||||
|
)
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||||
elif SYSTEM == "ipex":
|
elif SYSTEM == "ipex":
|
||||||
|
|
|
@ -221,9 +221,11 @@ SUPPORTS_WINDOWING = V2
|
||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q,
|
q: torch.Tensor,
|
||||||
k,
|
k: torch.Tensor,
|
||||||
v,
|
v: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
|
@ -231,14 +233,15 @@ if ATTENTION == "flashinfer":
|
||||||
causal=True,
|
causal=True,
|
||||||
softcap=0.0,
|
softcap=0.0,
|
||||||
):
|
):
|
||||||
from text_generation_server.layers.attention.flash_infer import prefill_state
|
assert window_size_left == -1, "Windowing is not supported with flash infer"
|
||||||
|
from text_generation_server.layers.attention.flash_infer import (
|
||||||
|
prefill_with_paged_kv_state,
|
||||||
|
)
|
||||||
|
|
||||||
return prefill_state.get().forward(
|
return prefill_with_paged_kv_state.get().forward(
|
||||||
q,
|
q.contiguous(),
|
||||||
k,
|
|
||||||
v,
|
|
||||||
causal=causal,
|
causal=causal,
|
||||||
window_left=window_size_left,
|
paged_kv_cache=(key_cache, value_cache),
|
||||||
logits_soft_cap=softcap,
|
logits_soft_cap=softcap,
|
||||||
sm_scale=softmax_scale,
|
sm_scale=softmax_scale,
|
||||||
)
|
)
|
||||||
|
@ -249,6 +252,8 @@ elif V2:
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
|
@ -289,6 +294,8 @@ else:
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
|
|
|
@ -9,6 +9,10 @@ prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = Con
|
||||||
"prefill_state"
|
"prefill_state"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prefill_with_paged_kv_state: ContextVar[
|
||||||
|
flashinfer.BatchPrefillWithPagedKVCacheWrapper
|
||||||
|
] = ContextVar("prefill_with_paged_kv_state")
|
||||||
|
|
||||||
decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
|
decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
|
||||||
"decode_state"
|
"decode_state"
|
||||||
)
|
)
|
||||||
|
@ -24,6 +28,78 @@ def get_workspace(device):
|
||||||
return workspace
|
return workspace
|
||||||
|
|
||||||
|
|
||||||
|
def create_prefill_with_paged_kv_state(
|
||||||
|
*,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
"""Create a prefill state that uses the KV cache."""
|
||||||
|
workspace_buffer = get_workspace(device)
|
||||||
|
return flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||||
|
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def use_prefill_with_paged_kv_state(
|
||||||
|
*,
|
||||||
|
state: flashinfer.BatchPrefillWithPagedKVCacheWrapper,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
page_size: int,
|
||||||
|
query_dtype: str = "float16",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Context manager to set the active flashinfer prefill state to the given
|
||||||
|
`state` and parameters. This state will be used by all calls to the
|
||||||
|
`attention` function while the context manager is active.
|
||||||
|
"""
|
||||||
|
|
||||||
|
indptr = torch.zeros(
|
||||||
|
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
# Round up to page size and then calculate the cumulative sum to get
|
||||||
|
# the indices into the block table.
|
||||||
|
torch.add(input_lengths, page_size - 1, out=indptr[1:])
|
||||||
|
indptr[1:].div_(page_size, rounding_mode="floor")
|
||||||
|
indptr[1:].cumsum_(-1)
|
||||||
|
|
||||||
|
# Get the lengths of the last page in a block.
|
||||||
|
if page_size == 1:
|
||||||
|
last_page_len = torch.ones(
|
||||||
|
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
last_page_len = torch.empty(
|
||||||
|
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
|
||||||
|
)
|
||||||
|
torch.sub(input_lengths, 1, out=last_page_len)
|
||||||
|
last_page_len.remainder_(page_size)
|
||||||
|
last_page_len += 1
|
||||||
|
|
||||||
|
token = prefill_with_paged_kv_state.set(state)
|
||||||
|
try:
|
||||||
|
state.begin_forward(
|
||||||
|
qo_indptr=cu_seqlens,
|
||||||
|
paged_kv_indptr=indptr,
|
||||||
|
paged_kv_indices=block_tables,
|
||||||
|
paged_kv_last_page_len=last_page_len,
|
||||||
|
num_qo_heads=num_heads,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
head_dim=head_size,
|
||||||
|
q_data_type=query_dtype,
|
||||||
|
page_size=page_size,
|
||||||
|
)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
state.end_forward()
|
||||||
|
if token is not None:
|
||||||
|
prefill_with_paged_kv_state.reset(token)
|
||||||
|
|
||||||
|
|
||||||
def create_prefill_state(
|
def create_prefill_state(
|
||||||
*,
|
*,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
|
|
@ -298,6 +298,8 @@ class FlashCohereAttention(torch.nn.Module):
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -337,6 +337,8 @@ class DbrxAttention(torch.nn.Module):
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -365,6 +365,8 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -238,6 +238,8 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -232,6 +232,8 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -232,6 +232,8 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
@ -220,6 +221,8 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -219,6 +219,8 @@ class MistralAttention(torch.nn.Module):
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -276,6 +276,8 @@ class MixtralAttention(torch.nn.Module):
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -173,6 +173,8 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
qkv[:, 1],
|
qkv[:, 1],
|
||||||
qkv[:, 2],
|
qkv[:, 2],
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -194,6 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -137,6 +137,8 @@ class Qwen2Attention(torch.nn.Module):
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -208,6 +208,8 @@ class FlashRWAttention(torch.nn.Module):
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -326,6 +328,8 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=2, index=0),
|
torch.select(kv, dim=2, index=0),
|
||||||
torch.select(kv, dim=2, index=1),
|
torch.select(kv, dim=2, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -293,6 +293,8 @@ class FlashMQAttention(torch.nn.Module):
|
||||||
query,
|
query,
|
||||||
torch.select(key_value, dim=1, index=0),
|
torch.select(key_value, dim=1, index=0),
|
||||||
torch.select(key_value, dim=1, index=1),
|
torch.select(key_value, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -242,6 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -20,6 +20,9 @@ from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, D
|
||||||
|
|
||||||
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
|
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
|
||||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||||
|
from text_generation_server.layers.attention.flash_infer import (
|
||||||
|
create_prefill_with_paged_kv_state,
|
||||||
|
)
|
||||||
from text_generation_server.utils.chunks import concat_text_chunks
|
from text_generation_server.utils.chunks import concat_text_chunks
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
|
@ -43,6 +46,7 @@ from text_generation_server.models.globals import (
|
||||||
ATTENTION,
|
ATTENTION,
|
||||||
BLOCK_SIZE,
|
BLOCK_SIZE,
|
||||||
CUDA_GRAPHS,
|
CUDA_GRAPHS,
|
||||||
|
PREFIX_CACHING,
|
||||||
get_adapter_to_index,
|
get_adapter_to_index,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
|
@ -108,6 +112,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
block_tables_tensor: torch.Tensor
|
block_tables_tensor: torch.Tensor
|
||||||
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
||||||
slots: torch.Tensor
|
slots: torch.Tensor
|
||||||
|
# size [b], containing the number of blocks that can be retrieved from the cache
|
||||||
|
prefix_lens: List[int]
|
||||||
|
prefix_lens_tensor: torch.Tensor
|
||||||
|
|
||||||
max_seqlen: int
|
max_seqlen: int
|
||||||
|
|
||||||
|
@ -116,6 +123,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
prefill_next_token_indices: Optional[torch.tensor]
|
prefill_next_token_indices: Optional[torch.tensor]
|
||||||
prefill_cu_outlens: Optional[List[int]]
|
prefill_cu_outlens: Optional[List[int]]
|
||||||
|
|
||||||
|
# Prefixes
|
||||||
|
prefix_ids: List[List[int]]
|
||||||
|
|
||||||
# All tokens
|
# All tokens
|
||||||
all_input_ids: List[List[int]]
|
all_input_ids: List[List[int]]
|
||||||
all_input_ids_tensor: torch.Tensor
|
all_input_ids_tensor: torch.Tensor
|
||||||
|
@ -183,6 +193,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
|
prefix_ids = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
all_prefill_logprobs = True
|
all_prefill_logprobs = True
|
||||||
|
@ -200,7 +211,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_length = 0
|
cumulative_length = 0
|
||||||
cumulative_max_length = 0
|
cumulative_slot_tokens = 0
|
||||||
prefill_out_cumulative_length = 0
|
prefill_out_cumulative_length = 0
|
||||||
|
|
||||||
num_blocks = 0
|
num_blocks = 0
|
||||||
|
@ -210,6 +221,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
block_tables = []
|
block_tables = []
|
||||||
slots = []
|
slots = []
|
||||||
|
prefix_lens = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
for i, (r, tokenized_input) in enumerate(
|
for i, (r, tokenized_input) in enumerate(
|
||||||
|
@ -225,6 +237,19 @@ class FlashCausalLMBatch(Batch):
|
||||||
):
|
):
|
||||||
tokenized_input = tokenized_input[1:]
|
tokenized_input = tokenized_input[1:]
|
||||||
|
|
||||||
|
orig_input_length = len(tokenized_input)
|
||||||
|
|
||||||
|
if PREFIX_CACHING:
|
||||||
|
prefix_len = r.prefix_len
|
||||||
|
if prefix_len == orig_input_length:
|
||||||
|
assert prefix_len > 0
|
||||||
|
prefix_len -= 1
|
||||||
|
else:
|
||||||
|
prefix_len = 0
|
||||||
|
|
||||||
|
prefix_ids.append(tokenized_input[:prefix_len])
|
||||||
|
tokenized_input = tokenized_input[prefix_len:]
|
||||||
|
|
||||||
input_length = len(tokenized_input)
|
input_length = len(tokenized_input)
|
||||||
input_lengths.append(input_length)
|
input_lengths.append(input_length)
|
||||||
|
|
||||||
|
@ -234,7 +259,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
all_input_ids.append(tokenized_input)
|
all_input_ids.append(tokenized_input)
|
||||||
|
|
||||||
# Position ids
|
# Position ids
|
||||||
request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
|
request_position_ids = torch.arange(
|
||||||
|
prefix_len, orig_input_length, dtype=torch.int32
|
||||||
|
)
|
||||||
position_ids.append(request_position_ids)
|
position_ids.append(request_position_ids)
|
||||||
|
|
||||||
# Add cumulative lengths of all previous inputs
|
# Add cumulative lengths of all previous inputs
|
||||||
|
@ -258,11 +285,17 @@ class FlashCausalLMBatch(Batch):
|
||||||
# Remove one as the first token des not have a past
|
# Remove one as the first token des not have a past
|
||||||
speculative_length = get_speculate()
|
speculative_length = get_speculate()
|
||||||
speculative_length = 0 if speculative_length is None else speculative_length
|
speculative_length = 0 if speculative_length is None else speculative_length
|
||||||
total_tokens = input_length + max_new_tokens - 1 + speculative_length
|
|
||||||
|
# Tokens that need to be mapped to blocks.
|
||||||
|
block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length
|
||||||
|
|
||||||
|
# Tokens that need to be mapped to slots. We don't need slots for the
|
||||||
|
# cached prefix (if present).
|
||||||
|
slot_tokens = input_length + max_new_tokens - 1 + speculative_length
|
||||||
|
|
||||||
# blocks and slots can be empty (for example in warmup)
|
# blocks and slots can be empty (for example in warmup)
|
||||||
if not r.blocks:
|
if not r.blocks:
|
||||||
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
needed_blocks = math.ceil(block_tokens / BLOCK_SIZE)
|
||||||
request_blocks = [
|
request_blocks = [
|
||||||
b for b in range(num_blocks, num_blocks + needed_blocks)
|
b for b in range(num_blocks, num_blocks + needed_blocks)
|
||||||
]
|
]
|
||||||
|
@ -273,16 +306,20 @@ class FlashCausalLMBatch(Batch):
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
request_blocks = r.blocks
|
request_blocks = r.blocks
|
||||||
request_slots = r.slots
|
request_slots = r.slots[
|
||||||
|
prefix_len: #: orig_input_length + max_new_tokens + speculative_length
|
||||||
|
]
|
||||||
|
|
||||||
block_tables.append(request_blocks)
|
block_tables.append(request_blocks)
|
||||||
slots.extend(request_slots[:total_tokens])
|
|
||||||
|
slots.extend(request_slots)
|
||||||
|
prefix_lens.append(prefix_len)
|
||||||
num_blocks += len(request_blocks)
|
num_blocks += len(request_blocks)
|
||||||
start_slots.append(cumulative_max_length)
|
start_slots.append(cumulative_slot_tokens)
|
||||||
|
|
||||||
request_slot_indices = torch.arange(
|
request_slot_indices = torch.arange(
|
||||||
cumulative_max_length,
|
cumulative_slot_tokens,
|
||||||
cumulative_max_length + input_length,
|
cumulative_slot_tokens + input_length,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
slot_indices.append(request_slot_indices)
|
slot_indices.append(request_slot_indices)
|
||||||
|
@ -318,7 +355,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
# Update
|
# Update
|
||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
cumulative_max_length += total_tokens
|
cumulative_slot_tokens += slot_tokens
|
||||||
max_seqlen = max(max_seqlen, input_length)
|
max_seqlen = max(max_seqlen, input_length)
|
||||||
max_blocks = max(max_blocks, len(request_blocks))
|
max_blocks = max(max_blocks, len(request_blocks))
|
||||||
max_length = max(
|
max_length = max(
|
||||||
|
@ -395,12 +432,14 @@ class FlashCausalLMBatch(Batch):
|
||||||
)
|
)
|
||||||
|
|
||||||
slots = torch.tensor(slots, dtype=torch.int64, device=device)
|
slots = torch.tensor(slots, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
block_tables_tensor = torch.zeros(
|
block_tables_tensor = torch.zeros(
|
||||||
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
|
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
|
||||||
)
|
)
|
||||||
for i, request_blocks in enumerate(block_tables):
|
for i, request_blocks in enumerate(block_tables):
|
||||||
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
||||||
block_tables_tensor = block_tables_tensor.to(device)
|
block_tables_tensor = block_tables_tensor.to(device)
|
||||||
|
prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
|
@ -415,6 +454,8 @@ class FlashCausalLMBatch(Batch):
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
block_tables_tensor=block_tables_tensor,
|
block_tables_tensor=block_tables_tensor,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
|
prefix_lens=prefix_lens,
|
||||||
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=prefill_head_indices,
|
prefill_head_indices=prefill_head_indices,
|
||||||
prefill_next_token_indices=prefill_next_token_indices,
|
prefill_next_token_indices=prefill_next_token_indices,
|
||||||
|
@ -425,6 +466,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
|
prefix_ids=prefix_ids,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
|
@ -480,6 +522,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
start_slots = []
|
start_slots = []
|
||||||
block_tables = []
|
block_tables = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
|
prefix_ids = []
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
|
@ -506,6 +549,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
max_seqlen = max(max_seqlen, request_input_length)
|
max_seqlen = max(max_seqlen, request_input_length)
|
||||||
|
|
||||||
all_input_ids.append(self.all_input_ids[idx])
|
all_input_ids.append(self.all_input_ids[idx])
|
||||||
|
prefix_ids.append(self.prefix_ids[idx])
|
||||||
|
|
||||||
input_lengths.append(request_input_length)
|
input_lengths.append(request_input_length)
|
||||||
prefix_offsets.append(self.prefix_offsets[idx])
|
prefix_offsets.append(self.prefix_offsets[idx])
|
||||||
|
@ -591,6 +635,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
|
prefix_ids=prefix_ids,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
|
@ -651,6 +696,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
|
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
|
||||||
(total_batch_size, max_blocks)
|
(total_batch_size, max_blocks)
|
||||||
)
|
)
|
||||||
|
prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size)
|
||||||
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
||||||
(total_batch_size, max_length)
|
(total_batch_size, max_length)
|
||||||
)
|
)
|
||||||
|
@ -668,7 +714,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
start_slots = []
|
start_slots = []
|
||||||
block_tables = []
|
block_tables = []
|
||||||
|
prefix_lens = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
|
prefix_ids = []
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
|
@ -730,10 +778,14 @@ class FlashCausalLMBatch(Batch):
|
||||||
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
||||||
] = batch.block_tables_tensor[:, :max_blocks]
|
] = batch.block_tables_tensor[:, :max_blocks]
|
||||||
|
|
||||||
|
prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor
|
||||||
|
|
||||||
start_slots.append(batch.start_slots + cumulative_slots)
|
start_slots.append(batch.start_slots + cumulative_slots)
|
||||||
|
|
||||||
block_tables.extend(batch.block_tables)
|
block_tables.extend(batch.block_tables)
|
||||||
|
prefix_lens.extend(batch.prefix_lens)
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
|
prefix_ids.extend(batch.prefix_ids)
|
||||||
|
|
||||||
input_lengths.extend(batch.input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
prefix_offsets.extend(batch.prefix_offsets)
|
prefix_offsets.extend(batch.prefix_offsets)
|
||||||
|
@ -779,6 +831,8 @@ class FlashCausalLMBatch(Batch):
|
||||||
slot_indices=slot_indices,
|
slot_indices=slot_indices,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
block_tables_tensor=block_tables_tensor,
|
block_tables_tensor=block_tables_tensor,
|
||||||
|
prefix_lens=prefix_lens,
|
||||||
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=None,
|
prefill_head_indices=None,
|
||||||
|
@ -790,6 +844,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
|
prefix_ids=prefix_ids,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
|
@ -944,6 +999,9 @@ class FlashCausalLM(Model):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.prefill_state = create_prefill_state(device=device)
|
self.prefill_state = create_prefill_state(device=device)
|
||||||
|
self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
if not CUDA_GRAPHS:
|
if not CUDA_GRAPHS:
|
||||||
self.decode_state = create_decode_state(
|
self.decode_state = create_decode_state(
|
||||||
|
@ -1042,11 +1100,22 @@ class FlashCausalLM(Model):
|
||||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||||
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
||||||
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
input_lengths = [max_s] * bs
|
||||||
block_tables = (
|
prefix_lengths = [0] * bs
|
||||||
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
input_lengths_tensor = (
|
||||||
.repeat(bs)
|
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||||
.reshape((bs, max_bt))
|
)
|
||||||
|
prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||||
|
block_tables = torch.arange(
|
||||||
|
max_bt, dtype=torch.int32, device=self.device
|
||||||
|
).repeat(bs)
|
||||||
|
block_tables = block_tables.reshape((bs, max_bt))
|
||||||
|
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
prefix_lens=prefix_lengths,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cuda_graphs[bs] = {
|
self.cuda_graphs[bs] = {
|
||||||
|
@ -1055,9 +1124,9 @@ class FlashCausalLM(Model):
|
||||||
"kv_cache": self.kv_cache,
|
"kv_cache": self.kv_cache,
|
||||||
"block_tables": block_tables,
|
"block_tables": block_tables,
|
||||||
"slots": slots,
|
"slots": slots,
|
||||||
"input_lengths": input_lengths,
|
"input_lengths": input_lengths_tensor,
|
||||||
}
|
}
|
||||||
input_lengths_ = Seqlen(input_lengths=input_lengths)
|
input_lengths_ = Seqlen(input_lengths=input_lengths_tensor)
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
self.cuda_graphs[bs]["graph"] = graph
|
self.cuda_graphs[bs]["graph"] = graph
|
||||||
|
|
||||||
|
@ -1072,7 +1141,7 @@ class FlashCausalLM(Model):
|
||||||
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
|
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
|
||||||
state = create_decode_state_cuda_graphs(
|
state = create_decode_state_cuda_graphs(
|
||||||
device=input_ids.device,
|
device=input_ids.device,
|
||||||
block_tables=block_tables.view(-1),
|
block_tables=block_tables,
|
||||||
block_tables_ptr=block_tables_ptr,
|
block_tables_ptr=block_tables_ptr,
|
||||||
last_page_len=last_page_len,
|
last_page_len=last_page_len,
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
|
@ -1088,7 +1157,10 @@ class FlashCausalLM(Model):
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
|
input_lengths_tensor=input_lengths_tensor,
|
||||||
state=state,
|
state=state,
|
||||||
|
prefix_lens=prefix_lengths,
|
||||||
|
prefix_lens_tensor=prefix_lengths_tensor,
|
||||||
):
|
):
|
||||||
self.model.forward(
|
self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
@ -1106,7 +1178,7 @@ class FlashCausalLM(Model):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
input_lengths_tensor = Seqlen(input_lengths=input_lengths_tensor)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
@ -1114,7 +1186,7 @@ class FlashCausalLM(Model):
|
||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths_tensor,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
|
@ -1343,7 +1415,10 @@ class FlashCausalLM(Model):
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
input_lengths=input_lengths,
|
input_lengths=batch.input_lengths,
|
||||||
|
input_lengths_tensor=input_lengths,
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
prefix_lens_tensor=batch.prefix_lens_tensor,
|
||||||
):
|
):
|
||||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
|
@ -1367,19 +1442,32 @@ class FlashCausalLM(Model):
|
||||||
# Static inputs are potentially padded
|
# Static inputs are potentially padded
|
||||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=batch.input_lengths,
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
)
|
||||||
|
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||||
|
else:
|
||||||
cuda_graph["block_tables"][
|
cuda_graph["block_tables"][
|
||||||
: block_tables.shape[0], : block_tables.shape[1]
|
: block_tables.shape[0], : block_tables.shape[1]
|
||||||
] = block_tables
|
] = block_tables
|
||||||
cuda_graph["slots"].fill_(-1)
|
cuda_graph["slots"].fill_(-1)
|
||||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
cuda_graph["input_lengths"].zero_()
|
cuda_graph["input_lengths"].zero_()
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
||||||
|
input_lengths + batch.prefix_lens_tensor
|
||||||
|
)
|
||||||
|
|
||||||
state = cuda_graph.get("state")
|
state = cuda_graph.get("state")
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
input_lengths=input_lengths,
|
input_lengths=batch.input_lengths,
|
||||||
|
input_lengths_tensor=input_lengths,
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
prefix_lens_tensor=batch.prefix_lens_tensor,
|
||||||
state=state,
|
state=state,
|
||||||
):
|
):
|
||||||
# Replay the graph
|
# Replay the graph
|
||||||
|
@ -1578,6 +1666,7 @@ class FlashCausalLM(Model):
|
||||||
batch.read_offsets,
|
batch.read_offsets,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
|
batch.prefix_ids,
|
||||||
batch.next_token_chooser.do_sample,
|
batch.next_token_chooser.do_sample,
|
||||||
batch.next_token_chooser.seeds,
|
batch.next_token_chooser.seeds,
|
||||||
batch.top_n_tokens,
|
batch.top_n_tokens,
|
||||||
|
@ -1595,6 +1684,7 @@ class FlashCausalLM(Model):
|
||||||
read_offset,
|
read_offset,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
|
prefix_ids,
|
||||||
do_sample,
|
do_sample,
|
||||||
seed,
|
seed,
|
||||||
top_n_tokens,
|
top_n_tokens,
|
||||||
|
@ -1669,18 +1759,18 @@ class FlashCausalLM(Model):
|
||||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||||
|
|
||||||
# Remove generated token to only have prefill and add nan for first prompt token
|
# Remove generated token to only have prefill and add nan for first prompt token
|
||||||
request_prefill_logprobs = [float("nan")] + prefill_logprobs[
|
request_prefill_logprobs = (
|
||||||
out_start_index : out_end_index - 1
|
[float("nan")] * (len(prefix_ids) + 1)
|
||||||
]
|
) + prefill_logprobs[out_start_index : out_end_index - 1]
|
||||||
prefill_token_ids = all_input_ids[:-1]
|
prefill_token_ids = all_input_ids[:-1]
|
||||||
prefill_texts = self.tokenizer.batch_decode(
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
prefill_token_ids,
|
prefix_ids + prefill_token_ids,
|
||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
prefill_tokens = Tokens(
|
prefill_tokens = Tokens(
|
||||||
prefill_token_ids,
|
prefix_ids + prefill_token_ids,
|
||||||
request_prefill_logprobs,
|
request_prefill_logprobs,
|
||||||
prefill_texts,
|
prefill_texts,
|
||||||
is_special=[],
|
is_special=[],
|
||||||
|
@ -1762,7 +1852,10 @@ class FlashCausalLM(Model):
|
||||||
*,
|
*,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: List[int],
|
||||||
|
input_lengths_tensor: torch.Tensor,
|
||||||
|
prefix_lens: List[int],
|
||||||
|
prefix_lens_tensor: torch.Tensor,
|
||||||
state: Optional[Any] = None,
|
state: Optional[Any] = None,
|
||||||
) -> ContextManager:
|
) -> ContextManager:
|
||||||
if ATTENTION != "flashinfer":
|
if ATTENTION != "flashinfer":
|
||||||
|
@ -1771,9 +1864,30 @@ class FlashCausalLM(Model):
|
||||||
from text_generation_server.layers.attention.flash_infer import (
|
from text_generation_server.layers.attention.flash_infer import (
|
||||||
use_decode_state,
|
use_decode_state,
|
||||||
use_prefill_state,
|
use_prefill_state,
|
||||||
|
use_prefill_with_paged_kv_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens)
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
|
if True: # has_prefix_lens:
|
||||||
|
return use_prefill_with_paged_kv_state(
|
||||||
|
state=(
|
||||||
|
state if state is not None else self.prefill_with_paged_kv_state
|
||||||
|
),
|
||||||
|
block_tables=block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
prefix_lens=prefix_lens,
|
||||||
|
),
|
||||||
|
cu_seqlens=cu_seqlen_prefill,
|
||||||
|
input_lengths=input_lengths_tensor + prefix_lens_tensor,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
head_size=self.head_size,
|
||||||
|
page_size=BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
else:
|
||||||
return use_prefill_state(
|
return use_prefill_state(
|
||||||
state=state if state is not None else self.prefill_state,
|
state=state if state is not None else self.prefill_state,
|
||||||
cu_seqlens=cu_seqlen_prefill,
|
cu_seqlens=cu_seqlen_prefill,
|
||||||
|
@ -1782,13 +1896,33 @@ class FlashCausalLM(Model):
|
||||||
head_size=self.head_size,
|
head_size=self.head_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert input_lengths is not None
|
assert input_lengths_tensor is not None
|
||||||
return use_decode_state(
|
return use_decode_state(
|
||||||
state=state if state is not None else self.decode_state,
|
state=state if state is not None else self.decode_state,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths_tensor + prefix_lens_tensor,
|
||||||
block_tables=block_tables.view(-1),
|
block_tables=block_tables,
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
head_size=self.head_size,
|
head_size=self.head_size,
|
||||||
page_size=BLOCK_SIZE,
|
page_size=BLOCK_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def block_tables_to_ragged(
|
||||||
|
*, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int]
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Convert block table to ragged format compatible with FlashInfer."""
|
||||||
|
assert len(input_lengths) == len(prefix_lens)
|
||||||
|
|
||||||
|
total_len = sum(input_lengths) + sum(prefix_lens)
|
||||||
|
block_tables_ragged = torch.empty(
|
||||||
|
total_len, dtype=torch.int32, device=block_tables.device
|
||||||
|
)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)):
|
||||||
|
seq_len = prefix_len + input_length
|
||||||
|
block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]
|
||||||
|
offset += seq_len
|
||||||
|
|
||||||
|
return block_tables_ragged
|
||||||
|
|
|
@ -5,17 +5,29 @@ from typing import Dict, Optional
|
||||||
|
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
ATTENTION = os.getenv("ATTENTION", "paged")
|
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", False)
|
||||||
|
log_master(logger.info, f"Using Attention = {PREFIX_CACHING}")
|
||||||
|
|
||||||
|
ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged")
|
||||||
_expected = {"paged", "flashdecoding", "flashinfer"}
|
_expected = {"paged", "flashdecoding", "flashinfer"}
|
||||||
assert (
|
assert (
|
||||||
ATTENTION in _expected
|
ATTENTION in _expected
|
||||||
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
||||||
log_master(logger.info, f"Using Attention = {ATTENTION}")
|
log_master(logger.info, f"Using Attention = {ATTENTION}")
|
||||||
|
|
||||||
|
if PREFIX_CACHING and ATTENTION != "flashinfer":
|
||||||
|
raise RuntimeError("Prefix caching is only supported with flashinfer")
|
||||||
|
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
# This is overridden by the cli
|
# This is overridden by the cli
|
||||||
BLOCK_SIZE: int = 256 if ATTENTION == "flashdecoding" else 16
|
|
||||||
|
|
||||||
|
BLOCK_SIZE: int
|
||||||
|
if ATTENTION == "flashdecoding":
|
||||||
|
BLOCK_SIZE = 256
|
||||||
|
elif ATTENTION == "flashinfer":
|
||||||
|
BLOCK_SIZE = 1
|
||||||
|
else:
|
||||||
|
BLOCK_SIZE = 16
|
||||||
|
|
||||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||||
if cuda_graphs is not None:
|
if cuda_graphs is not None:
|
||||||
|
|
Loading…
Reference in New Issue