Lots of improvements (Still 2 allocators) (#2449)

* Making prefix/flashinfer the default and testing the full release tests.

* Include flashinfer in the docker.

* Using prebuilt.

* Allowing window_left_size (dummy version).

* Disabling flashinfer/prefix caching on odd head_dim

* Disable prefix caching for lora.

* More specific codes.

* Update lock

* Updating integration tests with new values with FI/FD.

Remove paged as a default too, and using FD everywhere.

* Update cargo lock ?

* Upgrade to 1.80 because of bitstream...

* Everywhere 1.80

* Forgot last default place.

* Apply suggestions from code review

Co-authored-by: drbh <david.richard.holtz@gmail.com>

* Updated flake lock

* Tmp

* Upgrade resolution system for less errors in resolution.

* Remove lambda for cleaner function.

* Handling debugger.

* OVerride the env in server tests.

* Is this enough to make it work ?

* This seems to be working.

* Downgrade some logs.

* Fixing the default for vlm.

* Don't enable prefix caching on VLM just yet.

* Change `add_special_tokens` in order to have the correct tokens for chat
input and not (since it's super important with the prefixing now)

* Fixing prefix caching for flashdecoding.

* Update all models.

* Fixed flashinfer version.

* add_special_tokens is internal only

* Fixing seqlen with the new vlms.

* Fixing the issue with `add_special_tokens` not being passed around.

* Fixing the test.

* Removing encoder_decoder (seq2seq).

* Update the chat test.

* Fixing the batching tokenization in flash causal lm.

* Truncating left for radix purposes.

* Oops this doesn't belong here.

* Put back default pure shell.

* Update server tests

- Default to throughput test in k6
- Use TGI_WIGGLE_ROOM to adjust wiggle room

* Only n_heads / process_group.size() are necessary.

* Revert the integrationt tests change (seem linked to head_size
modification).

* Adding error message when assert is violated.

* Fixing the free algorithm to handle times where the common prefix is
smaller.

* Apply suggestions from code review

Co-authored-by: OlivierDehaene <olivier@huggingface.co>

* Update server/text_generation_server/layers/attention/common.py

Co-authored-by: OlivierDehaene <olivier@huggingface.co>

* Fix disabling prefix caching - Fix windowing checks.

* Revert the Cohere tokenizer change (for now using a revision instead).

* Fmt.

---------

Co-authored-by: drbh <david.richard.holtz@gmail.com>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
This commit is contained in:
Nicolas Patry 2024-08-29 16:29:01 +02:00 committed by GitHub
parent 4e821c003a
commit e415b690a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
59 changed files with 1234 additions and 934 deletions

View File

@ -35,7 +35,7 @@ jobs:
with: with:
# Released on: 02 May, 2024 # Released on: 02 May, 2024
# https://releases.rs/docs/1.78.0/ # https://releases.rs/docs/1.78.0/
toolchain: 1.79.0 toolchain: 1.80.0
override: true override: true
components: rustfmt, clippy components: rustfmt, clippy
- name: Install Protoc - name: Install Protoc

437
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,5 @@
# Rust builder # Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
WORKDIR /usr/src WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -184,6 +184,12 @@ WORKDIR /usr/src
COPY server/Makefile-selective-scan Makefile COPY server/Makefile-selective-scan Makefile
RUN make build-all RUN make build-all
# Build flashinfer
FROM kernel-builder AS flashinfer-builder
WORKDIR /usr/src
COPY server/Makefile-flashinfer Makefile
RUN make install-flashinfer
# Text Generation Inference base image # Text Generation Inference base image
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS base FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS base
@ -236,6 +242,7 @@ COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/c
# Copy build artifacts from mamba builder # Copy build artifacts from mamba builder
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
COPY --from=flashinfer-builder /opt/conda/lib/python3.10/site-packages/flashinfer/ /opt/conda/lib/python3.10/site-packages/flashinfer/
# Install flash-attention dependencies # Install flash-attention dependencies
RUN pip install einops --no-cache-dir RUN pip install einops --no-cache-dir

View File

@ -1,5 +1,5 @@
# Rust builder # Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
WORKDIR /usr/src WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse

View File

@ -1,6 +1,6 @@
ARG PLATFORM=xpu ARG PLATFORM=xpu
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
WORKDIR /usr/src WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse

View File

@ -153,6 +153,8 @@ impl Client {
}), }),
// We truncate the input on the server side to be sure that it has the correct size // We truncate the input on the server side to be sure that it has the correct size
truncate, truncate,
// Most request will have that
add_special_tokens: true,
// Blocks and slots will be set on the server side if we use paged attention // Blocks and slots will be set on the server side if we use paged attention
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],

View File

@ -221,6 +221,7 @@ impl Health for ShardedClient {
chunks: vec![Chunk::Text("liveness".into()).into()], chunks: vec![Chunk::Text("liveness".into()).into()],
}), }),
truncate: 10, truncate: 10,
add_special_tokens: true,
prefill_logprobs: false, prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 1.0, temperature: 1.0,

View File

@ -35,27 +35,15 @@ impl BackendV3 {
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
) -> Self { ) -> Self {
let prefix_caching = if let Ok(prefix_caching) = std::env::var("USE_PREFIX_CACHING") { let prefix_caching =
matches!(prefix_caching.as_str(), "true" | "1") std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var");
} else { let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
false let attention: String = std::env::var("ATTENTION").expect("attention env var");
};
let attention = if let Ok(attention) = std::env::var("ATTENTION") { let attention: Attention = attention
attention .parse()
.parse() .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")) let block_size = attention.block_size();
} else if prefix_caching {
Attention::FlashInfer
} else {
Attention::Paged
};
let block_size = if attention == Attention::FlashDecoding {
256
} else if attention == Attention::FlashInfer {
1
} else {
16
};
let queue = Queue::new( let queue = Queue::new(
requires_padding, requires_padding,

View File

@ -1,4 +1,4 @@
use std::{cmp::min, sync::Arc}; use std::sync::Arc;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use crate::radix::RadixAllocator; use crate::radix::RadixAllocator;
@ -137,7 +137,6 @@ pub trait Allocator {
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64); fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
} }
pub struct SimpleAllocator { pub struct SimpleAllocator {
free_blocks: Vec<u32>, free_blocks: Vec<u32>,
block_size: u32, block_size: u32,
@ -167,7 +166,7 @@ impl Allocator for SimpleAllocator {
None => (tokens, 1), None => (tokens, 1),
Some(window_size) => { Some(window_size) => {
let repeats = (tokens + window_size - 1) / window_size; let repeats = (tokens + window_size - 1) / window_size;
let tokens = min(tokens, window_size); let tokens = core::cmp::min(tokens, window_size);
(tokens, repeats as usize) (tokens, repeats as usize)
} }
}; };

View File

@ -149,6 +149,7 @@ impl Client {
requests.push(Request { requests.push(Request {
id: 0, id: 0,
inputs, inputs,
add_special_tokens: true,
input_chunks: Some(Input { input_chunks: Some(Input {
chunks: input_chunks, chunks: input_chunks,
}), }),

View File

@ -222,6 +222,7 @@ impl Health for ShardedClient {
chunks: vec![Chunk::Text("liveness".into()).into()], chunks: vec![Chunk::Text("liveness".into()).into()],
}), }),
truncate: 10, truncate: 10,
add_special_tokens: true,
prefill_logprobs: false, prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 1.0, temperature: 1.0,

View File

@ -383,6 +383,7 @@ impl State {
}), }),
inputs: entry.request.inputs.chunks_to_string(), inputs: entry.request.inputs.chunks_to_string(),
truncate: entry.request.truncate, truncate: entry.request.truncate,
add_special_tokens: entry.request.add_special_tokens,
parameters: Some(NextTokenChooserParameters::from( parameters: Some(NextTokenChooserParameters::from(
entry.request.parameters.clone(), entry.request.parameters.clone(),
)), )),
@ -517,6 +518,7 @@ mod tests {
inputs: vec![], inputs: vec![],
input_ids: Some(Arc::new(vec![])), input_ids: Some(Arc::new(vec![])),
input_length: 0, input_length: 0,
add_special_tokens: true,
truncate: 0, truncate: 0,
decoder_input_details: false, decoder_input_details: false,
parameters: ValidParameters { parameters: ValidParameters {

View File

@ -1,12 +1,10 @@
use crate::block_allocator::{Allocator, BlockAllocation};
use slotmap::{DefaultKey, SlotMap};
use std::{ use std::{
collections::{BTreeSet, HashMap}, collections::{BTreeSet, HashMap},
sync::Arc, sync::Arc,
}; };
use slotmap::{DefaultKey, SlotMap};
use crate::block_allocator::{Allocator, BlockAllocation};
pub struct RadixAllocator { pub struct RadixAllocator {
allocation_id: u64, allocation_id: u64,
@ -16,26 +14,26 @@ pub struct RadixAllocator {
/// Blocks that are immediately available for allocation. /// Blocks that are immediately available for allocation.
free_blocks: Vec<u32>, free_blocks: Vec<u32>,
#[allow(dead_code)]
// This isn't used because the prefix need to match without the windowing
// mecanism. This at worst is overallocating, not necessarily being wrong.
window_size: Option<u32>,
block_size: u32,
} }
impl RadixAllocator { impl RadixAllocator {
pub fn new(block_size: u32, n_blocks: u32, window_size: Option<u32>) -> Self { pub fn new(block_size: u32, n_blocks: u32, window_size: Option<u32>) -> Self {
assert_eq!(
block_size, 1,
"Radix tree allocator only works with block_size=1, was: {}",
block_size
);
if window_size.is_some() {
unimplemented!("Window size not supported in the prefix-caching block allocator yet");
}
RadixAllocator { RadixAllocator {
allocation_id: 0, allocation_id: 0,
allocations: HashMap::new(), allocations: HashMap::new(),
cache_blocks: RadixTrie::new(), cache_blocks: RadixTrie::new(block_size as usize),
// Block 0 is reserved for health checks. // Block 0 is reserved for health checks.
free_blocks: (1..n_blocks).collect(), free_blocks: (1..n_blocks).collect(),
window_size,
block_size,
} }
} }
@ -63,6 +61,7 @@ impl RadixAllocator {
} }
} }
// Allocator trait
impl Allocator for RadixAllocator { impl Allocator for RadixAllocator {
fn allocate( fn allocate(
&mut self, &mut self,
@ -86,10 +85,12 @@ impl Allocator for RadixAllocator {
.incref(prefix_node) .incref(prefix_node)
.expect("Failed to increment refcount"); .expect("Failed to increment refcount");
let prefix_len = blocks.len(); let prefix_len = blocks.len() * self.block_size as usize;
let suffix_len = tokens - prefix_len as u32; let suffix_len = tokens - prefix_len as u32;
match self.alloc_or_reclaim(suffix_len as usize) { let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size;
match self.alloc_or_reclaim(suffix_blocks as usize) {
Some(suffix_blocks) => blocks.extend(suffix_blocks), Some(suffix_blocks) => blocks.extend(suffix_blocks),
None => { None => {
self.cache_blocks self.cache_blocks
@ -100,7 +101,20 @@ impl Allocator for RadixAllocator {
} }
// 1:1 mapping of blocks and slots. // 1:1 mapping of blocks and slots.
let slots = blocks.clone(); let slots = if self.block_size == 1 {
blocks.clone()
} else {
let mut slots = Vec::with_capacity(blocks.len() * self.block_size as usize);
'slots: for block_id in &blocks {
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
slots.push(s);
if slots.len() as u32 == tokens {
break 'slots;
}
}
}
slots
};
let allocation = RadixAllocation { let allocation = RadixAllocation {
prefix_node, prefix_node,
@ -108,6 +122,8 @@ impl Allocator for RadixAllocator {
prefill_tokens: prefill_tokens.clone(), prefill_tokens: prefill_tokens.clone(),
}; };
tracing::debug!("Blocks {blocks:?}");
self.allocation_id += 1; self.allocation_id += 1;
self.allocations.insert(self.allocation_id, allocation); self.allocations.insert(self.allocation_id, allocation);
@ -136,27 +152,38 @@ impl Allocator for RadixAllocator {
// If there are prefill tokens that did not come from the cache, // If there are prefill tokens that did not come from the cache,
// add them to the cache. // add them to the cache.
if prefill_tokens.len() > allocation.cached_prefix_len { if prefill_tokens.len() > allocation.cached_prefix_len {
let prefix_len = self let aligned =
.cache_blocks (prefill_tokens.len() / self.block_size as usize) * self.block_size as usize;
.insert(prefill_tokens, &blocks[..prefill_tokens.len()]) if aligned > 0 {
// Unwrap, failing is a programming error. let prefix_len = self
.expect("Failed to store prefill tokens"); .cache_blocks
.insert(
// We can have a prefill with the following structure: &prefill_tokens[..aligned],
// &blocks[..aligned / self.block_size as usize],
// |---| From the prefix cache. )
// A B C D E F G // Unwrap, failing is a programming error.
//|--------| Found in the trie during insertion. .expect("Failed to store prefill tokens");
// // We can have a prefill with the following structure:
// This means that while processing this request there was a //
// partially overlapping request that had A..=E in its // |---| From the prefix cache.
// prefill. In this case we need to free the blocks D E. // A B C D E F G
self.free_blocks //|--------| Found in the trie during insertion.
.extend(&blocks[allocation.cached_prefix_len..prefix_len]); //
// This means that while processing this request there was a
// partially overlapping request that had A..=E in its
// prefill. In this case we need to free the blocks D E.
if prefix_len > allocation.cached_prefix_len {
self.free_blocks.extend(
&blocks[allocation.cached_prefix_len / self.block_size as usize
..prefix_len / self.block_size as usize],
);
}
}
} }
// Free non-prefill blocks. // Free non-prefill blocks.
self.free_blocks.extend(&blocks[prefill_tokens.len()..]); self.free_blocks
.extend(&blocks[prefill_tokens.len() / self.block_size as usize..]);
} else { } else {
self.free_blocks.extend(blocks); self.free_blocks.extend(blocks);
} }
@ -204,17 +231,14 @@ pub struct RadixTrie {
/// Time as a monotonically increating counter to avoid the system /// Time as a monotonically increating counter to avoid the system
/// call that a real time lookup would require. /// call that a real time lookup would require.
time: u64, time: u64,
}
impl Default for RadixTrie { /// All blocks need to be aligned with this
fn default() -> Self { block_size: usize,
Self::new()
}
} }
impl RadixTrie { impl RadixTrie {
/// Construct a new radix trie. /// Construct a new radix trie.
pub fn new() -> Self { pub fn new(block_size: usize) -> Self {
let root = TrieNode::new(vec![], vec![], 0, None); let root = TrieNode::new(vec![], vec![], 0, None);
let mut nodes = SlotMap::new(); let mut nodes = SlotMap::new();
let root = nodes.insert(root); let root = nodes.insert(root);
@ -223,13 +247,14 @@ impl RadixTrie {
nodes, nodes,
root, root,
time: 0, time: 0,
block_size,
} }
} }
/// Find the prefix of the given tokens. /// Find the prefix of the given tokens.
/// ///
/// The blocks corresponding to the part of the prefix that could be found /// The blocks corresponding to the part of the prefix that could be found
/// are writteng to `blocks`. The number of blocks is in `0..=tokens.len()`. /// are written to `blocks`. The number of blocks is in `0..=tokens.len()`.
/// Returns the identifier of the trie node that contains the longest /// Returns the identifier of the trie node that contains the longest
/// prefix. The node identifier can be used by callers to e.g. increase its /// prefix. The node identifier can be used by callers to e.g. increase its
/// reference count. /// reference count.
@ -247,8 +272,9 @@ impl RadixTrie {
if let Some(&child_id) = node.children.get(&key[0]) { if let Some(&child_id) = node.children.get(&key[0]) {
self.update_access_time(child_id); self.update_access_time(child_id);
let child = self.nodes.get(child_id).expect("Invalid child identifier"); let child = self.nodes.get(child_id).expect("Invalid child identifier");
let shared_prefix_len = child.key.shared_prefix_len(key); let shared_prefix_len = shared_prefix(&child.key, key, self.block_size);
blocks.extend(&child.blocks[..shared_prefix_len]); assert_eq!(shared_prefix_len % self.block_size, 0);
blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);
let key = &key[shared_prefix_len..]; let key = &key[shared_prefix_len..];
if !key.is_empty() { if !key.is_empty() {
@ -349,7 +375,8 @@ impl RadixTrie {
/// the first 10 elements of the tree **the blocks are not updated**. /// the first 10 elements of the tree **the blocks are not updated**.
pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result<usize, TrieError> { pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result<usize, TrieError> {
self.time += 1; self.time += 1;
self.insert_(self.root, tokens, blocks) let common = self.insert_(self.root, tokens, blocks)?;
Ok(common)
} }
/// Insertion worker. /// Insertion worker.
@ -363,7 +390,7 @@ impl RadixTrie {
// the part of the prefix that is already in the trie to detect // the part of the prefix that is already in the trie to detect
// mismatches. // mismatches.
if tokens.len() != blocks.len() { if tokens.len() != blocks.len() * self.block_size {
return Err(TrieError::BlockTokenCountMismatch); return Err(TrieError::BlockTokenCountMismatch);
} }
@ -374,10 +401,10 @@ impl RadixTrie {
.get_mut(child_id) .get_mut(child_id)
// Unwrap here, since failure is a bug. // Unwrap here, since failure is a bug.
.expect("Child node does not exist"); .expect("Child node does not exist");
let shared_prefix_len = child.key.shared_prefix_len(tokens); let shared_prefix_len = shared_prefix(&child.key, tokens, self.block_size);
// We are done, the prefix is already in the trie. // We are done, the prefix is already in the trie.
if shared_prefix_len == tokens.len() { if shared_prefix_len == tokens.len() || shared_prefix_len == 0 {
return Ok(shared_prefix_len); return Ok(shared_prefix_len);
} }
@ -387,7 +414,7 @@ impl RadixTrie {
+ self.insert_( + self.insert_(
child_id, child_id,
&tokens[shared_prefix_len..], &tokens[shared_prefix_len..],
&blocks[shared_prefix_len..], &blocks[shared_prefix_len / self.block_size..],
)?); )?);
} }
@ -396,7 +423,7 @@ impl RadixTrie {
// remainder of the prefix into the node again // remainder of the prefix into the node again
let child_id = self.split_node(child_id, shared_prefix_len); let child_id = self.split_node(child_id, shared_prefix_len);
let key = &tokens[shared_prefix_len..]; let key = &tokens[shared_prefix_len..];
let blocks = &blocks[shared_prefix_len..]; let blocks = &blocks[shared_prefix_len / self.block_size..];
Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?) Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?)
} else { } else {
self.add_node(node_id, tokens, blocks); self.add_node(node_id, tokens, blocks);
@ -550,34 +577,53 @@ impl TrieNode {
} }
} }
/// Helper trait to get the length of the shared prefix of two sequences. fn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize {
trait SharedPrefixLen { let full = left.iter().zip(right).take_while(|(a, b)| a == b).count();
fn shared_prefix_len(&self, other: &Self) -> usize; (full / block_size) * block_size
}
impl<T> SharedPrefixLen for [T]
where
T: PartialEq,
{
fn shared_prefix_len(&self, other: &Self) -> usize {
self.iter().zip(other).take_while(|(a, b)| a == b).count()
}
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::sync::Arc; use std::sync::Arc;
use crate::block_allocator::Allocator; use super::*;
use super::RadixAllocator; #[test]
fn allocator_block_size() {
let mut cache = RadixAllocator::new(2, 12, None);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
assert_eq!(allocation.prefix_len, 0);
cache.free(allocation.blocks.clone(), allocation.allocation_id);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
assert_eq!(allocation.prefix_len, 4);
}
#[test]
fn allocator_block_size_non_aligned() {
let mut cache = RadixAllocator::new(2, 12, None);
let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
assert_eq!(allocation.prefix_len, 0);
cache.free(allocation.blocks.clone(), allocation.allocation_id);
let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
assert_eq!(allocation.prefix_len, 2);
}
#[test] #[test]
fn allocator_reuses_prefixes() { fn allocator_reuses_prefixes() {
let mut cache = RadixAllocator::new(1, 12, None); let mut cache = RadixAllocator::new(1, 12, None);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
assert_eq!(allocation.slots, allocation.slots); assert_eq!(allocation.blocks, allocation.slots);
assert_eq!(allocation.prefix_len, 0); assert_eq!(allocation.prefix_len, 0);
cache.free(allocation.blocks.clone(), allocation.allocation_id); cache.free(allocation.blocks.clone(), allocation.allocation_id);
@ -666,7 +712,7 @@ mod tests {
#[test] #[test]
fn trie_insertions_have_correct_prefix_len() { fn trie_insertions_have_correct_prefix_len() {
let mut trie = super::RadixTrie::new(); let mut trie = RadixTrie::new(1);
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0); assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0);
@ -687,9 +733,33 @@ mod tests {
); );
} }
#[test]
fn trie_insertions_block_size() {
let mut trie = RadixTrie::new(2);
assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 0);
// Already exists.
// But needs to be block_size aligned
assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 4);
// Completely new at root-level
assert_eq!(trie.insert(&[1, 2, 3, 4], &[1, 2]).unwrap(), 0);
// Contains full prefix, but longer.
assert_eq!(trie.insert(&[0, 1, 2, 3, 4, 5], &[0, 1, 2]).unwrap(), 4);
// Shares partial prefix, we need a split.
assert_eq!(
trie.insert(&[0, 1, 3, 4, 5, 6, 7, 8], &[0, 1, 2, 3])
.unwrap(),
2
);
}
#[test] #[test]
fn trie_get_returns_correct_blocks() { fn trie_get_returns_correct_blocks() {
let mut trie = super::RadixTrie::new(); let mut trie = RadixTrie::new(1);
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(); trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(); trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
@ -723,7 +793,7 @@ mod tests {
#[test] #[test]
fn trie_evict_removes_correct_blocks() { fn trie_evict_removes_correct_blocks() {
let mut trie = super::RadixTrie::new(); let mut trie = RadixTrie::new(1);
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
.unwrap(); .unwrap();

View File

@ -148,6 +148,7 @@ async fn prefill(
}), }),
inputs: sequence.clone(), inputs: sequence.clone(),
truncate: sequence_length, truncate: sequence_length,
add_special_tokens: true,
parameters: Some(parameters.clone()), parameters: Some(parameters.clone()),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: decode_length, max_new_tokens: decode_length,

View File

@ -835,11 +835,11 @@
] ]
}, },
"locked": { "locked": {
"lastModified": 1724206841, "lastModified": 1724638882,
"narHash": "sha256-L8dKaX4T3k+TR2fEHCfGbH4UXdspovz/pj87iai9qmc=", "narHash": "sha256-ap2jIQi/FuUHR6HCht6ASWhoz8EiB99XmI8Esot38VE=",
"owner": "oxalica", "owner": "oxalica",
"repo": "rust-overlay", "repo": "rust-overlay",
"rev": "45e98fbd62c32e5927e952d2833fa1ba4fb35a61", "rev": "19b70f147b9c67a759e35824b241f1ed92e46694",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -5,7 +5,7 @@
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"message": { "message": {
"content": "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to explore in the middle of urban confines. In fact, typical times for humidity levels in Brooklyn include:\n\n- Early morning: 80-85% humidity, with occas", "content": "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to appreciate nature.\n\nIn terms of temperature, the warmest times of the year are from June to August, when average high temperatures typically range from around 73°F or 23°C",
"name": null, "name": null,
"role": "assistant", "role": "assistant",
"tool_calls": null "tool_calls": null
@ -13,14 +13,14 @@
"usage": null "usage": null
} }
], ],
"created": 1716553098, "created": 1724792495,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "chat.completion",
"system_fingerprint": "2.0.5-dev0-native", "system_fingerprint": "2.2.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 100, "completion_tokens": 100,
"prompt_tokens": 62, "prompt_tokens": 61,
"total_tokens": 162 "total_tokens": 161
} }
} }

View File

@ -8,11 +8,11 @@
"text": "\n" "text": "\n"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -23,11 +23,11 @@
"text": "\n" "text": "\n"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -38,11 +38,11 @@
"text": "\n" "text": "\n"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -53,11 +53,11 @@
"text": "hd" "text": "hd"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -68,11 +68,11 @@
"text": "\n" "text": "\n"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -83,11 +83,11 @@
"text": "\n" "text": "\n"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -98,11 +98,11 @@
"text": "\n" "text": "\n"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -113,11 +113,11 @@
"text": "aho" "text": "aho"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -128,11 +128,11 @@
"text": "2" "text": "2"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -143,11 +143,11 @@
"text": "2" "text": "2"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -158,11 +158,11 @@
"text": "2" "text": "2"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -173,11 +173,11 @@
"text": "ima" "text": "ima"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -188,11 +188,11 @@
"text": "." "text": "."
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -203,11 +203,11 @@
"text": "." "text": "."
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -218,11 +218,11 @@
"text": "." "text": "."
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -233,11 +233,11 @@
"text": "\n" "text": "\n"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -248,11 +248,11 @@
"text": " Sarah" "text": " Sarah"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -263,11 +263,11 @@
"text": " Yes" "text": " Yes"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -278,11 +278,11 @@
"text": " And" "text": " And"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -293,11 +293,11 @@
"text": "i" "text": "i"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -308,11 +308,11 @@
"text": "'" "text": "'"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -323,11 +323,11 @@
"text": "," "text": ","
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -338,11 +338,11 @@
"text": " what" "text": " what"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -353,11 +353,11 @@
"text": "'" "text": "'"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -368,11 +368,11 @@
"text": "s" "text": "s"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -383,11 +383,11 @@
"text": " Moh" "text": " Moh"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -398,11 +398,11 @@
"text": " is" "text": " is"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -413,11 +413,11 @@
"text": "m" "text": "m"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -428,11 +428,11 @@
"text": " Room" "text": " Room"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -443,11 +443,11 @@
"text": "s" "text": "s"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -458,11 +458,11 @@
"text": " the" "text": " the"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -473,11 +473,11 @@
"text": " tired" "text": " tired"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -488,11 +488,11 @@
"text": ":" "text": ":"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -503,11 +503,11 @@
"text": "'" "text": "'"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -518,11 +518,11 @@
"text": " capital" "text": " capital"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -530,73 +530,73 @@
"finish_reason": "", "finish_reason": "",
"index": 3, "index": 3,
"logprobs": null, "logprobs": null,
"text": " of" "text": ","
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
{ {
"finish_reason": "", "finish_reason": "length",
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"text": " She" "text": " She"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
{ {
"finish_reason": "", "finish_reason": "length",
"index": 1, "index": 1,
"logprobs": null, "logprobs": null,
"text": " scale" "text": " scale"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
{ {
"finish_reason": "", "finish_reason": "length",
"index": 2, "index": 2,
"logprobs": null, "logprobs": null,
"text": " of" "text": " of"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
{ {
"finish_reason": "", "finish_reason": "length",
"index": 3, "index": 3,
"logprobs": null, "logprobs": null,
"text": " being" "text": " its"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
} }
] ]

View File

@ -16,7 +16,7 @@
}, },
{ {
"id": 3102, "id": 3102,
"logprob": -11.1875, "logprob": -11.25,
"text": " request" "text": " request"
} }
], ],
@ -24,66 +24,66 @@
"tokens": [ "tokens": [
{ {
"id": 185, "id": 185,
"logprob": -1.5546875, "logprob": -1.546875,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
{ {
"id": 549, "id": 549,
"logprob": -2.84375, "logprob": -2.859375,
"special": false, "special": false,
"text": "The" "text": "The"
}, },
{ {
"id": 1727, "id": 1727,
"logprob": -2.34375, "logprob": -2.484375,
"special": false, "special": false,
"text": " test" "text": " test"
}, },
{ {
"id": 3102, "id": 3102,
"logprob": -0.8359375, "logprob": -0.83203125,
"special": false, "special": false,
"text": " request" "text": " request"
}, },
{ {
"id": 317, "id": 317,
"logprob": -1.0859375, "logprob": -1.1484375,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 254, "id": 245,
"logprob": -1.5390625, "logprob": -1.578125,
"special": false, "special": false,
"text": " the" "text": " a"
}, },
{ {
"id": 1022, "id": 3412,
"logprob": -1.1875, "logprob": -2.578125,
"special": false, "special": false,
"text": " first" "text": " document"
}, },
{ {
"id": 3458, "id": 344,
"logprob": -0.35546875, "logprob": -1.125,
"special": false, "special": false,
"text": " step" "text": " that"
}, },
{ {
"id": 279, "id": 317,
"logprob": -0.8828125, "logprob": -1.6953125,
"special": false, "special": false,
"text": " in" "text": " is"
}, },
{ {
"id": 254, "id": 1222,
"logprob": -0.71484375, "logprob": -1.71875,
"special": false, "special": false,
"text": " the" "text": " used"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "\nThe test request is the first step in the" "generated_text": "\nThe test request is a document that is used"
} }

View File

@ -37,56 +37,56 @@
}, },
{ {
"id": 1727, "id": 1727,
"logprob": -2.359375, "logprob": -2.4375,
"special": false, "special": false,
"text": " test" "text": " test"
}, },
{ {
"id": 3102, "id": 3102,
"logprob": -0.83203125, "logprob": -0.83984375,
"special": false, "special": false,
"text": " request" "text": " request"
}, },
{ {
"id": 317, "id": 317,
"logprob": -1.125, "logprob": -1.1328125,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 245, "id": 254,
"logprob": -1.5703125, "logprob": -1.515625,
"special": false, "special": false,
"text": " a" "text": " the"
}, },
{ {
"id": 3412, "id": 1022,
"logprob": -2.578125, "logprob": -1.15625,
"special": false, "special": false,
"text": " document" "text": " first"
}, },
{ {
"id": 344, "id": 3458,
"logprob": -1.125, "logprob": -0.3671875,
"special": false, "special": false,
"text": " that" "text": " step"
}, },
{ {
"id": 317, "id": 279,
"logprob": -1.6953125, "logprob": -0.88671875,
"special": false, "special": false,
"text": " is" "text": " in"
}, },
{ {
"id": 1222, "id": 254,
"logprob": -1.75, "logprob": -0.69140625,
"special": false, "special": false,
"text": " used" "text": " the"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "\nThe test request is a document that is used" "generated_text": "\nThe test request is the first step in the"
}, },
{ {
"details": { "details": {
@ -126,56 +126,56 @@
}, },
{ {
"id": 1727, "id": 1727,
"logprob": -2.359375, "logprob": -2.4375,
"special": false, "special": false,
"text": " test" "text": " test"
}, },
{ {
"id": 3102, "id": 3102,
"logprob": -0.83203125, "logprob": -0.83984375,
"special": false, "special": false,
"text": " request" "text": " request"
}, },
{ {
"id": 317, "id": 317,
"logprob": -1.125, "logprob": -1.1328125,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 245, "id": 254,
"logprob": -1.5703125, "logprob": -1.515625,
"special": false, "special": false,
"text": " a" "text": " the"
}, },
{ {
"id": 3412, "id": 1022,
"logprob": -2.578125, "logprob": -1.15625,
"special": false, "special": false,
"text": " document" "text": " first"
}, },
{ {
"id": 344, "id": 3458,
"logprob": -1.125, "logprob": -0.3671875,
"special": false, "special": false,
"text": " that" "text": " step"
}, },
{ {
"id": 317, "id": 279,
"logprob": -1.6953125, "logprob": -0.88671875,
"special": false, "special": false,
"text": " is" "text": " in"
}, },
{ {
"id": 1222, "id": 254,
"logprob": -1.75, "logprob": -0.69140625,
"special": false, "special": false,
"text": " used" "text": " the"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "\nThe test request is a document that is used" "generated_text": "\nThe test request is the first step in the"
}, },
{ {
"details": { "details": {
@ -215,56 +215,56 @@
}, },
{ {
"id": 1727, "id": 1727,
"logprob": -2.359375, "logprob": -2.4375,
"special": false, "special": false,
"text": " test" "text": " test"
}, },
{ {
"id": 3102, "id": 3102,
"logprob": -0.83203125, "logprob": -0.83984375,
"special": false, "special": false,
"text": " request" "text": " request"
}, },
{ {
"id": 317, "id": 317,
"logprob": -1.125, "logprob": -1.1328125,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 245, "id": 254,
"logprob": -1.5703125, "logprob": -1.515625,
"special": false, "special": false,
"text": " a" "text": " the"
}, },
{ {
"id": 3412, "id": 1022,
"logprob": -2.578125, "logprob": -1.15625,
"special": false, "special": false,
"text": " document" "text": " first"
}, },
{ {
"id": 344, "id": 3458,
"logprob": -1.125, "logprob": -0.3671875,
"special": false, "special": false,
"text": " that" "text": " step"
}, },
{ {
"id": 317, "id": 279,
"logprob": -1.6953125, "logprob": -0.88671875,
"special": false, "special": false,
"text": " is" "text": " in"
}, },
{ {
"id": 1222, "id": 254,
"logprob": -1.75, "logprob": -0.69140625,
"special": false, "special": false,
"text": " used" "text": " the"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "\nThe test request is a document that is used" "generated_text": "\nThe test request is the first step in the"
}, },
{ {
"details": { "details": {
@ -304,55 +304,55 @@
}, },
{ {
"id": 1727, "id": 1727,
"logprob": -2.359375, "logprob": -2.4375,
"special": false, "special": false,
"text": " test" "text": " test"
}, },
{ {
"id": 3102, "id": 3102,
"logprob": -0.83203125, "logprob": -0.83984375,
"special": false, "special": false,
"text": " request" "text": " request"
}, },
{ {
"id": 317, "id": 317,
"logprob": -1.125, "logprob": -1.1328125,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 245, "id": 254,
"logprob": -1.5703125, "logprob": -1.515625,
"special": false, "special": false,
"text": " a" "text": " the"
}, },
{ {
"id": 3412, "id": 1022,
"logprob": -2.578125, "logprob": -1.15625,
"special": false, "special": false,
"text": " document" "text": " first"
}, },
{ {
"id": 344, "id": 3458,
"logprob": -1.125, "logprob": -0.3671875,
"special": false, "special": false,
"text": " that" "text": " step"
}, },
{ {
"id": 317, "id": 279,
"logprob": -1.6953125, "logprob": -0.88671875,
"special": false, "special": false,
"text": " is" "text": " in"
}, },
{ {
"id": 1222, "id": 254,
"logprob": -1.75, "logprob": -0.69140625,
"special": false, "special": false,
"text": " used" "text": " the"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "\nThe test request is a document that is used" "generated_text": "\nThe test request is the first step in the"
} }
] ]

View File

@ -1,8 +1,8 @@
{ {
"details": { "details": {
"best_of_sequences": null, "best_of_sequences": null,
"finish_reason": "length", "finish_reason": "stop_sequence",
"generated_tokens": 10, "generated_tokens": 5,
"prefill": [ "prefill": [
{ {
"id": 128000, "id": 128000,
@ -16,7 +16,7 @@
}, },
{ {
"id": 1715, "id": 1715,
"logprob": -10.375, "logprob": -10.4375,
"text": " request" "text": " request"
} }
], ],
@ -29,61 +29,31 @@
"text": ":" "text": ":"
}, },
{ {
"id": 2209, "id": 923,
"logprob": -2.78125, "logprob": -2.84375,
"special": false, "special": false,
"text": " Is" "text": " add"
}, },
{ {
"id": 279, "id": 264,
"logprob": -0.6328125, "logprob": 0.0,
"special": false, "special": false,
"text": " the" "text": " a"
},
{
"id": 734,
"logprob": -2.703125,
"special": false,
"text": " function"
}, },
{ {
"id": 330, "id": 330,
"logprob": -0.34179688, "logprob": -0.31640625,
"special": false, "special": false,
"text": " \"" "text": " \""
}, },
{ {
"id": 4110, "id": 1985,
"logprob": -2.359375, "logprob": 0.0,
"special": false, "special": false,
"text": "Create" "text": "test"
},
{
"id": 7575,
"logprob": -2.1875,
"special": false,
"text": "Process"
},
{
"id": 1,
"logprob": -0.07910156,
"special": false,
"text": "\""
},
{
"id": 304,
"logprob": -0.83203125,
"special": false,
"text": " in"
},
{
"id": 12468,
"logprob": -1.8203125,
"special": false,
"text": " Win"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "Test request: Is the function \"CreateProcess\" in Win" "generated_text": "Test request: add a \"test"
} }

View File

@ -16,7 +16,7 @@
}, },
{ {
"id": 100, "id": 100,
"logprob": -0.38549805, "logprob": -0.38305664,
"text": "_" "text": "_"
}, },
{ {
@ -29,7 +29,7 @@
"tokens": [ "tokens": [
{ {
"id": 2284, "id": 2284,
"logprob": -0.31323242, "logprob": -0.296875,
"special": false, "special": false,
"text": "():" "text": "():"
}, },
@ -59,19 +59,19 @@
}, },
{ {
"id": 10914, "id": 10914,
"logprob": -0.7817383, "logprob": -0.7734375,
"special": false, "special": false,
"text": " World" "text": " World"
}, },
{ {
"id": 16013, "id": 16013,
"logprob": -0.6328125, "logprob": -0.61816406,
"special": false, "special": false,
"text": "!\")" "text": "!\")"
}, },
{ {
"id": 222, "id": 222,
"logprob": -0.0619812, "logprob": -0.054870605,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
@ -83,7 +83,7 @@
}, },
{ {
"id": 610, "id": 610,
"logprob": -0.4086914, "logprob": -0.4152832,
"special": false, "special": false,
"text": "def" "text": "def"
}, },
@ -113,7 +113,7 @@
}, },
{ {
"id": 444, "id": 444,
"logprob": -0.21826172, "logprob": -0.21618652,
"special": false, "special": false,
"text": "name" "text": "name"
}, },
@ -173,7 +173,7 @@
}, },
{ {
"id": 11571, "id": 11571,
"logprob": -0.10021973, "logprob": -0.08892822,
"special": false, "special": false,
"text": "!\"" "text": "!\""
}, },

View File

@ -30,19 +30,19 @@
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.37573242, "logprob": -0.38061523,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 633, "id": 633,
"logprob": -0.09161377, "logprob": -0.09301758,
"special": false, "special": false,
"text": " new" "text": " new"
}, },
{ {
"id": 4480, "id": 4480,
"logprob": -0.26171875, "logprob": -0.26782227,
"special": false, "special": false,
"text": " feature" "text": " feature"
}, },
@ -78,7 +78,7 @@
}, },
{ {
"id": 13, "id": 13,
"logprob": 0.0, "logprob": -0.10632324,
"special": false, "special": false,
"text": "\n" "text": "\n"
} }

View File

@ -35,6 +35,6 @@ async def test_flash_llama_simple(flash_llama_chat, response_snapshot):
print(repr(response.choices[0].message.content)) print(repr(response.choices[0].message.content))
assert ( assert (
response.choices[0].message.content response.choices[0].message.content
== "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to explore in the middle of urban confines. In fact, typical times for humidity levels in Brooklyn include:\n\n- Early morning: 80-85% humidity, with occas" == "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to appreciate nature.\n\nIn terms of temperature, the warmest times of the year are from June to August, when average high temperatures typically range from around 73°F or 23°C"
) )
assert response == response_snapshot assert response == response_snapshot

View File

@ -8,7 +8,7 @@ use nix::unistd::Pid;
use serde::Deserialize; use serde::Deserialize;
use std::env; use std::env;
use std::ffi::OsString; use std::ffi::OsString;
use std::io::{BufRead, BufReader, Lines}; use std::io::{BufRead, BufReader};
use std::os::unix::process::{CommandExt, ExitStatusExt}; use std::os::unix::process::{CommandExt, ExitStatusExt};
use std::path::Path; use std::path::Path;
use std::process::{Child, Command, ExitStatus, Stdio}; use std::process::{Child, Command, ExitStatus, Stdio};
@ -18,12 +18,103 @@ use std::sync::{mpsc, Arc};
use std::thread; use std::thread;
use std::thread::sleep; use std::thread::sleep;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use std::{fs, io}; use std::{
fs, io,
io::{Read, Write},
};
use thiserror::Error; use thiserror::Error;
use tracing_subscriber::{filter::LevelFilter, EnvFilter}; use tracing_subscriber::{filter::LevelFilter, EnvFilter};
mod env_runtime; mod env_runtime;
fn get_config(
model_id: &str,
revision: &Option<String>,
) -> Result<Config, Box<dyn std::error::Error>> {
let mut path = std::path::Path::new(model_id).to_path_buf();
let model_id = model_id.to_string();
let filename = if !path.exists() {
// Assume it's a hub id
let api = if let Ok(token) = std::env::var("HF_TOKEN") {
// env variable has precedence over on file token.
ApiBuilder::new().with_token(Some(token)).build()?
} else {
Api::new()?
};
let repo = if let Some(ref revision) = revision {
api.repo(Repo::with_revision(
model_id,
RepoType::Model,
revision.to_string(),
))
} else {
api.model(model_id)
};
repo.get("config.json")?
} else {
path.push("config.json");
path
};
let content = std::fs::read_to_string(filename)?;
let config: RawConfig = serde_json::from_str(&content)?;
let config: Config = config.into();
Ok(config)
}
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
if let Some(config) = config {
if prefix_caching.is_none() {
if config.vision_config.is_some() {
tracing::info!("Disabling prefix caching because of VLM model");
prefix_caching = Some("0".to_string());
} else if config.is_encoder_decoder {
tracing::info!("Disabling prefix caching because of seq2seq model");
prefix_caching = Some("0".to_string());
}
}
match config.head_dim {
Some(h) if h == 64 || h == 128 || h == 256 => {
if lora_adapters.is_some() && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching because of lora adapters");
prefix_caching = Some("0".to_string());
}
match config.model_type.as_deref() {
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
// Required because gemma2 needs bfloat16 which is not supported by
// flashinfer ?
if attention.is_none() {
tracing::info!(
"Forcing flash decoding because model {} requires it",
config.model_type.as_ref().unwrap()
);
attention = Some("flashdecoding".to_string());
}
}
Some("t5") => {}
_ => {}
}
}
_ => {
if attention.is_none() {
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
attention = Some("flashdecoding".to_string());
}
if prefix_caching.is_none() {
prefix_caching = Some("0".to_string());
}
}
}
}
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
let attention = attention.unwrap_or("flashinfer".to_string());
(prefix_caching, attention)
}
#[derive(Deserialize)] #[derive(Deserialize)]
struct RawConfig { struct RawConfig {
max_position_embeddings: Option<usize>, max_position_embeddings: Option<usize>,
@ -31,6 +122,12 @@ struct RawConfig {
model_type: Option<String>, model_type: Option<String>,
max_seq_len: Option<usize>, max_seq_len: Option<usize>,
quantization_config: Option<QuantizationConfig>, quantization_config: Option<QuantizationConfig>,
n_embd: Option<usize>,
hidden_size: Option<usize>,
num_attention_heads: Option<usize>,
head_dim: Option<usize>,
vision_config: Option<VisionConfig>,
is_encoder_decoder: Option<bool>,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -38,10 +135,17 @@ struct QuantizationConfig {
quant_method: Option<Quantization>, quant_method: Option<Quantization>,
} }
#[derive(Deserialize)]
struct VisionConfig {}
#[derive(Deserialize)] #[derive(Deserialize)]
struct Config { struct Config {
max_position_embeddings: Option<usize>, max_position_embeddings: Option<usize>,
quantize: Option<Quantization>, quantize: Option<Quantization>,
head_dim: Option<usize>,
model_type: Option<String>,
vision_config: Option<VisionConfig>,
is_encoder_decoder: bool,
} }
impl From<RawConfig> for Config { impl From<RawConfig> for Config {
@ -51,9 +155,32 @@ impl From<RawConfig> for Config {
.or(other.max_seq_len) .or(other.max_seq_len)
.or(other.n_positions); .or(other.n_positions);
let quantize = other.quantization_config.and_then(|q| q.quant_method); let quantize = other.quantization_config.and_then(|q| q.quant_method);
let head_dim = other.head_dim.or_else(|| {
match (other.hidden_size, other.n_embd, other.num_attention_heads) {
(Some(hidden_size), _, Some(num_attention_heads))
if hidden_size % num_attention_heads == 0 =>
{
Some(hidden_size / num_attention_heads)
}
// Legacy
(_, Some(hidden_size), Some(num_attention_heads))
if hidden_size % num_attention_heads == 0 =>
{
Some(hidden_size / num_attention_heads)
}
_ => None,
}
});
let model_type = other.model_type;
let vision_config = other.vision_config;
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
Config { Config {
max_position_embeddings, max_position_embeddings,
quantize, quantize,
head_dim,
model_type,
vision_config,
is_encoder_decoder,
} }
} }
} }
@ -731,6 +858,7 @@ fn shard_manager(
.args(shard_args) .args(shard_args)
.env_clear() .env_clear()
.envs(envs) .envs(envs)
.stdin(Stdio::piped())
.stdout(Stdio::piped()) .stdout(Stdio::piped())
.stderr(Stdio::piped()) .stderr(Stdio::piped())
.process_group(0) .process_group(0)
@ -752,12 +880,13 @@ fn shard_manager(
}; };
// Redirect STDOUT to the console // Redirect STDOUT to the console
let mut pstdin = p.stdin.take().unwrap();
let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap()); let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap()); let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
//stdout tracing thread //stdout tracing thread
thread::spawn(move || { thread::spawn(move || {
log_lines(shard_stdout_reader.lines()); log_lines(shard_stdout_reader);
}); });
// We read stderr in another thread as it seems that lines() can block in some cases // We read stderr in another thread as it seems that lines() can block in some cases
let (err_sender, err_receiver) = mpsc::channel(); let (err_sender, err_receiver) = mpsc::channel();
@ -766,6 +895,18 @@ fn shard_manager(
err_sender.send(line).unwrap_or(()); err_sender.send(line).unwrap_or(());
} }
}); });
// We read stdin in another thread as it seems that lines() can block in some cases
thread::spawn(move || {
let mut stdin = io::stdin(); // We get `Stdin` here.
loop {
let mut buffer = vec![0; 4096];
if let Ok(n) = stdin.read(&mut buffer) {
if n > 0 {
let _ = pstdin.write_all(&buffer[..n]);
}
}
}
});
let mut ready = false; let mut ready = false;
let start_time = Instant::now(); let start_time = Instant::now();
@ -872,19 +1013,36 @@ impl PythonLogMessage {
} }
} }
impl TryFrom<&String> for PythonLogMessage { impl TryFrom<&[u8]> for PythonLogMessage {
type Error = serde_json::Error; type Error = serde_json::Error;
fn try_from(value: &String) -> Result<Self, Self::Error> { fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
serde_json::from_str::<Self>(value) serde_json::from_slice::<Self>(value)
} }
} }
fn log_lines<S: Sized + BufRead>(lines: Lines<S>) { fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
for line in lines.map_while(Result::ok) { let mut buffer = vec![0u8; 8 * 4096];
match PythonLogMessage::try_from(&line) { let mut stdout = std::io::stdout();
Ok(log) => log.trace(), loop {
Err(_) => tracing::debug!("{line}"), let n = bufread.read(&mut buffer);
if let Ok(n) = n {
if n > 0 {
let mut lines = buffer[..n].split(|i| *i == b'\n').peekable();
while let Some(line) = lines.next() {
match PythonLogMessage::try_from(line) {
Ok(log) => log.trace(),
// For interactive debugging ?
Err(_) => {
stdout.write_all(line).unwrap();
if lines.peek().is_some() {
stdout.write_all(b"\n").unwrap();
}
stdout.flush().unwrap();
}
}
}
}
} }
} }
} }
@ -1044,7 +1202,7 @@ fn download_convert_model(
let download_stdout = BufReader::new(download_process.stdout.take().unwrap()); let download_stdout = BufReader::new(download_process.stdout.take().unwrap());
thread::spawn(move || { thread::spawn(move || {
log_lines(download_stdout.lines()); log_lines(download_stdout);
}); });
let download_stderr = BufReader::new(download_process.stderr.take().unwrap()); let download_stderr = BufReader::new(download_process.stderr.take().unwrap());
@ -1439,68 +1597,35 @@ fn main() -> Result<(), LauncherError> {
tracing::info!("{:#?}", args); tracing::info!("{:#?}", args);
let get_max_positions_quantize = let config: Option<Config> = get_config(&args.model_id, &args.revision).ok();
|| -> Result<(usize, Option<Quantization>), Box<dyn std::error::Error>> { let quantize = config.as_ref().and_then(|c| c.quantize);
let model_id = args.model_id.clone(); // Quantization usually means you're even more RAM constrained.
let mut path = std::path::Path::new(&args.model_id).to_path_buf(); let max_default = 4096;
let filename = if !path.exists() {
// Assume it's a hub id
let api = if let Ok(token) = std::env::var("HF_TOKEN") { let max_position_embeddings = if let Some(config) = &config {
// env variable has precedence over on file token. if let Some(max_position_embeddings) = config.max_position_embeddings {
ApiBuilder::new().with_token(Some(token)).build()? if max_position_embeddings > max_default {
} else { let max = max_position_embeddings;
Api::new()? if args.max_input_tokens.is_none()
}; && args.max_total_tokens.is_none()
let repo = if let Some(ref revision) = args.revision { && args.max_batch_prefill_tokens.is_none()
api.repo(Repo::with_revision( {
model_id, tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
RepoType::Model,
revision.to_string(),
))
} else {
api.model(model_id)
};
repo.get("config.json")?
} else {
path.push("config.json");
path
};
let content = std::fs::read_to_string(filename)?;
let config: RawConfig = serde_json::from_str(&content)?;
if config.model_type == Some("gemma2".to_string()) {
tracing::info!("Forcing flash decoding because of softcap usage");
std::env::set_var("ATTENTION", "flashdecoding");
}
let config: Config = config.into();
let quantize = config.quantize;
// Quantization usually means you're even more RAM constrained.
let max_default = 4096;
if let Some(max_position_embeddings) = config.max_position_embeddings {
if max_position_embeddings > max_default {
let max = max_position_embeddings;
if args.max_input_tokens.is_none()
&& args.max_total_tokens.is_none()
&& args.max_batch_prefill_tokens.is_none()
{
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
}
Ok((max_default, quantize))
} else {
Ok((max_position_embeddings, quantize))
} }
max_default
} else { } else {
Err(Box::new(LauncherError::ArgumentValidation( max_position_embeddings
"no max defined".to_string(),
)))
} }
}; } else {
let (max_position_embeddings, quantize): (usize, Option<Quantization>) = max_default
get_max_positions_quantize().unwrap_or((4096, None)); }
} else {
max_default
};
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
std::env::set_var("USE_PREFIX_CACHING", prefix_caching);
std::env::set_var("ATTENTION", attention);
let max_input_tokens = { let max_input_tokens = {
match (args.max_input_tokens, args.max_input_length) { match (args.max_input_tokens, args.max_input_length) {

View File

@ -33,13 +33,13 @@ export function get_options() {
// rate: 20, // rate: 20,
// timeUnit: '1s', // timeUnit: '1s',
// }, // },
load_test: { // load_test: {
executor: 'constant-arrival-rate', // executor: 'constant-arrival-rate',
duration: '60s', // duration: '60s',
preAllocatedVUs: 100, // preAllocatedVUs: 100,
rate: 1, // rate: 1,
timeUnit: '1s', // timeUnit: '1s',
}, // },
// breakpoint: { // breakpoint: {
// executor: 'ramping-arrival-rate', //Assure load increase if the system slows // executor: 'ramping-arrival-rate', //Assure load increase if the system slows
// preAllocatedVUs: 300, // preAllocatedVUs: 300,
@ -47,12 +47,12 @@ export function get_options() {
// { duration: '60s', target: 100 }, // just slowly ramp-up to a HUGE load // { duration: '60s', target: 100 }, // just slowly ramp-up to a HUGE load
// ], // ],
// }, // },
// throughput: { throughput: {
// executor: 'shared-iterations', executor: 'shared-iterations',
// vus: 100, vus: 100,
// iterations: 200, iterations: 200,
// maxDuration: '40s', maxDuration: '40s',
// }, },
}, },
}; };
} }

View File

@ -137,6 +137,8 @@ message Request {
optional string adapter_id = 11; optional string adapter_id = 11;
/// Prefix length that can be retrieved from the KV cache. /// Prefix length that can be retrieved from the KV cache.
uint32 prefix_len = 12; uint32 prefix_len = 12;
/// Context truncation
bool add_special_tokens = 13;
} }
message Batch { message Batch {

View File

@ -120,10 +120,11 @@ impl Infer {
) -> Result<Option<tokenizers::Encoding>, InferError> { ) -> Result<Option<tokenizers::Encoding>, InferError> {
// Tokenize request // Tokenize request
let inputs = request.inputs; let inputs = request.inputs;
let add_special_tokens = request.add_special_tokens;
let truncate = request.parameters.truncate; let truncate = request.parameters.truncate;
let encoding = self let encoding = self
.validation .validation
.tokenize(inputs, truncate) .tokenize(inputs, add_special_tokens, truncate)
.await .await
.map_err(|err| { .map_err(|err| {
tracing::error!("Tokenization {err}"); tracing::error!("Tokenization {err}");

View File

@ -22,6 +22,16 @@ pub enum Attention {
FlashInfer, FlashInfer,
} }
impl Attention {
pub fn block_size(&self) -> u32 {
match self {
Attention::FlashDecoding => 256,
Attention::FlashInfer => 1,
Attention::Paged => 16,
}
}
}
#[derive(Debug)] #[derive(Debug)]
pub struct ParseError; pub struct ParseError;
@ -1072,6 +1082,16 @@ pub(crate) struct GenerateRequest {
pub inputs: String, pub inputs: String,
#[serde(default = "default_parameters")] #[serde(default = "default_parameters")]
pub parameters: GenerateParameters, pub parameters: GenerateParameters,
/// This is used internally because some requests
/// already contain the templated input therefore
/// we shouldn't add the special tokens.
#[serde(default = "default_true", skip)]
pub add_special_tokens: bool,
}
fn default_true() -> bool {
true
} }
#[derive(Clone, Debug, Deserialize, ToSchema)] #[derive(Clone, Debug, Deserialize, ToSchema)]
@ -1089,6 +1109,7 @@ impl From<CompatGenerateRequest> for GenerateRequest {
fn from(req: CompatGenerateRequest) -> Self { fn from(req: CompatGenerateRequest) -> Self {
Self { Self {
inputs: req.inputs, inputs: req.inputs,
add_special_tokens: true,
parameters: req.parameters, parameters: req.parameters,
} }
} }

View File

@ -158,6 +158,7 @@ async fn get_chat_tokenize(
let generate_request = GenerateRequest { let generate_request = GenerateRequest {
inputs, inputs,
add_special_tokens: false,
parameters: GenerateParameters { parameters: GenerateParameters {
best_of: None, best_of: None,
temperature, temperature,
@ -754,6 +755,7 @@ async fn completions(
.iter() .iter()
.map(|prompt| GenerateRequest { .map(|prompt| GenerateRequest {
inputs: prompt.to_string(), inputs: prompt.to_string(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
best_of: None, best_of: None,
temperature, temperature,
@ -1180,6 +1182,7 @@ async fn chat_completions(
// build the request passing some parameters // build the request passing some parameters
let generate_request = GenerateRequest { let generate_request = GenerateRequest {
inputs: inputs.to_string(), inputs: inputs.to_string(),
add_special_tokens: false,
parameters: GenerateParameters { parameters: GenerateParameters {
best_of: None, best_of: None,
temperature, temperature,
@ -1386,6 +1389,7 @@ async fn vertex_compatibility(
.map(|instance| { .map(|instance| {
let generate_request = GenerateRequest { let generate_request = GenerateRequest {
inputs: instance.inputs.clone(), inputs: instance.inputs.clone(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
do_sample: true, do_sample: true,
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens), max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),

View File

@ -95,6 +95,7 @@ impl Validation {
pub async fn tokenize( pub async fn tokenize(
&self, &self,
inputs: String, inputs: String,
add_special_tokens: bool,
truncate: Option<usize>, truncate: Option<usize>,
) -> Result<Option<(tokenizers::Encoding, Vec<Chunk>)>, ValidationError> { ) -> Result<Option<(tokenizers::Encoding, Vec<Chunk>)>, ValidationError> {
// If we have a fast tokenizer // If we have a fast tokenizer
@ -104,7 +105,11 @@ impl Validation {
// Send request to the background validation task // Send request to the background validation task
// Unwrap is safe here // Unwrap is safe here
sender sender
.send(((inputs, truncate), response_sender, Span::current())) .send((
(inputs, add_special_tokens, truncate),
response_sender,
Span::current(),
))
.unwrap(); .unwrap();
// Await on response channel // Await on response channel
@ -121,11 +126,15 @@ impl Validation {
async fn validate_input( async fn validate_input(
&self, &self,
inputs: String, inputs: String,
add_special_tokens: bool,
truncate: Option<usize>, truncate: Option<usize>,
max_new_tokens: Option<u32>, max_new_tokens: Option<u32>,
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> { ) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
// If we have a fast tokenizer // If we have a fast tokenizer
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { if let Some((encoding, inputs)) = self
.tokenize(inputs.clone(), add_special_tokens, truncate)
.await?
{
// Create response channel // Create response channel
let input_length = if let Some(truncate) = truncate { let input_length = if let Some(truncate) = truncate {
std::cmp::min(encoding.len(), truncate) std::cmp::min(encoding.len(), truncate)
@ -158,7 +167,8 @@ impl Validation {
)); ));
} }
let input_ids = encoding.get_ids()[..input_length].to_owned(); let ids = encoding.get_ids();
let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned();
metrics::histogram!("tgi_request_input_length").record(input_length as f64); metrics::histogram!("tgi_request_input_length").record(input_length as f64);
Ok((inputs, Some(input_ids), input_length, max_new_tokens)) Ok((inputs, Some(input_ids), input_length, max_new_tokens))
@ -324,7 +334,12 @@ impl Validation {
// Validate inputs // Validate inputs
let (inputs, input_ids, input_length, max_new_tokens) = self let (inputs, input_ids, input_length, max_new_tokens) = self
.validate_input(request.inputs, truncate, max_new_tokens) .validate_input(
request.inputs,
request.add_special_tokens,
truncate,
max_new_tokens,
)
.await?; .await?;
// TODO: we should build the FSM here and pass the compiled FSM instead of the grammar // TODO: we should build the FSM here and pass the compiled FSM instead of the grammar
@ -401,6 +416,7 @@ impl Validation {
Ok(ValidGenerateRequest { Ok(ValidGenerateRequest {
inputs, inputs,
input_ids: input_ids.map(Arc::new), input_ids: input_ids.map(Arc::new),
add_special_tokens: request.add_special_tokens,
decoder_input_details, decoder_input_details,
input_length: input_length as u32, input_length: input_length as u32,
truncate: truncate.unwrap_or(self.max_input_length) as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32,
@ -449,12 +465,15 @@ fn tokenizer_worker(
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>, mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
) { ) {
// Loop over requests // Loop over requests
while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() { while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
receiver.blocking_recv()
{
parent_span.in_scope(|| { parent_span.in_scope(|| {
response_tx response_tx
.send(prepare_input( .send(prepare_input(
inputs, inputs,
truncate, truncate,
add_special_tokens,
&tokenizer, &tokenizer,
config.as_ref(), config.as_ref(),
preprocessor_config.as_ref(), preprocessor_config.as_ref(),
@ -591,6 +610,7 @@ fn image_tokens_fixup(config: &Config, text: String) -> String {
fn prepare_input( fn prepare_input(
inputs: String, inputs: String,
_truncate: Option<usize>, _truncate: Option<usize>,
add_special_tokens: bool,
tokenizer: &Tokenizer, tokenizer: &Tokenizer,
config: Option<&Config>, config: Option<&Config>,
preprocessor_config: Option<&HubPreprocessorConfig>, preprocessor_config: Option<&HubPreprocessorConfig>,
@ -628,14 +648,14 @@ fn prepare_input(
// Get the number of tokens in the input // Get the number of tokens in the input
let encoding = tokenizer let encoding = tokenizer
.encode(tokenizer_query, true) .encode(tokenizer_query, add_special_tokens)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?; .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
Ok((encoding, input_chunks)) Ok((encoding, input_chunks))
} }
type TokenizerRequest = ( type TokenizerRequest = (
(String, Option<usize>), (String, bool, Option<usize>),
oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>, oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>,
Span, Span,
); );
@ -720,6 +740,7 @@ pub struct ValidGenerateRequest {
pub input_ids: Option<Arc<Vec<u32>>>, pub input_ids: Option<Arc<Vec<u32>>>,
pub input_length: u32, pub input_length: u32,
pub truncate: u32, pub truncate: u32,
pub add_special_tokens: bool,
pub decoder_input_details: bool, pub decoder_input_details: bool,
pub parameters: ValidParameters, pub parameters: ValidParameters,
pub stopping_parameters: ValidStoppingParameters, pub stopping_parameters: ValidStoppingParameters,
@ -826,7 +847,7 @@ mod tests {
let max_new_tokens = 10; let max_new_tokens = 10;
match validation match validation
.validate_input("Hello".to_string(), None, Some(max_new_tokens)) .validate_input("Hello".to_string(), true, None, Some(max_new_tokens))
.await .await
{ {
// Err(ValidationError::MaxNewTokens(1, 10)) => (), // Err(ValidationError::MaxNewTokens(1, 10)) => (),
@ -861,7 +882,7 @@ mod tests {
let max_new_tokens = 10; let max_new_tokens = 10;
match validation match validation
.validate_input("Hello".to_string(), None, Some(max_new_tokens)) .validate_input("Hello".to_string(), true, None, Some(max_new_tokens))
.await .await
{ {
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (), Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
@ -895,6 +916,7 @@ mod tests {
match validation match validation
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
best_of: Some(2), best_of: Some(2),
do_sample: false, do_sample: false,
@ -934,6 +956,7 @@ mod tests {
match validation match validation
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
top_p: Some(1.0), top_p: Some(1.0),
max_new_tokens: Some(5), max_new_tokens: Some(5),
@ -949,6 +972,7 @@ mod tests {
match validation match validation
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
top_p: Some(0.99), top_p: Some(0.99),
max_new_tokens: Some(5), max_new_tokens: Some(5),
@ -964,6 +988,7 @@ mod tests {
let valid_request = validation let valid_request = validation
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
top_p: None, top_p: None,
max_new_tokens: Some(5), max_new_tokens: Some(5),
@ -1002,6 +1027,7 @@ mod tests {
match validation match validation
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
top_n_tokens: Some(5), top_n_tokens: Some(5),
max_new_tokens: Some(5), max_new_tokens: Some(5),
@ -1017,6 +1043,7 @@ mod tests {
validation validation
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
top_n_tokens: Some(4), top_n_tokens: Some(4),
max_new_tokens: Some(5), max_new_tokens: Some(5),
@ -1029,6 +1056,7 @@ mod tests {
validation validation
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
top_n_tokens: Some(0), top_n_tokens: Some(0),
max_new_tokens: Some(5), max_new_tokens: Some(5),
@ -1041,6 +1069,7 @@ mod tests {
let valid_request = validation let valid_request = validation
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
top_n_tokens: None, top_n_tokens: None,
max_new_tokens: Some(5), max_new_tokens: Some(5),
@ -1089,6 +1118,7 @@ mod tests {
let chunks = match validation let chunks = match validation
.tokenize( .tokenize(
format!("test![](data:image/gif;base64,{})", PIXEL_GIF), format!("test![](data:image/gif;base64,{})", PIXEL_GIF),
true,
None, None,
) )
.await .await
@ -1148,6 +1178,7 @@ mod tests {
"test![](data:image/gif;base64,{})![](data:image/gif;base64,{})", "test![](data:image/gif;base64,{})![](data:image/gif;base64,{})",
PIXEL_GIF, PIXEL_GIF PIXEL_GIF, PIXEL_GIF
), ),
true,
None, None,
) )
.await .await

View File

@ -1,5 +1,5 @@
[toolchain] [toolchain]
# Released on: June 13, 2024 # Released on: June 13, 2024
# https://releases.rs/docs/1.79.0/ # https://releases.rs/docs/1.79.0/
channel = "1.79.0" channel = "1.80.0"
components = ["rustfmt", "clippy"] components = ["rustfmt", "clippy"]

View File

@ -7,6 +7,7 @@ include Makefile-selective-scan
include Makefile-lorax-punica include Makefile-lorax-punica
include Makefile-fbgemm include Makefile-fbgemm
include Makefile-exllamav2 include Makefile-exllamav2
include Makefile-flashinfer
unit-tests: unit-tests:
pytest -s -vv -m "not private" tests pytest -s -vv -m "not private" tests

View File

@ -0,0 +1,2 @@
install-flashinfer:
pip install flashinfer==0.1.5 -i https://flashinfer.ai/whl/cu124/torch2.4

View File

@ -1,7 +1,10 @@
import pytest import pytest
import os
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
os.environ["USE_PREFIX_CACHING"] = "1"
os.environ["ATTENTION"] = "flashinfer"
@pytest.fixture @pytest.fixture
def default_pb_parameters(): def default_pb_parameters():

View File

@ -9,26 +9,46 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
@dataclass @dataclass
class Seqlen: class Seqlen:
input_lengths: torch.Tensor input_lengths: torch.Tensor
prefix_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor] cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor] cu_seqlen_k: Optional[torch.Tensor]
max_q: int
max_k: int
def __init__(self, input_lengths): def __init__(
self,
input_lengths,
prefix_lengths,
cu_seqlen_q=None,
max_q=None,
max_k=None,
):
self.input_lengths = input_lengths self.input_lengths = input_lengths
self.prefix_lengths = prefix_lengths
device = self.input_lengths.device device = self.input_lengths.device
shape = self.input_lengths.shape shape = self.input_lengths.shape
cu_seqlen_q = torch.arange( if cu_seqlen_q is None:
shape[0] + 1, cu_seqlen_q = torch.arange(
device=device, shape[0] + 1,
dtype=torch.int32, device=device,
) dtype=torch.int32,
)
max_q = 1
else:
assert max_q is not None
assert max_k is not None
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
# cuda graphs don't like this and this is necessary to clamp within mistral # cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping # Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0 # cu_seqlen_k[0] = 0
torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:]) total = self.input_lengths + self.prefix_lengths
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
self.cu_seqlen_q = cu_seqlen_q self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k self.cu_seqlen_k = cu_seqlen_k
self.max_q = max_q
self.max_k = max_k
def clamp(self, max): def clamp(self, max):
# Flash decoding doesn't need to clamp # Flash decoding doesn't need to clamp
@ -39,6 +59,11 @@ else:
@dataclass @dataclass
class Seqlen: class Seqlen:
input_lengths: torch.Tensor input_lengths: torch.Tensor
prefix_lengths: torch.Tensor
cu_seqlen_q: torch.Tensor
max_q: int
max_k: int
def clamp(self, max): def clamp(self, max):
raise NotImplementedError("Not implemented seqlen for paged")
return Seqlen(torch.clamp(self.input_lengths, max=max)) return Seqlen(torch.clamp(self.input_lengths, max=max))

View File

@ -222,18 +222,15 @@ if ATTENTION == "flashinfer":
def attention( def attention(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
cu_seqlens, seqlen: Seqlen,
max_s, block_tables: torch.Tensor,
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
causal=True, causal=True,
softcap=0.0, softcap=0.0,
): ):
assert window_size_left == -1, "Windowing is not supported with flash infer"
from text_generation_server.layers.attention.flashinfer import ( from text_generation_server.layers.attention.flashinfer import (
prefill_with_paged_kv_state, prefill_with_paged_kv_state,
) )
@ -244,18 +241,17 @@ if ATTENTION == "flashinfer":
paged_kv_cache=(key_cache, value_cache), paged_kv_cache=(key_cache, value_cache),
logits_soft_cap=softcap, logits_soft_cap=softcap,
sm_scale=softmax_scale, sm_scale=softmax_scale,
window_left=window_size_left,
) )
elif V2: elif V2:
def attention( def attention(
q, q,
k,
v,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
cu_seqlens, seqlen: Seqlen,
max_s, block_tables: torch.Tensor,
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
causal=True, causal=True,
@ -266,17 +262,17 @@ elif V2:
raise ValueError("`window_size_left` must be > 0 or -1") raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
q, q,
k, key_cache,
v, value_cache,
out, out,
cu_seqlens, seqlen.cu_seqlen_q,
cu_seqlens, seqlen.cu_seqlen_k,
None, None,
None, None,
block_tables,
None, None,
None, seqlen.max_q,
max_s, seqlen.max_k,
max_s,
0.0, 0.0,
softmax_scale, softmax_scale,
False, False,

View File

@ -497,15 +497,14 @@ def get_model(
else -1 else -1
) )
should_use_sliding_window = ( use_sliding_window = sliding_window is not None and sliding_window != -1
sliding_window is not None and sliding_window != -1 and SUPPORTS_WINDOWING needs_sliding_window = (
max_input_tokens is not None and max_input_tokens > sliding_window
) )
if use_sliding_window and needs_sliding_window and not SUPPORTS_WINDOWING:
if should_use_sliding_window: raise ValueError(
if max_input_tokens is not None and max_input_tokens > sliding_window: f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
raise ValueError( )
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
)
if model_type == DEEPSEEK_V2: if model_type == DEEPSEEK_V2:
if FLASH_ATTENTION: if FLASH_ATTENTION:

View File

@ -29,6 +29,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import ( from text_generation_server.layers import (
@ -264,7 +265,7 @@ class FlashCohereAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
@ -296,12 +297,10 @@ class FlashCohereAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
key,
value,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
@ -313,7 +312,7 @@ class FlashCohereAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
@ -388,7 +387,7 @@ class FlashCohereLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -402,7 +401,7 @@ class FlashCohereLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
@ -454,7 +453,7 @@ class FlashCohereModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: torch.Tensor,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
@ -477,7 +476,7 @@ class FlashCohereModel(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
@ -518,7 +517,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
@ -531,7 +530,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -29,6 +29,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
FastLinear, FastLinear,
@ -309,7 +310,7 @@ class DbrxAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
@ -335,12 +336,10 @@ class DbrxAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
@ -352,7 +351,7 @@ class DbrxAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
@ -389,7 +388,7 @@ class DbrxNormAttentionNorm(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
normed_hidden_states, res = self.norm_1(hidden_states, residual) normed_hidden_states, res = self.norm_1(hidden_states, residual)
@ -403,7 +402,7 @@ class DbrxNormAttentionNorm(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
@ -622,7 +621,7 @@ class DbrxLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
# Self Attention # Self Attention
@ -635,7 +634,7 @@ class DbrxLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
@ -679,7 +678,7 @@ class DbrxModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
@ -701,7 +700,7 @@ class DbrxModel(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
@ -734,7 +733,7 @@ class FlashDbrxForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
@ -747,7 +746,7 @@ class FlashDbrxForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -29,8 +29,8 @@ from text_generation_server.layers.attention import (
attention, attention,
paged_attention, paged_attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers.attention.common import Seqlen
from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
@ -298,7 +298,7 @@ class DeepseekV2Attention(torch.nn.Module):
kv_cache: Tuple[torch.Tensor, torch.Tensor], kv_cache: Tuple[torch.Tensor, torch.Tensor],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: Seqlen, seqlen: Seqlen,
max_s: int, max_s: int,
): ):
if self.q_lora_rank is None: if self.q_lora_rank is None:
@ -363,12 +363,10 @@ class DeepseekV2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
key,
value,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
@ -380,7 +378,7 @@ class DeepseekV2Attention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
@ -666,7 +664,7 @@ class DeepseekV2Layer(nn.Module):
kv_cache, kv_cache,
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: Seqlen, seqlen: Seqlen,
max_s: int, max_s: int,
): ):
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual) normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
@ -680,7 +678,7 @@ class DeepseekV2Layer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
@ -729,7 +727,7 @@ class DeepseekV2Model(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
@ -751,7 +749,7 @@ class DeepseekV2Model(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
@ -781,7 +779,7 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
@ -794,7 +792,7 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -213,7 +214,7 @@ class FlashGemma2Attention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
@ -236,12 +237,10 @@ class FlashGemma2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
causal=self.causal, causal=self.causal,
window_size_left=self.window_size, window_size_left=self.window_size,
@ -256,7 +255,7 @@ class FlashGemma2Attention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
softcap=self.softcap, softcap=self.softcap,
) )
@ -343,7 +342,7 @@ class FlashGemma2Layer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -357,7 +356,7 @@ class FlashGemma2Layer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
@ -408,7 +407,7 @@ class FlashGemma2Model(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -430,7 +429,7 @@ class FlashGemma2Model(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
@ -477,7 +476,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
@ -491,7 +490,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -207,7 +208,7 @@ class FlashGemmaAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
@ -230,12 +231,10 @@ class FlashGemmaAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
causal=self.causal, causal=self.causal,
) )
@ -248,7 +247,7 @@ class FlashGemmaAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
@ -320,7 +319,7 @@ class FlashGemmaLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -334,7 +333,7 @@ class FlashGemmaLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
@ -382,7 +381,7 @@ class FlashGemmaModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -404,7 +403,7 @@ class FlashGemmaModel(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
@ -449,7 +448,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
@ -463,7 +462,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -29,6 +29,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -213,7 +214,7 @@ class FlashGPT2Attention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
query, key, value = self.query_key_value(hidden_states).split( query, key, value = self.query_key_value(hidden_states).split(
@ -230,12 +231,10 @@ class FlashGPT2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
key,
value,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
@ -247,7 +246,7 @@ class FlashGPT2Attention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
@ -316,7 +315,7 @@ class FlashGPT2Layer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
residual = hidden_states residual = hidden_states
@ -329,7 +328,7 @@ class FlashGPT2Layer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
@ -382,7 +381,7 @@ class FlashGPT2Model(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
@ -398,7 +397,7 @@ class FlashGPT2Model(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
@ -435,7 +434,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None, prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
@ -451,7 +450,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
true_max_s=max_s, true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices, prefill_cache_indices=prefill_cache_indices,

View File

@ -29,6 +29,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -167,7 +168,7 @@ class FlashGPTJAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
query, key, value = self.query_key_value(hidden_states).split( query, key, value = self.query_key_value(hidden_states).split(
@ -192,10 +193,10 @@ class FlashGPTJAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
key, kv_cache[0],
value, kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
@ -207,7 +208,7 @@ class FlashGPTJAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
@ -268,7 +269,7 @@ class FlashGPTJLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states, residual = self.input_layernorm(hidden_states, residual)
@ -281,7 +282,7 @@ class FlashGPTJLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
@ -328,7 +329,7 @@ class FlashGPTJModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
@ -351,7 +352,7 @@ class FlashGPTJModel(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
@ -382,7 +383,7 @@ class FlashGPTJForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None, prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
@ -395,7 +396,7 @@ class FlashGPTJForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices=prefill_cache_indices, prefill_cache_indices=prefill_cache_indices,
) )

View File

@ -32,6 +32,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -194,7 +195,7 @@ class FlashLlamaAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
adapter_data, adapter_data,
): ):
@ -218,12 +219,10 @@ class FlashLlamaAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
@ -235,7 +234,7 @@ class FlashLlamaAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
@ -375,7 +374,7 @@ class FlashLlamaLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
adapter_data, adapter_data,
): ):
@ -390,7 +389,7 @@ class FlashLlamaLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
adapter_data, adapter_data,
) )
@ -479,7 +478,7 @@ class FlashLlamaModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
@ -504,7 +503,7 @@ class FlashLlamaModel(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
adapter_data, adapter_data,
) )
@ -548,7 +547,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None, prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
@ -562,7 +561,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
true_max_s=max_s, true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices, prefill_cache_indices=prefill_cache_indices,

View File

@ -31,6 +31,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -185,7 +186,7 @@ class MistralAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data, adapter_data,
@ -217,12 +218,10 @@ class MistralAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
window_size_left=self.max_past, window_size_left=self.max_past,
) )
@ -235,7 +234,7 @@ class MistralAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
@ -356,7 +355,7 @@ class MistralLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data, adapter_data,
@ -372,7 +371,7 @@ class MistralLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data, adapter_data,
@ -424,7 +423,7 @@ class MistralModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
@ -448,7 +447,7 @@ class MistralModel(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data, adapter_data,
@ -499,7 +498,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
@ -512,7 +511,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
input_lengths = input_lengths.clamp(max=self.max_past_tensor) seqlen = seqlen.clamp(max=self.max_past_tensor)
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model( hidden_states = self.model(
@ -522,7 +521,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
true_max_s, true_max_s,
prefill_cache_indices, prefill_cache_indices,

View File

@ -35,6 +35,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
FastLinear, FastLinear,
@ -243,7 +244,7 @@ class MixtralAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
): ):
@ -274,12 +275,10 @@ class MixtralAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
window_size_left=self.max_past, window_size_left=self.max_past,
) )
@ -292,7 +291,7 @@ class MixtralAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
@ -498,7 +497,7 @@ class MixtralLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
): ):
@ -513,7 +512,7 @@ class MixtralLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
) )
@ -568,7 +567,7 @@ class MixtralModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
@ -592,7 +591,7 @@ class MixtralModel(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
) )
@ -627,7 +626,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
@ -640,7 +639,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
input_lengths = input_lengths.clamp(max=self.max_past_tensor) seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
@ -649,7 +648,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
true_max_s, true_max_s,
prefill_cache_indices, prefill_cache_indices,

View File

@ -31,6 +31,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -147,7 +148,7 @@ class FlashNeoxAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
@ -171,12 +172,10 @@ class FlashNeoxAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
qkv[:, 0], qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
@ -188,7 +187,7 @@ class FlashNeoxAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
@ -258,7 +257,7 @@ class FlashNeoXLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
if self.use_parallel_residual: if self.use_parallel_residual:
@ -272,7 +271,7 @@ class FlashNeoXLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
@ -296,7 +295,7 @@ class FlashNeoXLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
@ -350,7 +349,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_in(input_ids) hidden_states = self.embed_in(input_ids)
@ -372,7 +371,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
@ -404,7 +403,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
@ -417,7 +416,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -19,6 +19,7 @@ from torch import nn
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
from text_generation_server.layers.attention import Seqlen
from text_generation_server.models.custom_modeling.vlm import ( from text_generation_server.models.custom_modeling.vlm import (
load_text_model, load_text_model,
load_vision_model, load_vision_model,
@ -70,7 +71,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None, prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
@ -107,7 +108,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables, block_tables=block_tables,
slots=slots, slots=slots,
input_lengths=input_lengths, seqlen=seqlen,
max_s=max_s, max_s=max_s,
) )

View File

@ -10,6 +10,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -159,7 +160,7 @@ class FlashPhiAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
# Compute query, key, value and split # Compute query, key, value and split
@ -192,12 +193,10 @@ class FlashPhiAttention(torch.nn.Module):
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
@ -209,7 +208,7 @@ class FlashPhiAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
@ -276,7 +275,7 @@ class FlashPhiLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
hidden_states, res = self.input_layernorm(hidden_states, residual) hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -289,7 +288,7 @@ class FlashPhiLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
@ -341,7 +340,7 @@ class FlashPhiModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
@ -363,7 +362,7 @@ class FlashPhiModel(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
@ -396,7 +395,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
@ -409,7 +408,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -9,6 +9,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -104,7 +105,7 @@ class Qwen2Attention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
): ):
@ -135,12 +136,10 @@ class Qwen2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
window_size_left=self.max_past, window_size_left=self.max_past,
) )
@ -153,7 +152,7 @@ class Qwen2Attention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
@ -225,7 +224,7 @@ class Qwen2Layer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
): ):
@ -240,7 +239,7 @@ class Qwen2Layer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
) )
@ -296,7 +295,7 @@ class Qwen2Model(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
@ -320,7 +319,7 @@ class Qwen2Model(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
) )
@ -361,7 +360,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None, prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
@ -374,7 +373,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
input_lengths = input_lengths.clamp(max=self.max_past_tensor) seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
@ -383,7 +382,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
true_max_s, true_max_s,
prefill_cache_indices, prefill_cache_indices,

View File

@ -19,6 +19,7 @@ from text_generation_server.layers.attention import (
attention, attention,
paged_attention, paged_attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
@ -181,7 +182,7 @@ class FlashRWAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
@ -206,12 +207,10 @@ class FlashRWAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
@ -223,7 +222,7 @@ class FlashRWAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
@ -296,7 +295,7 @@ class FlashRWLargeAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
@ -343,7 +342,7 @@ class FlashRWLargeAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
@ -429,7 +428,7 @@ class FlashRWLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
if self.parallel_attn: if self.parallel_attn:
@ -443,7 +442,7 @@ class FlashRWLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
@ -465,7 +464,7 @@ class FlashRWLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
@ -552,7 +551,7 @@ class FlashRWLargeLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
# Layer norm. # Layer norm.
@ -567,7 +566,7 @@ class FlashRWLargeLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
@ -628,7 +627,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids) hidden_states = self.word_embeddings(input_ids)
@ -650,7 +649,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
@ -680,7 +679,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
@ -693,7 +692,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -9,6 +9,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -268,7 +269,7 @@ class FlashMQAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
qkv = self.c_attn(hidden_states) qkv = self.c_attn(hidden_states)
@ -291,12 +292,10 @@ class FlashMQAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
@ -308,7 +307,7 @@ class FlashMQAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
@ -373,7 +372,7 @@ class Block(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states, residual = self.ln_1(hidden_states, residual)
@ -383,7 +382,7 @@ class Block(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
@ -437,7 +436,7 @@ class FlashSantacoderModel(nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.wte(input_ids) + self.wpe(position_ids) hidden_states = self.wte(input_ids) + self.wpe(position_ids)
@ -454,7 +453,7 @@ class FlashSantacoderModel(nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
@ -486,7 +485,7 @@ class FlashSantacoderForCausalLM(nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
@ -499,7 +498,7 @@ class FlashSantacoderForCausalLM(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -209,7 +210,7 @@ class Starcoder2Attention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
): ):
@ -240,12 +241,10 @@ class Starcoder2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
window_size_left=self.max_past, window_size_left=self.max_past,
) )
@ -258,7 +257,7 @@ class Starcoder2Attention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
@ -381,7 +380,7 @@ class Starcoder2Layer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
): ):
@ -396,7 +395,7 @@ class Starcoder2Layer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
) )
@ -449,7 +448,7 @@ class Starcoder2Model(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
@ -473,7 +472,7 @@ class Starcoder2Model(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
) )
@ -521,7 +520,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
@ -534,7 +533,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
input_lengths = input_lengths.clamp(max=self.max_past_tensor) seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
@ -543,7 +542,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
true_max_s, true_max_s,
prefill_cache_indices, prefill_cache_indices,

View File

@ -25,6 +25,7 @@ from transformers.activations import ACT2FN
from text_generation_server.models.custom_modeling.vlm import ( from text_generation_server.models.custom_modeling.vlm import (
load_text_model, load_text_model,
) )
from text_generation_server.layers.attention import Seqlen
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from text_generation_server.layers import ( from text_generation_server.layers import (
@ -740,7 +741,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
@ -826,7 +827,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables, block_tables=block_tables,
slots=slots, slots=slots,
input_lengths=input_lengths, seqlen=seqlen,
max_s=max_s, max_s=max_s,
true_max_s=max_s, true_max_s=max_s,
prefill_cache_indices=None, prefill_cache_indices=None,

View File

@ -23,6 +23,7 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.image_processing_utils import select_best_resolution from transformers.image_processing_utils import select_best_resolution
from text_generation_server.layers.attention import Seqlen
from text_generation_server.models.custom_modeling.vlm import ( from text_generation_server.models.custom_modeling.vlm import (
load_text_model, load_text_model,
load_vision_model, load_vision_model,
@ -170,7 +171,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
@ -276,7 +277,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables, block_tables=block_tables,
slots=slots, slots=slots,
input_lengths=input_lengths, seqlen=seqlen,
max_s=max_s, max_s=max_s,
true_max_s=max_s, true_max_s=max_s,
prefill_cache_indices=None, prefill_cache_indices=None,

View File

@ -43,7 +43,7 @@ from text_generation_server.models.globals import (
ATTENTION, ATTENTION,
BLOCK_SIZE, BLOCK_SIZE,
CUDA_GRAPHS, CUDA_GRAPHS,
PREFIX_CACHING, TGI_WIGGLE_ROOM,
get_adapter_to_index, get_adapter_to_index,
) )
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
@ -189,16 +189,21 @@ class FlashCausalLMBatch(Batch):
def batch_tokenized_inputs( def batch_tokenized_inputs(
cls, requests: Iterable[generate_pb2.Request], tokenizer cls, requests: Iterable[generate_pb2.Request], tokenizer
): ):
batch_inputs = [] max_length = 0
max_truncation = 0 all_input_ids = []
batch_size = 0
for r in requests: for r in requests:
batch_inputs.append(concat_text_chunks(r.input_chunks.chunks)) batch_size += 1
max_truncation = max(max_truncation, r.truncate) inputs = concat_text_chunks(r.input_chunks.chunks)
input_ids = tokenizer(
batch_tokenized_inputs = tokenizer( inputs,
batch_inputs, truncation=True, max_length=max_truncation truncation=True,
)["input_ids"] max_length=r.truncate,
return batch_tokenized_inputs add_special_tokens=r.add_special_tokens,
)["input_ids"]
max_length = max(max_length, len(input_ids))
all_input_ids.append(input_ids)
return all_input_ids
@classmethod @classmethod
def from_tokenized( def from_tokenized(
@ -257,22 +262,15 @@ class FlashCausalLMBatch(Batch):
# request id -> idx in list mapping # request id -> idx in list mapping
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
tokenized_input = tokenized_input[-r.truncate :]
if (
tokenized_input[0] == tokenizer.bos_token_id
and tokenized_input[1] == tokenizer.bos_token_id
):
tokenized_input = tokenized_input[1:]
orig_input_length = len(tokenized_input) orig_input_length = len(tokenized_input)
if PREFIX_CACHING: prefix_len = r.prefix_len
prefix_len = r.prefix_len assert (
if prefix_len == orig_input_length: prefix_len <= orig_input_length
assert prefix_len > 0 ), f"Prefix {prefix_len} vs input {orig_input_length}"
prefix_len -= 1 if prefix_len == orig_input_length:
else: assert prefix_len > 0
prefix_len = 0 prefix_len -= 1
prefix_ids.append(tokenized_input[:prefix_len]) prefix_ids.append(tokenized_input[:prefix_len])
tokenized_input = tokenized_input[prefix_len:] tokenized_input = tokenized_input[prefix_len:]
@ -998,7 +996,7 @@ class FlashCausalLM(Model):
config.sliding_window = None config.sliding_window = None
self.num_layers = config.num_hidden_layers self.num_layers = config.num_hidden_layers
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads // self.process_group.size()
# Validation is done in the model itself # Validation is done in the model itself
if num_kv_heads is None: if num_kv_heads is None:
num_kv_heads = getattr(config, "num_key_value_heads", None) num_kv_heads = getattr(config, "num_key_value_heads", None)
@ -1160,8 +1158,15 @@ class FlashCausalLM(Model):
"block_tables": block_tables, "block_tables": block_tables,
"slots": slots, "slots": slots,
"input_lengths": input_lengths_tensor, "input_lengths": input_lengths_tensor,
"prefix_lengths": prefix_lengths_tensor,
} }
input_lengths_ = Seqlen(input_lengths=input_lengths_tensor) seqlen = Seqlen(
input_lengths=input_lengths_tensor,
prefix_lengths=prefix_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph self.cuda_graphs[bs]["graph"] = graph
@ -1204,7 +1209,7 @@ class FlashCausalLM(Model):
kv_cache=self.kv_cache, kv_cache=self.kv_cache,
block_tables=block_tables, block_tables=block_tables,
slots=slots, slots=slots,
input_lengths=input_lengths_, seqlen=seqlen,
max_s=max_s, max_s=max_s,
prefill_cache_indices=None, prefill_cache_indices=None,
lm_head_indices=None, lm_head_indices=None,
@ -1213,7 +1218,13 @@ class FlashCausalLM(Model):
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL): with torch.cuda.graph(graph, pool=MEM_POOL):
input_lengths_tensor = Seqlen(input_lengths=input_lengths_tensor) seqlen = Seqlen(
input_lengths=input_lengths_tensor,
prefix_lengths=prefix_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
@ -1221,7 +1232,7 @@ class FlashCausalLM(Model):
kv_cache=self.kv_cache, kv_cache=self.kv_cache,
block_tables=block_tables, block_tables=block_tables,
slots=slots, slots=slots,
input_lengths=input_lengths_tensor, seqlen=seqlen,
max_s=max_s, max_s=max_s,
prefill_cache_indices=None, prefill_cache_indices=None,
lm_head_indices=None, lm_head_indices=None,
@ -1268,7 +1279,7 @@ class FlashCausalLM(Model):
num_blocks = ( num_blocks = (
# Leave 5% for some wiggle room # Leave 5% for some wiggle room
int((free_memory * 0.95) // total_cache_size) int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size)
# Add batch.num_blocks as we allocated it above, so it is included in the peak memory. # Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
+ batch_num_blocks + batch_num_blocks
) )
@ -1360,18 +1371,26 @@ class FlashCausalLM(Model):
# Dummy value, some models (starcoder2) don't accept `None`. # Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
input_lengths = Seqlen(input_lengths=input_lengths) prefix_lens_tensor = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
cu_seqlen_prefill = torch.tensor(
[0, seqlen], device=self.device, dtype=torch.int32
)
seqlen = Seqlen(
input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=1,
max_k=seqlen,
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward( self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=torch.tensor( cu_seqlen_prefill=cu_seqlen_prefill,
[0, seqlen], device=self.device, dtype=torch.int32
),
kv_cache=self.kv_cache, kv_cache=self.kv_cache,
block_tables=None, block_tables=None,
input_lengths=input_lengths, seqlen=seqlen,
slots=slots, slots=slots,
max_s=seqlen, max_s=seqlen,
lm_head_indices=None, lm_head_indices=None,
@ -1451,8 +1470,7 @@ class FlashCausalLM(Model):
cuda_graph = None cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = input_lengths + prefix_lens_tensor if ATTENTION == "flashinfer":
if PREFIX_CACHING:
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
input_lengths=batch.input_lengths, input_lengths=batch.input_lengths,
@ -1462,11 +1480,18 @@ class FlashCausalLM(Model):
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths=batch.input_lengths, input_lengths=batch.input_lengths,
input_lengths_tensor=input_lengths, input_lengths_tensor=input_lengths + prefix_lens_tensor,
prefix_lens=batch.prefix_lens, prefix_lens=batch.prefix_lens,
prefix_lens_tensor=prefix_lens_tensor, prefix_lens_tensor=prefix_lens_tensor,
): ):
input_lengths = Seqlen(input_lengths=input_lengths) max_k = (input_lengths + prefix_lens_tensor).max().item()
seqlen = Seqlen(
input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s,
max_k=max_k,
)
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
@ -1474,7 +1499,7 @@ class FlashCausalLM(Model):
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables, block_tables=block_tables,
slots=slots, slots=slots,
input_lengths=input_lengths, seqlen=seqlen,
max_s=max_s, max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices, prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,

View File

@ -5,19 +5,22 @@ from typing import Dict, Optional
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "0").lower() in {"1", "true"} PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged") ATTENTION = os.getenv("ATTENTION")
_expected = {"paged", "flashdecoding", "flashinfer"} _expected = {"paged", "flashdecoding", "flashinfer"}
assert ( assert (
ATTENTION in _expected ATTENTION in _expected
), f"Attention is not valid {ATTENTION}, expected {_expected}" ), f"Attention is not valid {ATTENTION}, expected {_expected}"
log_master(logger.info, f"Using Attention = {ATTENTION}") log_master(logger.info, f"Using Attention = {ATTENTION}")
if PREFIX_CACHING and ATTENTION != "flashinfer": if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
raise RuntimeError("Prefix caching is only supported with flashinfer") raise RuntimeError("Prefix caching is only supported with flashinfer")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95"))
assert TGI_WIGGLE_ROOM > 0
assert TGI_WIGGLE_ROOM < 1
# This is overridden by the cli # This is overridden by the cli
BLOCK_SIZE: int BLOCK_SIZE: int

View File

@ -372,7 +372,14 @@ class VlmCausalLM(FlashCausalLM):
prefix_lens=batch.prefix_lens, prefix_lens=batch.prefix_lens,
prefix_lens_tensor=prefix_lens_tensor, prefix_lens_tensor=prefix_lens_tensor,
): ):
input_lengths = Seqlen(input_lengths=input_lengths) max_k = (input_lengths + prefix_lens_tensor).max().item()
seqlen = Seqlen(
input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s,
max_k=max_k,
)
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
@ -380,7 +387,7 @@ class VlmCausalLM(FlashCausalLM):
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables, block_tables=block_tables,
slots=slots, slots=slots,
input_lengths=input_lengths, seqlen=seqlen,
max_s=max_s, max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices, prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,