feat: prefill chunking (#2600)

* wip

* rollback

* refactor to use prefix/postfix namming + fix all_input_ids_tensor

* maybe patching vlms?

* fix filter and concat

* wip, no filter, no concat

* current

* add prepare_for_prefill

* working

* load tested

* re-create slots

* re-create slots

* fix slot_filtering_indices

* feedback loop

* remove log

* fix benchmarker

* fix vlm and seq2seq

* rename to cache and input lengths

* fix prefill logprobs

* fix launcher

* fix logprobs?

* idk at this point

* max input length

* omfg

* remove debugging lines

* fix tests

* fix mllama

* fix cargo tests

* remove support chunking for paged

* Fixing non blocked attentions

* Fixing dtype + AMD, Ipex targets.

* lint fix.

* rename

* Fix prefix_caching variable, remove defaults in server (confusing a lot
of the times).

* Add simple resolution when user specifies ATTENTION=paged.

* Put back non default simple tests.

* Fix env name

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
OlivierDehaene 2024-10-16 12:49:33 +02:00 committed by GitHub
parent 704a58c807
commit a6a0c97ed9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
46 changed files with 1694 additions and 1123 deletions

View File

@ -327,7 +327,8 @@ ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0 ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
ENV VLLM_MOE_PADDING=0 ENV VLLM_MOE_PADDING=0
ENV ATTENTION=paged ENV ATTENTION=paged
ENV USE_PREFIX_CACHING=0 ENV PREFIX_CACHING=0
ENV PREFILL_CHUNKING=0
ENV ROCM_USE_SKINNY_GEMM=1 ENV ROCM_USE_SKINNY_GEMM=1
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh

View File

@ -218,7 +218,8 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/lo
FROM ${PLATFORM} AS final FROM ${PLATFORM} AS final
ENV ATTENTION=paged ENV ATTENTION=paged
ENV USE_PREFIX_CACHING=0 ENV PREFIX_CACHING=0
ENV PREFILL_CHUNKING=0
ENV CUDA_GRAPHS=0 ENV CUDA_GRAPHS=0
ENTRYPOINT ["text-generation-launcher"] ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"] CMD ["--json-output"]

View File

@ -158,7 +158,8 @@ impl Client {
// Blocks and slots will be set on the server side if we use paged attention // Blocks and slots will be set on the server side if we use paged attention
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
prefix_len: 0, cache_len: 0,
chunk_len: None,
// Set sampling parameters to also take these ops into account in the max memory // Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 0.9, temperature: 0.9,
@ -217,8 +218,13 @@ impl Client {
pub async fn prefill( pub async fn prefill(
&mut self, &mut self,
batch: Batch, batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> { ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); let request = tonic::Request::new(PrefillRequest {
batch: Some(batch),
cached_batch,
})
.inject_context();
let response = self.stub.prefill(request).await?.into_inner(); let response = self.stub.prefill(request).await?.into_inner();
Ok(( Ok((
response.generations, response.generations,

View File

@ -134,11 +134,12 @@ impl ShardedClient {
pub async fn prefill( pub async fn prefill(
&mut self, &mut self,
batch: Batch, batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> { ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
.iter_mut() .iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone()))) .map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
.collect(); .collect();
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> = let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
@ -245,7 +246,8 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks // Block 0 is reserved for health checks
blocks: vec![0], blocks: vec![0],
slots: (0..16).collect(), slots: (0..16).collect(),
prefix_len: 0, cache_len: 0,
chunk_len: None,
adapter_id: None, adapter_id: None,
}; };
let batch = Batch { let batch = Batch {
@ -255,7 +257,7 @@ impl Health for ShardedClient {
max_tokens: 2, max_tokens: 2,
max_blocks: 1, max_blocks: 1,
}; };
self.clone().prefill(batch).await?; self.clone().prefill(batch, None).await?;
Ok(()) Ok(())
} }
} }

View File

@ -6,7 +6,7 @@ use nohash_hasher::IntMap;
use std::sync::Arc; use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest; use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{Attention, FinishReason, PrefillToken, Token}; use text_generation_router::{FinishReason, PrefillToken, Token};
use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify}; use tokio::sync::{mpsc, Notify};
use tokio::time::Instant; use tokio::time::Instant;
@ -36,18 +36,14 @@ impl BackendV2 {
speculate: u32, speculate: u32,
) -> Self { ) -> Self {
// Infer shared state // Infer shared state
let attention = if let Ok(attention) = std::env::var("ATTENTION") { let attention = std::env::var("ATTENTION").unwrap_or("paged".to_string());
attention let block_size = match attention.as_str() {
.parse() "flashinfer" => 1,
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")) "flashdecoding" => 256,
} else { "paged" => 16,
Attention::Paged _ => unreachable!(),
};
let block_size = if attention == Attention::FlashDecoding {
256
} else {
16
}; };
let queue = Queue::new(requires_padding, block_size, window_size, speculate); let queue = Queue::new(requires_padding, block_size, window_size, speculate);
let batching_task_notifier = Arc::new(Notify::new()); let batching_task_notifier = Arc::new(Notify::new());

View File

@ -1,12 +1,14 @@
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
/// Batching and inference logic /// Batching and inference logic
use crate::client::{
Batch, CachedBatch, ClientError, Generation, Health, InfoResponse, ShardedClient,
};
use crate::queue::{Entry, Queue}; use crate::queue::{Entry, Queue};
use async_trait::async_trait; use async_trait::async_trait;
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
use std::sync::Arc; use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest; use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{Attention, FinishReason, PrefillToken, Token}; use text_generation_router::{FinishReason, PrefillToken, Token};
use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify}; use tokio::sync::{mpsc, Notify};
use tokio::time::Instant; use tokio::time::Instant;
@ -31,27 +33,22 @@ impl BackendV3 {
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
max_waiting_tokens: usize, max_waiting_tokens: usize,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
requires_padding: bool, shard_info: InfoResponse,
window_size: Option<u32>,
speculate: u32,
) -> Self { ) -> Self {
let prefix_caching = if shard_info.support_chunking {
std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var"); tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored.");
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1"); }
let attention: String = std::env::var("ATTENTION").expect("attention env var");
let attention: Attention = attention let block_size = shard_info.block_size;
.parse()
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
let block_size = attention.block_size();
let queue = Queue::new( let queue = Queue::new(
requires_padding, shard_info.requires_padding,
block_size, block_size,
prefix_caching, shard_info.use_prefix_caching,
window_size, shard_info.window_size,
speculate, shard_info.speculate,
max_batch_total_tokens, max_batch_total_tokens,
shard_info.support_chunking,
); );
let batching_task_notifier = Arc::new(Notify::new()); let batching_task_notifier = Arc::new(Notify::new());
@ -63,6 +60,7 @@ impl BackendV3 {
max_batch_total_tokens, max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
max_batch_size, max_batch_size,
shard_info.support_chunking,
queue.clone(), queue.clone(),
batching_task_notifier.clone(), batching_task_notifier.clone(),
)); ));
@ -127,6 +125,7 @@ pub(crate) async fn batching_task(
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
max_waiting_tokens: usize, max_waiting_tokens: usize,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
support_chunking: bool,
queue: Queue, queue: Queue,
notifier: Arc<Notify>, notifier: Arc<Notify>,
) { ) {
@ -147,7 +146,7 @@ pub(crate) async fn batching_task(
) )
.await .await
{ {
let mut cached_batch = prefill(&mut client, batch, &mut entries) let mut cached_batch = prefill(&mut client, batch, None, &mut entries)
.instrument(span) .instrument(span)
.await; .await;
let mut waiting_tokens = 1; let mut waiting_tokens = 1;
@ -158,10 +157,24 @@ pub(crate) async fn batching_task(
// Get current batch info // Get current batch info
let batch_size = batch.size; let batch_size = batch.size;
let batch_max_tokens = batch.max_tokens; let batch_max_tokens = batch.max_tokens;
let current_tokens = batch.current_tokens;
let mut batches = vec![batch]; let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let (min_size, max_size, prefill_token_budget) = if support_chunking {
// Since the next batch will be concatenated with the current batch,
// the current batch tokens must be subtracted to the prefill budget
let prefill_token_budget =
max_batch_prefill_tokens.saturating_sub(current_tokens);
// We can ignore min_size and max_size
// Models than rely on max_size cannot support chunking
// Regarding min_size, chunking allow us to consistently run at the compute
// bound, making min_size useless.
(None, None, prefill_token_budget)
} else {
let min_size = if waiting_tokens >= max_waiting_tokens { let min_size = if waiting_tokens >= max_waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try // If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small // to add a new batch even though its size might be small
@ -173,24 +186,34 @@ pub(crate) async fn batching_task(
Some((batch_size as f32 * waiting_served_ratio).floor() as usize) Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
}; };
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let max_size = let max_size =
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
(min_size, max_size, max_batch_prefill_tokens)
};
// Try to get a new batch // Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue if let Some((new_entries, new_batch, span)) = queue
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget) .next_batch(min_size, max_size, prefill_token_budget, token_budget)
.await .await
{ {
// Tracking metrics // Tracking metrics
if min_size.is_some() { if min_size.is_some() {
metrics::counter!("tgi_batch_concat", "reason" => "backpressure") metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
.increment(1); .increment(1);
} else {
let counter = if support_chunking {
metrics::counter!("tgi_batch_concat", "reason" => "chunking")
} else { } else {
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
.increment(1); };
counter.increment(1);
} }
let cached_batch = if support_chunking {
// Concat current batch to the new one
batches.pop()
} else {
// Request are waiting only if we don't support chunking
entries.iter_mut().for_each(|(_, entry)| { entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting // Create a new span to add the info that this entry is waiting
// because a new batch is being computed // because a new batch is being computed
@ -201,17 +224,23 @@ pub(crate) async fn batching_task(
// Update entry // Update entry
entry.temp_span = Some(entry_waiting_span); entry.temp_span = Some(entry_waiting_span);
}); });
None
};
entries.extend(new_entries);
// Generate one token for this new batch to have the attention past in cache // Generate one token for this new batch to have the attention past in cache
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries) let new_cached_batch =
prefill(&mut client, new_batch, cached_batch, &mut entries)
.instrument(span) .instrument(span)
.await; .await;
// Reset waiting counter // Reset waiting counter
waiting_tokens = 1; waiting_tokens = 1;
// Extend current batch with the new batch // Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch { if let Some(new_cached_batch) = new_cached_batch {
entries.extend(new_entries);
batches.push(new_cached_batch); batches.push(new_cached_batch);
} else if support_chunking {
// New cached batch is empty, no work left
break;
} }
} }
@ -244,13 +273,14 @@ pub(crate) async fn batching_task(
async fn prefill( async fn prefill(
client: &mut ShardedClient, client: &mut ShardedClient,
batch: Batch, batch: Batch,
cached_batch: Option<CachedBatch>,
entries: &mut IntMap<u64, Entry>, entries: &mut IntMap<u64, Entry>,
) -> Option<CachedBatch> { ) -> Option<CachedBatch> {
let start_time = Instant::now(); let start_time = Instant::now();
let batch_id = batch.id; let batch_id = batch.id;
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1); metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
match client.prefill(batch).await { match client.prefill(batch, cached_batch).await {
Ok((generations, next_batch, timings)) => { Ok((generations, next_batch, timings)) => {
let start_filtering_time = Instant::now(); let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries // Send generated tokens and filter stopped entries
@ -259,6 +289,10 @@ async fn prefill(
// Filter next batch and remove requests that were stopped // Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await; let next_batch = filter_batch(client, next_batch, entries).await;
if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
.record(concat_duration.as_secs_f64());
}
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill") metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
.record(timings.forward.as_secs_f64()); .record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill") metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")

View File

@ -158,7 +158,8 @@ impl Client {
// Blocks and slots will be set on the server side if we use paged attention // Blocks and slots will be set on the server side if we use paged attention
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
prefix_len: 0, cache_len: 0,
chunk_len: None,
// Set sampling parameters to also take these ops into account in the max memory // Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 0.9, temperature: 0.9,
@ -217,13 +218,23 @@ impl Client {
pub async fn prefill( pub async fn prefill(
&mut self, &mut self,
batch: Batch, batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> { ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); let request = tonic::Request::new(PrefillRequest {
batch: Some(batch),
cached_batch,
})
.inject_context();
let response = self.stub.prefill(request).await?.into_inner(); let response = self.stub.prefill(request).await?.into_inner();
Ok(( Ok((
response.generations, response.generations,
response.batch, response.batch,
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns), PrefillTimings::new(
response.concat_ns,
response.forward_ns,
response.decode_ns,
response.total_ns,
),
)) ))
} }
@ -252,14 +263,16 @@ impl Client {
} }
pub struct PrefillTimings { pub struct PrefillTimings {
pub concat: Option<Duration>,
pub forward: Duration, pub forward: Duration,
pub decode: Duration, pub decode: Duration,
pub total: Duration, pub total: Duration,
} }
impl PrefillTimings { impl PrefillTimings {
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self { Self {
concat: concat_ns.map(Duration::from_nanos),
forward: Duration::from_nanos(forward_ns), forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns), decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns), total: Duration::from_nanos(total_ns),

View File

@ -29,15 +29,6 @@ pub trait Health {
async fn model_health(&self) -> Result<()>; async fn model_health(&self) -> Result<()>;
} }
#[derive(Debug)]
pub struct ShardInfo {
pub requires_padding: bool,
pub dtype: String,
pub device_type: String,
pub window_size: Option<u32>,
pub speculate: u32,
}
#[derive(Error, Debug, Clone)] #[derive(Error, Debug, Clone)]
pub enum ClientError { pub enum ClientError {
#[error("Could not connect to Text Generation server: {0}")] #[error("Could not connect to Text Generation server: {0}")]

View File

@ -1,6 +1,6 @@
use crate::client::{ClientError, Result}; use crate::client::Health;
/// Multi shard Client /// Multi shard Client
use crate::client::{Health, ShardInfo}; use crate::client::{ClientError, Result};
use crate::client::grpc_client::{DecodeTimings, PrefillTimings}; use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
use crate::client::{ use crate::client::{
@ -49,13 +49,13 @@ impl ShardedClient {
/// Get the model info /// Get the model info
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn info(&mut self) -> Result<ShardInfo> { pub async fn info(&mut self) -> Result<InfoResponse> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
.iter_mut() .iter_mut()
.map(|client| client.info()) .map(|client| client.info())
.collect(); .collect();
join_all(futures).await.pop().unwrap().map(ShardInfo::from) join_all(futures).await.pop().unwrap()
} }
/// GRPC health check /// GRPC health check
@ -135,11 +135,12 @@ impl ShardedClient {
pub async fn prefill( pub async fn prefill(
&mut self, &mut self,
batch: Batch, batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> { ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
.iter_mut() .iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone()))) .map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
.collect(); .collect();
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> = let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
@ -194,18 +195,6 @@ impl ShardedClient {
} }
} }
impl From<InfoResponse> for ShardInfo {
fn from(value: InfoResponse) -> Self {
Self {
requires_padding: value.requires_padding,
dtype: value.dtype,
device_type: value.device_type,
window_size: value.window_size,
speculate: value.speculate,
}
}
}
#[async_trait] #[async_trait]
impl Health for ShardedClient { impl Health for ShardedClient {
async fn device_health(&self) -> Result<()> { async fn device_health(&self) -> Result<()> {
@ -246,8 +235,9 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks // Block 0 is reserved for health checks
blocks: vec![0], blocks: vec![0],
slots: (0..16).collect(), slots: (0..16).collect(),
prefix_len: 0, cache_len: 0,
adapter_id: None, adapter_id: None,
chunk_len: None,
}; };
let batch = Batch { let batch = Batch {
id: u64::MAX, id: u64::MAX,
@ -256,7 +246,7 @@ impl Health for ShardedClient {
max_tokens: 2, max_tokens: 2,
max_blocks: 1, max_blocks: 1,
}; };
self.clone().prefill(batch).await?; self.clone().prefill(batch, None).await?;
Ok(()) Ok(())
} }
} }

View File

@ -29,6 +29,14 @@ pub struct BackendInfo {
pub max_waiting_tokens: usize, pub max_waiting_tokens: usize,
#[schema(nullable = true, example = "null")] #[schema(nullable = true, example = "null")]
pub max_batch_size: Option<usize>, pub max_batch_size: Option<usize>,
#[schema(example = "false")]
pub support_chunking: bool,
#[schema(example = "false")]
pub prefix_caching: bool,
#[schema(example = "flashinfer")]
pub attention_impl: String,
#[schema(example = "1")]
pub block_size: u32,
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
@ -110,6 +118,10 @@ pub async fn connect_backend(
model_device_type: shard_info.device_type.clone(), model_device_type: shard_info.device_type.clone(),
model_dtype: shard_info.dtype.clone(), model_dtype: shard_info.dtype.clone(),
speculate: shard_info.speculate as usize, speculate: shard_info.speculate as usize,
support_chunking: shard_info.support_chunking,
prefix_caching: shard_info.use_prefix_caching,
attention_impl: shard_info.attention_impl.clone(),
block_size: shard_info.block_size,
}; };
let backend = BackendV3::new( let backend = BackendV3::new(
@ -119,9 +131,7 @@ pub async fn connect_backend(
max_batch_total_tokens, max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
max_batch_size, max_batch_size,
shard_info.requires_padding, shard_info,
shard_info.window_size,
shard_info.speculate,
); );
tracing::info!("Using backend V3"); tracing::info!("Using backend V3");

View File

@ -131,25 +131,12 @@ async fn main() -> Result<(), RouterError> {
"`max_input_tokens` must be < `max_total_tokens`".to_string(), "`max_input_tokens` must be < `max_total_tokens`".to_string(),
)); ));
} }
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
}
if validation_workers == 0 { if validation_workers == 0 {
return Err(RouterError::ArgumentValidation( return Err(RouterError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(), "`validation_workers` must be > 0".to_string(),
)); ));
} }
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
}
if let Some(max_batch_size) = max_batch_size { if let Some(max_batch_size) = max_batch_size {
if max_batch_size == 0 { if max_batch_size == 0 {
return Err(RouterError::ArgumentValidation( return Err(RouterError::ArgumentValidation(
@ -158,7 +145,7 @@ async fn main() -> Result<(), RouterError> {
} }
} }
let (backend, _backend_info) = connect_backend( let (backend, backend_info) = connect_backend(
max_input_tokens, max_input_tokens,
max_total_tokens, max_total_tokens,
master_shard_uds_path, master_shard_uds_path,
@ -170,6 +157,19 @@ async fn main() -> Result<(), RouterError> {
) )
.await?; .await?;
// Validate remaining args now that the backend is known
let support_chunking = backend_info.support_chunking;
let max_batch_total_tokens = backend_info.max_batch_total_tokens;
if max_input_tokens as u32 > max_batch_prefill_tokens && !support_chunking {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
}
if max_batch_prefill_tokens > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
// Run server // Run server
server::run( server::run(
backend, backend,

View File

@ -4,7 +4,7 @@ use crate::client::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
}; };
use nohash_hasher::{BuildNoHashHasher, IntMap}; use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::{max, min}; use std::cmp::max;
use std::collections::VecDeque; use std::collections::VecDeque;
use text_generation_router::infer::InferError; use text_generation_router::infer::InferError;
use text_generation_router::infer::InferStreamResponse; use text_generation_router::infer::InferStreamResponse;
@ -50,6 +50,7 @@ impl Queue {
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
support_chunking: bool,
) -> Self { ) -> Self {
// Create channel // Create channel
let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
@ -62,6 +63,7 @@ impl Queue {
window_size, window_size,
speculate, speculate,
max_batch_total_tokens, max_batch_total_tokens,
support_chunking,
queue_receiver, queue_receiver,
)); ));
@ -87,6 +89,10 @@ impl Queue {
prefill_token_budget: u32, prefill_token_budget: u32,
token_budget: u32, token_budget: u32,
) -> Option<NextBatch> { ) -> Option<NextBatch> {
if prefill_token_budget == 0 || token_budget == 0 {
return None;
};
// Create response channel // Create response channel
let (response_sender, response_receiver) = oneshot::channel(); let (response_sender, response_receiver) = oneshot::channel();
// Send next batch command to the background task managing the state // Send next batch command to the background task managing the state
@ -108,6 +114,7 @@ impl Queue {
} }
// Background task responsible of the queue state // Background task responsible of the queue state
#[allow(clippy::too_many_arguments)]
async fn queue_task( async fn queue_task(
requires_padding: bool, requires_padding: bool,
block_size: u32, block_size: u32,
@ -115,6 +122,7 @@ async fn queue_task(
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
support_chunking: bool,
mut receiver: mpsc::UnboundedReceiver<QueueCommand>, mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
) { ) {
let mut state = State::new( let mut state = State::new(
@ -124,6 +132,7 @@ async fn queue_task(
window_size, window_size,
speculate, speculate,
max_batch_total_tokens, max_batch_total_tokens,
support_chunking,
); );
while let Some(cmd) = receiver.recv().await { while let Some(cmd) = receiver.recv().await {
@ -166,12 +175,14 @@ struct State {
/// Paged Attention block size /// Paged Attention block size
block_size: u32, block_size: u32,
/// Sliding window
window_size: Option<u32>,
/// Speculation amount /// Speculation amount
speculate: u32, speculate: u32,
/// Whether the model allow the prefill chunking
/// If it does, the last request in the batch will be split to exactly match the prefill
/// token budget
support_chunking: bool,
/// Paged Attention Block Allocation /// Paged Attention Block Allocation
block_allocator: Option<BlockAllocator>, block_allocator: Option<BlockAllocator>,
} }
@ -184,6 +195,7 @@ impl State {
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
support_chunking: bool,
) -> Self { ) -> Self {
let block_allocator = (!requires_padding).then(|| { let block_allocator = (!requires_padding).then(|| {
BlockAllocator::new( BlockAllocator::new(
@ -199,8 +211,8 @@ impl State {
next_id: 0, next_id: 0,
next_batch_id: 0, next_batch_id: 0,
block_size, block_size,
window_size,
speculate, speculate,
support_chunking,
block_allocator, block_allocator,
} }
} }
@ -287,32 +299,7 @@ impl State {
} }
None None
} }
Some(_block_allocator) => { Some(block_allocator) => {
prefill_tokens += entry.request.input_length;
let max_new_tokens = match self.window_size {
None => entry.request.stopping_parameters.max_new_tokens,
Some(window_size) => min(
window_size.saturating_sub(entry.request.input_length),
entry.request.stopping_parameters.max_new_tokens,
),
};
decode_tokens += max_new_tokens;
if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
{
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
self.entries.push_front((id, entry));
break;
}
let tokens = entry.request.input_length
+ entry.request.stopping_parameters.max_new_tokens
+ self.speculate
- 1;
// If users wants the prefill logprobs, we cannot reuse the cache. // If users wants the prefill logprobs, we cannot reuse the cache.
// So no input_ids for the radix tree. // So no input_ids for the radix tree.
let input_ids = if entry.request.decoder_input_details { let input_ids = if entry.request.decoder_input_details {
@ -321,10 +308,73 @@ impl State {
entry.request.input_ids.clone() entry.request.input_ids.clone()
}; };
Some((tokens, input_ids)) let tokens = entry.request.input_length
+ entry.request.stopping_parameters.max_new_tokens
+ self.speculate
- 1;
tracing::debug!("Allocating {tokens} with {input_ids:?}");
let block_allocation = match block_allocator.allocate(tokens, input_ids).await {
None => {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: not enough free blocks");
self.entries.push_front((id, entry));
break 'entry_loop;
}
Some(mut block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}");
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
if block_allocation.prefix_len == entry.request.input_length {
// The whole request was found in the radix trie
// However, for the transformer forward to work, we need to
// have at least one token of postfix.
block_allocation.prefix_len -= 1;
}
block_allocation
} }
}; };
batch.push((id, entry, block_allocation));
let postfix_len = entry.request.input_length - block_allocation.prefix_len;
if prefill_tokens + postfix_len > prefill_token_budget {
// Entry is over budget
if self.support_chunking {
// We support chunking, just set postfix_len to exactly match prefill_token_budget
let chunk_len = prefill_token_budget.saturating_sub(prefill_tokens);
if chunk_len > 0 {
// Push this entry inside the batch
batch.push((id, entry, Some(block_allocation), Some(chunk_len)));
} else {
// We cannot prefill even one token for this entry
// Add it back to the queue
self.entries.push_front((id, entry));
}
tracing::debug!(
"Matched budget: prefill_tokens={} == {prefill_token_budget}",
prefill_tokens + postfix_len
);
break 'entry_loop;
} else {
// We don't support chunking, this entry needs to go back to the buffer
// Add it back to the front
tracing::debug!(
"Over budget: prefill_tokens={} > {prefill_token_budget}",
prefill_tokens + postfix_len
);
self.entries.push_front((id, entry));
break 'entry_loop;
}
}
prefill_tokens += postfix_len;
Some(block_allocation)
}
};
batch.push((id, entry, block_allocation, None));
if Some(batch.len()) == max_size { if Some(batch.len()) == max_size {
break; break;
} }
@ -342,7 +392,7 @@ impl State {
// Batch is too small // Batch is too small
if batch.len() < min_size { if batch.len() < min_size {
// Add back entries to the queue in the correct order // Add back entries to the queue in the correct order
for (id, entry, _) in batch.into_iter().rev() { for (id, entry, _, _) in batch.into_iter().rev() {
self.entries.push_front((id, entry)); self.entries.push_front((id, entry));
} }
return None; return None;
@ -353,29 +403,7 @@ impl State {
let mut batch_entries = let mut batch_entries =
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default()); IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
for (id, mut entry, block_allocation) in batch { for (id, mut entry, block_allocation, chunk_len) in batch {
let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) =
(block_allocation, &self.block_allocator)
{
tracing::debug!("Allocating {tokens} with {input_ids:?}");
match block_allocator.allocate(tokens, input_ids).await {
None => {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: not enough free blocks");
self.entries.push_front((id, entry));
continue;
}
Some(block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}");
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
Some(block_allocation)
}
}
} else {
None
};
tracing::debug!("Accepting entry");
// Create a new span to link the batch back to this entry // Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer"); let entry_batch_span = info_span!(parent: &entry.span, "infer");
// Add relationships // Add relationships
@ -427,8 +455,9 @@ impl State {
top_n_tokens: entry.request.top_n_tokens, top_n_tokens: entry.request.top_n_tokens,
blocks, blocks,
slots, slots,
prefix_len, cache_len: prefix_len,
adapter_id: entry.request.adapter_id.clone(), adapter_id: entry.request.adapter_id.clone(),
chunk_len,
}); });
// Set batch_time // Set batch_time
entry.batch_time = Some(Instant::now()); entry.batch_time = Some(Instant::now());
@ -436,12 +465,6 @@ impl State {
batch_entries.insert(id, entry); batch_entries.insert(id, entry);
} }
// Empty batch
if batch_requests.is_empty() {
tracing::debug!("Filterered out all entries");
return None;
}
// Final batch size // Final batch size
let size = batch_requests.len() as u32; let size = batch_requests.len() as u32;
next_batch_span.record("batch_size", size); next_batch_span.record("batch_size", size);
@ -531,7 +554,7 @@ mod tests {
request: ValidGenerateRequest { request: ValidGenerateRequest {
inputs: vec![], inputs: vec![],
input_ids: Some(Arc::new(vec![])), input_ids: Some(Arc::new(vec![])),
input_length: 0, input_length: 1,
add_special_tokens: true, add_special_tokens: true,
truncate: 0, truncate: 0,
decoder_input_details: false, decoder_input_details: false,
@ -567,7 +590,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_append() { async fn test_append() {
let mut state = State::new(false, 1, false, None, 0, 16); let mut state = State::new(false, 1, false, None, 0, 16, false);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0); assert_eq!(state.next_id, 0);
@ -583,7 +606,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_empty() { async fn test_next_batch_empty() {
let mut state = State::new(false, 1, false, None, 0, 16); let mut state = State::new(false, 1, false, None, 0, 16, false);
assert!(state.next_batch(None, None, 1, 1).await.is_none()); assert!(state.next_batch(None, None, 1, 1).await.is_none());
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none()); assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
@ -591,7 +614,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_min_size() { async fn test_next_batch_min_size() {
let mut state = State::new(false, 1, false, None, 0, 16); let mut state = State::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -623,7 +646,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_max_size() { async fn test_next_batch_max_size() {
let mut state = State::new(false, 1, false, None, 0, 16); let mut state = State::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -643,7 +666,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_token_budget() { async fn test_next_batch_token_budget() {
let mut state = State::new(false, 1, false, None, 0, 2); let mut state = State::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -676,14 +699,14 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_append() { async fn test_queue_append() {
let queue = Queue::new(false, 1, false, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16, false);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
queue.append(entry); queue.append(entry);
} }
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_empty() { async fn test_queue_next_batch_empty() {
let queue = Queue::new(false, 1, false, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16, false);
assert!(queue.next_batch(None, None, 1, 1).await.is_none()); assert!(queue.next_batch(None, None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
@ -691,7 +714,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_min_size() { async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false, 1, false, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -724,7 +747,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_max_size() { async fn test_queue_next_batch_max_size() {
let queue = Queue::new(false, 1, false, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -740,7 +763,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_budget() { async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false, 1, false, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -765,7 +788,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_speculate() { async fn test_queue_next_batch_token_speculate() {
let queue = Queue::new(false, 1, false, None, 2, 16); let queue = Queue::new(true, 1, false, None, 2, 16, false);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -784,7 +807,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_dropped_receiver() { async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false, 1, false, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16, false);
let (entry, _) = default_entry(); let (entry, _) = default_entry();
queue.append(entry); queue.append(entry);

View File

@ -158,7 +158,8 @@ async fn prefill(
top_n_tokens: top_n_tokens.unwrap_or(0), top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
prefix_len: 0, cache_len: 0,
chunk_len: None,
adapter_id: None, adapter_id: None,
}) })
.collect(); .collect();
@ -173,7 +174,7 @@ async fn prefill(
// Run prefill // Run prefill
let start_time = Instant::now(); let start_time = Instant::now();
let (_, decode_batch, _) = client.prefill(batch.clone()).await?; let (_, decode_batch, _) = client.prefill(batch.clone(), None).await?;
// Get latency // Get latency
let latency = start_time.elapsed(); let latency = start_time.elapsed();

View File

@ -178,6 +178,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.clear_cache(None) .clear_cache(None)
.await .await
.expect("Unable to clear cache"); .expect("Unable to clear cache");
tracing::info!("Connected"); tracing::info!("Connected");
// Run app // Run app

View File

@ -9,13 +9,16 @@ import subprocess
import sys import sys
import tempfile import tempfile
import time import time
from typing import Dict, List, Optional
import docker import docker
import pytest import pytest
import base64
from pathlib import Path
from typing import Dict, List, Optional
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
from docker.errors import NotFound from docker.errors import NotFound
from syrupy.extensions.json import JSONSnapshotExtension from syrupy.extensions.json import JSONSnapshotExtension
from text_generation import AsyncClient from text_generation import AsyncClient
from text_generation.types import ( from text_generation.types import (
BestOfSequence, BestOfSequence,
@ -403,6 +406,7 @@ def launcher(event_loop):
print(" ".join(args), file=sys.stderr) print(" ".join(args), file=sys.stderr)
env["LOG_LEVEL"] = "info,text_generation_router=debug" env["LOG_LEVEL"] = "info,text_generation_router=debug"
env["PREFILL_CHUNKING"] = "1"
if not use_flash_attention: if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false" env["USE_FLASH_ATTENTION"] = "false"
@ -501,6 +505,7 @@ def launcher(event_loop):
env = { env = {
"LOG_LEVEL": "info,text_generation_router=debug", "LOG_LEVEL": "info,text_generation_router=debug",
"PREFILL_CHUNKING": "1",
} }
if not use_flash_attention: if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false" env["USE_FLASH_ATTENTION"] = "false"
@ -642,3 +647,22 @@ def generate_multi():
return responses return responses
return generate_load_inner return generate_load_inner
# TODO fix the server parsser to count inline image tokens correctly
@pytest.fixture
def chicken():
path = Path(__file__).parent / "images" / "chicken_on_money.png"
with open(path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.fixture
def cow_beach():
path = Path(__file__).parent / "images" / "cow_beach.png"
with open(path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"

View File

@ -1,5 +1,4 @@
import pytest import pytest
import base64
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -20,24 +19,11 @@ async def flash_pali_gemma(flash_pali_gemma_handle):
return flash_pali_gemma_handle.client return flash_pali_gemma_handle.client
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.release @pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot): async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot, cow_beach):
cow = get_cow_beach() inputs = f"![]({cow_beach})Where is the cow standing?\n"
inputs = f"![]({cow})Where is the cow standing?\n"
response = await flash_pali_gemma.generate(inputs, max_new_tokens=20) response = await flash_pali_gemma.generate(inputs, max_new_tokens=20)
assert response.generated_text == "beach" assert response.generated_text == "beach"
@ -47,9 +33,9 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
@pytest.mark.release @pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot): async def test_flash_pali_gemma_two_images(
chicken = get_chicken() flash_pali_gemma, response_snapshot, chicken, cow_beach
cow_beach = get_cow_beach() ):
response = await flash_pali_gemma.generate( response = await flash_pali_gemma.generate(
f"caption![]({chicken})![]({cow_beach})\n", f"caption![]({chicken})![]({cow_beach})\n",
max_new_tokens=20, max_new_tokens=20,

View File

@ -25,7 +25,6 @@ async def llama_grammar(llama_grammar_handle):
@pytest.mark.release @pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot): async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
class Weather(BaseModel): class Weather(BaseModel):
unit: str unit: str
temperature: List[int] temperature: List[int]

View File

@ -1,5 +1,4 @@
import pytest import pytest
import base64
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -16,22 +15,8 @@ async def idefics(idefics_handle):
return idefics_handle.client return idefics_handle.client
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_idefics(idefics, response_snapshot): async def test_idefics(idefics, response_snapshot, chicken):
chicken = get_chicken()
response = await idefics.generate( response = await idefics.generate(
f"User:![]({chicken})Can you tell me a very short story based on the image?", f"User:![]({chicken})Can you tell me a very short story based on the image?",
max_new_tokens=10, max_new_tokens=10,
@ -48,9 +33,7 @@ async def test_idefics(idefics, response_snapshot):
@pytest.mark.release @pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_idefics_two_images(idefics, response_snapshot): async def test_idefics_two_images(idefics, response_snapshot, chicken, cow_beach):
chicken = get_chicken()
cow_beach = get_cow_beach()
response = await idefics.generate( response = await idefics.generate(
f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:", f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:",
max_new_tokens=20, max_new_tokens=20,
@ -63,8 +46,7 @@ async def test_idefics_two_images(idefics, response_snapshot):
@pytest.mark.release @pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_idefics_load(idefics, generate_load, response_snapshot): async def test_idefics_load(idefics, generate_load, response_snapshot, chicken):
chicken = get_chicken()
responses = await generate_load( responses = await generate_load(
idefics, idefics,
f"User:![]({chicken})Can you tell me a very short story based on the image?", f"User:![]({chicken})Can you tell me a very short story based on the image?",

View File

@ -1,18 +1,4 @@
import pytest import pytest
import base64
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -31,8 +17,9 @@ async def flash_idefics2_next(flash_idefics2_next_handle):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot): async def test_flash_idefics2_next_simple(
chicken = get_chicken() flash_idefics2_next, response_snapshot, chicken
):
response = await flash_idefics2_next.generate( response = await flash_idefics2_next.generate(
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:", f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",
max_new_tokens=10, max_new_tokens=10,
@ -46,9 +33,9 @@ async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot): async def test_flash_idefics2_two_images(
chicken = get_chicken() flash_idefics2_next, response_snapshot, chicken, cow_beach
cow_beach = get_cow_beach() ):
response = await flash_idefics2_next.generate( response = await flash_idefics2_next.generate(
f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:", f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:",
max_new_tokens=20, max_new_tokens=20,
@ -87,9 +74,8 @@ async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snap
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_idefics2_next_load( async def test_flash_idefics2_next_load(
flash_idefics2_next, generate_load, response_snapshot flash_idefics2_next, generate_load, response_snapshot, chicken
): ):
chicken = get_chicken()
responses = await generate_load( responses = await generate_load(
flash_idefics2_next, flash_idefics2_next,
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:", f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",

View File

@ -1,12 +1,4 @@
import pytest import pytest
import base64
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -29,8 +21,7 @@ async def flash_llava_next(flash_llava_next_handle):
@pytest.mark.release @pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot): async def test_flash_llava_next_simple(flash_llava_next, response_snapshot, chicken):
chicken = get_chicken()
response = await flash_llava_next.generate( response = await flash_llava_next.generate(
f"User:![]({chicken})Can you tell me a very short story based on the image?", f"User:![]({chicken})Can you tell me a very short story based on the image?",
max_new_tokens=10, max_new_tokens=10,
@ -70,9 +61,8 @@ async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llava_next_load( async def test_flash_llava_next_load(
flash_llava_next, generate_load, response_snapshot flash_llava_next, generate_load, response_snapshot, chicken
): ):
chicken = get_chicken()
responses = await generate_load( responses = await generate_load(
flash_llava_next, flash_llava_next,
f"User:![]({chicken})Can you tell me a very short story based on the image?", f"User:![]({chicken})Can you tell me a very short story based on the image?",

View File

@ -1,5 +1,4 @@
import pytest import pytest
import base64
import asyncio import asyncio
@ -15,22 +14,8 @@ async def mllama(mllama_handle):
return mllama_handle.client return mllama_handle.client
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mllama_simpl(mllama, response_snapshot): async def test_mllama_simpl(mllama, response_snapshot):
# chicken = get_chicken()
response = await mllama.chat( response = await mllama.chat(
max_tokens=10, max_tokens=10,
temperature=0.0, temperature=0.0,

View File

@ -68,7 +68,7 @@ fn get_config(
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) { fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
let compute_capability = gpu::get_cuda_capability(); let compute_capability = gpu::get_cuda_capability();
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok(); let mut prefix_caching: Option<String> = std::env::var("PREFIX_CACHING").ok();
let mut attention: Option<String> = std::env::var("ATTENTION").ok(); let mut attention: Option<String> = std::env::var("ATTENTION").ok();
if let Some(config) = config { if let Some(config) = config {
if prefix_caching.is_none() { if prefix_caching.is_none() {
@ -124,6 +124,10 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
} }
} }
} }
if attention == Some("paged".to_string()) && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching on paged attention");
prefix_caching = Some("0".to_string());
}
let attention = attention.unwrap_or("flashinfer".to_string()); let attention = attention.unwrap_or("flashinfer".to_string());
let prefix_caching = prefix_caching.unwrap_or("true".to_string()); let prefix_caching = prefix_caching.unwrap_or("true".to_string());
@ -1678,7 +1682,7 @@ fn main() -> Result<(), LauncherError> {
}; };
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters); let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}"); tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
std::env::set_var("USE_PREFIX_CACHING", prefix_caching); std::env::set_var("PREFIX_CACHING", prefix_caching);
std::env::set_var("ATTENTION", attention); std::env::set_var("ATTENTION", attention);
let max_input_tokens = { let max_input_tokens = {
@ -1729,12 +1733,6 @@ fn main() -> Result<(), LauncherError> {
"`max_input_tokens must be < `max_total_tokens`".to_string(), "`max_input_tokens must be < `max_total_tokens`".to_string(),
)); ));
} }
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
max_batch_prefill_tokens, max_input_tokens
)));
}
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) { if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases."); tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
@ -1788,12 +1786,6 @@ fn main() -> Result<(), LauncherError> {
} }
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens { if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
max_batch_prefill_tokens, max_batch_total_tokens
)));
}
if max_total_tokens as u32 > *max_batch_total_tokens { if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!( return Err(LauncherError::ArgumentValidation(format!(
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",

View File

@ -34,6 +34,10 @@ message InfoResponse {
string device_type = 3; string device_type = 3;
optional uint32 window_size = 4; optional uint32 window_size = 4;
uint32 speculate = 5; uint32 speculate = 5;
bool support_chunking = 6;
bool use_prefix_caching = 7;
string attention_impl = 8;
uint32 block_size = 9;
} }
/// Empty request /// Empty request
@ -135,10 +139,14 @@ message Request {
repeated uint32 slots = 10; repeated uint32 slots = 10;
/// LORA adapter index /// LORA adapter index
optional string adapter_id = 11; optional string adapter_id = 11;
/// Prefix length that can be retrieved from the KV cache. /// Tokens that can be retrieved from the KV cache.
uint32 prefix_len = 12; /// This value is set for the first prefill and never reset
uint32 cache_len = 12;
/// Context truncation /// Context truncation
bool add_special_tokens = 13; bool add_special_tokens = 13;
/// Chunk of tokens that must be computed for the first prefill
/// This value is set for the first prefill and never reset
optional uint32 chunk_len = 14;
} }
message Batch { message Batch {
@ -163,6 +171,8 @@ message CachedBatch {
uint32 size = 3; uint32 size = 3;
/// Maximum number of tokens this batch will grow to /// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4; uint32 max_tokens = 4;
/// Number of tokens in the next forward
uint32 current_tokens = 5;
} }
enum FinishReason { enum FinishReason {
@ -220,6 +230,8 @@ message FilterBatchResponse {
message PrefillRequest { message PrefillRequest {
/// Batch /// Batch
Batch batch = 1; Batch batch = 1;
/// Optional cached batch
CachedBatch cached_batch = 2;
} }
message PrefillResponse { message PrefillResponse {
@ -233,6 +245,8 @@ message PrefillResponse {
uint64 decode_ns = 4; uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds /// Total elapsed time in nanoseconds
uint64 total_ns = 5; uint64 total_ns = 5;
/// Concatenate elapsed time in nanoseconds
optional uint64 concat_ns = 6;
} }
message DecodeRequest { message DecodeRequest {

View File

@ -18,45 +18,6 @@ use tracing::warn;
use utoipa::ToSchema; use utoipa::ToSchema;
use validation::Validation; use validation::Validation;
#[derive(PartialEq)]
pub enum Attention {
Paged,
FlashDecoding,
FlashInfer,
}
impl Attention {
pub fn block_size(&self) -> u32 {
match self {
Attention::FlashDecoding => 256,
Attention::FlashInfer => 1,
Attention::Paged => 16,
}
}
}
#[derive(Debug)]
pub struct ParseError;
impl std::fmt::Display for ParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Cannot parse attention value")
}
}
impl std::error::Error for ParseError {}
impl std::str::FromStr for Attention {
type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"paged" => Ok(Attention::Paged),
"flashdecoding" => Ok(Attention::FlashDecoding),
"flashinfer" => Ok(Attention::FlashInfer),
_ => Err(ParseError),
}
}
}
/// Hub type /// Hub type
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct HubModelInfo { pub struct HubModelInfo {

View File

@ -2,7 +2,7 @@ import pytest
import os import os
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
os.environ["USE_PREFIX_CACHING"] = "1" os.environ["PREFIX_CACHING"] = "1"
os.environ["ATTENTION"] = "flashinfer" os.environ["ATTENTION"] = "flashinfer"

View File

@ -9,6 +9,9 @@ from typing import Callable, Any
class ExceptionInterceptor(AsyncServerInterceptor): class ExceptionInterceptor(AsyncServerInterceptor):
def __init__(self, shutdown_callback):
self.shutdown_callback = shutdown_callback
async def intercept( async def intercept(
self, self,
method: Callable, method: Callable,
@ -25,7 +28,7 @@ class ExceptionInterceptor(AsyncServerInterceptor):
# Runtime Error cannot be recovered from # Runtime Error cannot be recovered from
if isinstance(err, RuntimeError): if isinstance(err, RuntimeError):
exit(1) self.shutdown_callback()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -1,16 +1,12 @@
from dataclasses import dataclass from dataclasses import dataclass
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ATTENTION
import torch import torch
from typing import Optional from typing import Optional
if ATTENTION in {"flashinfer", "flashdecoding"}: @dataclass
class Seqlen:
@dataclass
class Seqlen:
input_lengths: torch.Tensor input_lengths: torch.Tensor
prefix_lengths: torch.Tensor cache_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor] cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor] cu_seqlen_k: Optional[torch.Tensor]
max_q: int max_q: int
@ -19,13 +15,13 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
def __init__( def __init__(
self, self,
input_lengths, input_lengths,
prefix_lengths, cache_lengths,
cu_seqlen_q=None, cu_seqlen_q=None,
max_q=None, max_q=None,
max_k=None, max_k=None,
): ):
self.input_lengths = input_lengths self.input_lengths = input_lengths
self.prefix_lengths = prefix_lengths self.cache_lengths = cache_lengths
device = self.input_lengths.device device = self.input_lengths.device
shape = self.input_lengths.shape shape = self.input_lengths.shape
if cu_seqlen_q is None: if cu_seqlen_q is None:
@ -43,7 +39,7 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
# cuda graphs don't like this and this is necessary to clamp within mistral # cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping # Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0 # cu_seqlen_k[0] = 0
total = self.input_lengths + self.prefix_lengths total = self.input_lengths + self.cache_lengths
torch.cumsum(total, -1, out=cu_seqlen_k[1:]) torch.cumsum(total, -1, out=cu_seqlen_k[1:])
self.cu_seqlen_q = cu_seqlen_q self.cu_seqlen_q = cu_seqlen_q
@ -54,19 +50,3 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
def clamp(self, max): def clamp(self, max):
# Flash decoding doesn't need to clamp # Flash decoding doesn't need to clamp
return self return self
else:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
prefix_lengths: torch.Tensor
cu_seqlen_q: torch.Tensor
max_q: int
max_k: int
def clamp(self, max):
if SYSTEM == "rocm":
return self
self.input_lengths = torch.clamp(self.input_lengths, max=max)
return self

View File

@ -123,7 +123,7 @@ def paged_attention(
else: else:
if softcap is not None: if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping") raise RuntimeError("Paged attention doesn't support softcapping")
input_lengths = seqlen.input_lengths input_lengths = seqlen.input_lengths + seqlen.cache_lengths
from vllm._C import ops from vllm._C import ops
out = torch.empty_like(query) out = torch.empty_like(query)
@ -244,7 +244,8 @@ if ATTENTION == "flashinfer":
window_left=window_size_left, window_left=window_size_left,
) )
elif V2: elif ATTENTION == "flashdecoding":
if V2:
def attention( def attention(
q, q,
@ -284,7 +285,7 @@ elif V2:
None, None,
)[0] )[0]
else: else:
def attention( def attention(
q: torch.Tensor, q: torch.Tensor,
@ -302,7 +303,9 @@ else:
"window_size_left is only available with flash attn v2" "window_size_left is only available with flash attn v2"
) )
if softcap is not None: if softcap is not None:
raise NotImplementedError("softcap is only available with flash attn v2") raise NotImplementedError(
"softcap is only available with flash attn v2"
)
# Flash attention v1 requires q, k and v to have the same number of heads # Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]: if k.shape[1] != q.shape[1]:
@ -350,11 +353,123 @@ else:
) )
return out return out
elif ATTENTION == "paged":
if V2:
def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
out = torch.empty_like(q)
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
q,
key_cache,
value_cache,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
None,
None, # block_tables,
None,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
softcap,
False,
None,
)[0]
else:
def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap=None,
):
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
if softcap is not None:
raise NotImplementedError(
"softcap is only available with flash attn v2"
)
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
out = torch.empty_like(q)
flash_attn_cuda.fwd(
q,
k,
v,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
False,
0,
None,
)
return out
else:
raise RuntimeError(f"Unknwon attention {ATTENTION}")
# Prefill in the cache with every kind of attention, unless we # Prefill in the cache with every kind of attention, unless we
# have a configuration that requires flash-attention v1, which # have a configuration that requires flash-attention v1, which
# does not support block tables. # does not support block tables.
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2 PREFILL_IN_KV_CACHE = ATTENTION == "flashinfer" or (ATTENTION == "flashdecoding" and V2)
__all__ = [ __all__ = [
"PREFILL_IN_KV_CACHE", "PREFILL_IN_KV_CACHE",

View File

@ -699,7 +699,6 @@ def check_args(
class _attention(torch.autograd.Function): class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
ctx, ctx,

View File

@ -66,6 +66,7 @@ def paged_attention(
softcap: Optional[float] = None, softcap: Optional[float] = None,
): ):
out = torch.empty_like(query) out = torch.empty_like(query)
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out, out,
query, query,
@ -74,7 +75,7 @@ def paged_attention(
kv_head_mapping, kv_head_mapping,
softmax_scale, softmax_scale,
block_tables, block_tables,
seqlen.input_lengths, input_lengths,
BLOCK_SIZE, BLOCK_SIZE,
max_s, max_s,
None, None,

View File

@ -104,7 +104,7 @@ def paged_attention(
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM _PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = seqlen.input_lengths input_lengths = seqlen.input_lengths + seqlen.cache_lengths
out = torch.empty_like(query) out = torch.empty_like(query)

View File

@ -76,6 +76,7 @@ class CausalLMBatch(Batch):
request_ids=[r.id for r in self.requests], request_ids=[r.id for r in self.requests],
size=len(self), size=len(self),
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
current_tokens=len(self.input_ids),
) )
@classmethod @classmethod

View File

@ -493,9 +493,14 @@ class MllamaVisionModel(nn.Module):
aspect_ratio_ids: torch.Tensor, aspect_ratio_ids: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
batch_size, num_concurrent_media, num_tiles, num_channels, height, width = ( (
pixel_values.shape batch_size,
) num_concurrent_media,
num_tiles,
num_channels,
height,
width,
) = pixel_values.shape
pixel_values = pixel_values.reshape( pixel_values = pixel_values.reshape(
batch_size * num_concurrent_media * num_tiles, num_channels, height, width batch_size * num_concurrent_media * num_tiles, num_channels, height, width

File diff suppressed because it is too large Load Diff

View File

@ -5,9 +5,14 @@ from typing import Dict, Optional
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"} ATTENTION = os.environ["ATTENTION"]
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
"1",
"true",
}
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "0").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
ATTENTION = os.getenv("ATTENTION")
_expected = {"paged", "flashdecoding", "flashinfer"} _expected = {"paged", "flashdecoding", "flashinfer"}
assert ( assert (
ATTENTION in _expected ATTENTION in _expected
@ -18,7 +23,7 @@ if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
raise RuntimeError("Prefix caching is only supported with flashinfer") raise RuntimeError("Prefix caching is only supported with flashinfer")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95")) TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90"))
assert TGI_WIGGLE_ROOM > 0 assert TGI_WIGGLE_ROOM > 0
assert TGI_WIGGLE_ROOM < 1 assert TGI_WIGGLE_ROOM < 1

View File

@ -83,6 +83,7 @@ class IdeficsCausalLMBatch(Batch):
request_ids=[r.id for r in self.requests], request_ids=[r.id for r in self.requests],
size=len(self), size=len(self),
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
current_tokens=len(self),
) )
@classmethod @classmethod

View File

@ -116,6 +116,7 @@ class MambaBatch(Batch):
request_ids=[r.id for r in self.requests], request_ids=[r.id for r in self.requests],
size=len(self), size=len(self),
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
current_tokens=len(self),
) )
@classmethod @classmethod

View File

@ -1,14 +1,17 @@
from io import BytesIO
from PIL import Image
import torch import torch
import numpy as np
from typing import Iterable, Optional, Tuple, List, Dict from typing import Iterable, Optional, Tuple, List, Dict
from text_generation_server.pb.generate_pb2 import Request from text_generation_server.pb.generate_pb2 import Request
from io import BytesIO
from PIL import Image
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import ( from transformers import (
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
) )
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import ( from text_generation_server.models.flash_causal_lm import (
@ -167,6 +170,13 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp( batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(
max=config.text_config.vocab_size - 1 max=config.text_config.vocab_size - 1
) )
if isinstance(batch.input_ids, list):
if len(batch) > 1:
input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
else:
input_ids = batch.input_ids[0]
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1) batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
if image_inputs is not None: if image_inputs is not None:
@ -190,7 +200,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
class MllamaCausalLM(VlmCausalLM): class MllamaCausalLM(VlmCausalLM):
def forward( def forward(
self, self,
batch: VlmCausalLMBatch, batch: MllamaCausalLMBatch,
adapter_data: Optional[Dict[str, torch.Tensor]] = None, adapter_data: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward # Model Forward
@ -202,7 +212,7 @@ class MllamaCausalLM(VlmCausalLM):
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids speculative_ids = batch.speculative_ids
@ -221,8 +231,8 @@ class MllamaCausalLM(VlmCausalLM):
input_lengths = ( input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1) ).view(-1)
prefix_lens_tensor = ( cache_lengths_tensor = (
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1) ).reshape(-1)
# Add Copy the block tables for all members # Add Copy the block tables for all members
@ -244,8 +254,8 @@ class MllamaCausalLM(VlmCausalLM):
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
prefix_lens_tensor = batch.prefix_lens_tensor cache_lengths_tensor = batch.cache_lengths_tensor
max_s = batch.max_seqlen max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None: if cu_seqlen_prefill is None and self.max_past() is not None:
@ -254,7 +264,6 @@ class MllamaCausalLM(VlmCausalLM):
# This makes sure the max_s for the decode pass is correct. # This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s) max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
# Try to find an associated cuda graph # Try to find an associated cuda graph
bs = input_ids.shape[0] bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
@ -269,26 +278,24 @@ class MllamaCausalLM(VlmCausalLM):
# Only run cuda graphs when there's no images. # Only run cuda graphs when there's no images.
or batch.cross_attention_states is not None or batch.cross_attention_states is not None
): ):
input_lengths = input_lengths + prefix_lens_tensor
if PREFIX_CACHING: if PREFIX_CACHING:
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
input_lengths=batch.input_lengths, input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens, cache_lengths=batch.cache_lengths,
) )
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths_tensor=input_lengths, input_lengths_tensor=input_lengths,
prefix_lens_tensor=prefix_lens_tensor, cache_lengths_tensor=cache_lengths_tensor,
): ):
max_k = (input_lengths + prefix_lens_tensor).max().item()
seqlen = Seqlen( seqlen = Seqlen(
input_lengths=input_lengths, input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor, cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill, cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s, max_q=batch.max_input_length,
max_k=max_k, max_k=batch.max_current_length,
) )
if batch.pixel_values is not None: if batch.pixel_values is not None:
@ -330,20 +337,32 @@ class MllamaCausalLM(VlmCausalLM):
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
input_lengths=batch.input_lengths, input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens, cache_lengths=batch.cache_lengths,
) )
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else: else:
cuda_graph["block_tables"][ cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1] : block_tables.shape[0], : block_tables.shape[1]
] = block_tables ] = block_tables
# XXX: This is working only because block 0 is reserved for the healthcheck
# so it doesn't matter if we override it with bogus values.
cuda_graph["slots"].fill_(0) cuda_graph["slots"].fill_(0)
cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
input_lengths + prefix_lens_tensor cuda_graph["cache_lengths"].zero_()
) cuda_graph["cache_lengths"][
: cache_lengths_tensor.shape[0]
] = cache_lengths_tensor
with self._forward_context(
block_tables=cuda_graph["block_tables"],
cu_seqlen_prefill=None,
input_lengths_tensor=cuda_graph["input_lengths"],
cache_lengths_tensor=cuda_graph["cache_lengths"],
state=cuda_graph["state"],
):
# Replay the graph # Replay the graph
cuda_graph["graph"].replay() cuda_graph["graph"].replay()

View File

@ -5,8 +5,17 @@ from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type, Dict from typing import List, Tuple, Optional, TypeVar, Type, Dict
from collections import defaultdict from collections import defaultdict
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from loguru import logger
from text_generation_server.models.globals import (
ATTENTION,
PREFIX_CACHING,
BLOCK_SIZE,
PREFILL_CHUNKING,
)
from text_generation_server.models.types import Batch, Generation from text_generation_server.models.types import Batch, Generation
from text_generation_server.utils.log import log_master
from text_generation_server.utils.prefill_chunking import set_support_chunking
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
from text_generation_server.pb.generate_pb2 import InfoResponse from text_generation_server.pb.generate_pb2 import InfoResponse
from text_generation_server.adapters.weights import LayerAdapterWeights from text_generation_server.adapters.weights import LayerAdapterWeights
@ -31,6 +40,7 @@ class Model(ABC):
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
speculate: Optional[int] = None, speculate: Optional[int] = None,
adapter_id: str = BASE_MODEL_ADAPTER_ID, adapter_id: str = BASE_MODEL_ADAPTER_ID,
support_chunking: bool = False,
): ):
self.model_id = model_id self.model_id = model_id
self.model = model.eval() self.model = model.eval()
@ -60,6 +70,29 @@ class Model(ABC):
speculate = get_speculate() speculate = get_speculate()
self.speculate = speculate self.speculate = speculate
support_chunking = support_chunking and PREFILL_CHUNKING
if speculate != 0 and support_chunking:
log_master(
logger.warning,
"Prefill chunking does not support speculation yet. "
"Prefill chunking will be turned off",
)
support_chunking = False
if ATTENTION not in ["flashinfer", "flashdecoding"] and support_chunking:
log_master(
logger.warning,
"Prefill chunking is only supported with `flashinfer` or `flashdecoding` attention types.",
)
support_chunking = False
log_master(
logger.info, f"Using experimental prefill chunking = {support_chunking}"
)
self.support_chunking = support_chunking
set_support_chunking(support_chunking)
self.has_position_ids = ( self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None) inspect.signature(model.forward).parameters.get("position_ids", None)
is not None is not None
@ -78,6 +111,10 @@ class Model(ABC):
device_type=self.device.type, device_type=self.device.type,
window_size=self.sliding_window, window_size=self.sliding_window,
speculate=self.speculate, speculate=self.speculate,
support_chunking=self.support_chunking,
use_prefix_caching=PREFIX_CACHING,
attention_impl=ATTENTION,
block_size=BLOCK_SIZE,
) )
@property @property

View File

@ -80,6 +80,7 @@ class Seq2SeqLMBatch(Batch):
request_ids=[r.id for r in self.requests], request_ids=[r.id for r in self.requests],
size=len(self), size=len(self),
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
current_tokens=len(self.decoder_input_ids),
) )
@classmethod @classmethod

View File

@ -74,6 +74,14 @@ class Tokens:
def __len__(self): def __len__(self):
return len(self.token_ids) return len(self.token_ids)
def __add__(self, other: "Tokens") -> "Tokens":
return Tokens(
self.token_ids + other.token_ids,
self.logprobs + other.logprobs,
self.texts + other.texts,
self.is_special + other.is_special,
)
@dataclass @dataclass
class Generation: class Generation:

View File

@ -271,6 +271,8 @@ class VlmCausalLM(FlashCausalLM):
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
# FIXME: VLM do not work with context chunking yet
support_chunking=False,
**kwargs, **kwargs,
) )
@ -295,7 +297,7 @@ class VlmCausalLM(FlashCausalLM):
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids speculative_ids = batch.speculative_ids
@ -314,8 +316,8 @@ class VlmCausalLM(FlashCausalLM):
input_lengths = ( input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1) ).view(-1)
prefix_lens_tensor = ( cache_lengths_tensor = (
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1) ).reshape(-1)
# Add Copy the block tables for all members # Add Copy the block tables for all members
@ -337,8 +339,8 @@ class VlmCausalLM(FlashCausalLM):
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
prefix_lens_tensor = batch.prefix_lens_tensor cache_lengths_tensor = batch.cache_lengths_tensor
max_s = batch.max_seqlen max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None: if cu_seqlen_prefill is None and self.max_past() is not None:
@ -347,7 +349,6 @@ class VlmCausalLM(FlashCausalLM):
# This makes sure the max_s for the decode pass is correct. # This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s) max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
# Try to find an associated cuda graph # Try to find an associated cuda graph
bs = input_ids.shape[0] bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
@ -357,26 +358,24 @@ class VlmCausalLM(FlashCausalLM):
else: else:
cuda_graph = None cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = input_lengths + prefix_lens_tensor if ATTENTION == "flashinfer":
if PREFIX_CACHING:
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
input_lengths=batch.input_lengths, input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens, cache_lengths=batch.cache_lengths,
) )
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths_tensor=input_lengths, input_lengths_tensor=input_lengths,
prefix_lens_tensor=prefix_lens_tensor, cache_lengths_tensor=cache_lengths_tensor,
): ):
max_k = (input_lengths + prefix_lens_tensor).max().item()
seqlen = Seqlen( seqlen = Seqlen(
input_lengths=input_lengths, input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor, cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill, cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s, max_q=batch.max_input_length,
max_k=max_k, max_k=batch.max_current_length,
) )
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
@ -411,20 +410,32 @@ class VlmCausalLM(FlashCausalLM):
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
input_lengths=batch.input_lengths, input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens, cache_lengths=batch.cache_lengths,
) )
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else: else:
cuda_graph["block_tables"][ cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1] : block_tables.shape[0], : block_tables.shape[1]
] = block_tables ] = block_tables
cuda_graph["slots"].fill_(-1)
# XXX: This is working only because block 0 is reserved for the healthcheck
# so it doesn't matter if we override it with bogus values.
cuda_graph["slots"].fill_(0)
cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
input_lengths + prefix_lens_tensor cuda_graph["cache_lengths"].zero_()
) cuda_graph["cache_lengths"][
: cache_lengths_tensor.shape[0]
] = cache_lengths_tensor
with self._forward_context(
block_tables=cuda_graph["block_tables"],
cu_seqlen_prefill=None,
input_lengths_tensor=cuda_graph["input_lengths"],
cache_lengths_tensor=cuda_graph["cache_lengths"],
state=cuda_graph["state"],
):
# Replay the graph # Replay the graph
cuda_graph["graph"].replay() cuda_graph["graph"].replay()

View File

@ -15,6 +15,7 @@ from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model_with_lora_adapters from text_generation_server.models import Model, get_model_with_lora_adapters
from text_generation_server.utils.adapter import AdapterInfo from text_generation_server.utils.adapter import AdapterInfo
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
try: try:
from text_generation_server.models.pali_gemma import PaliGemmaBatch from text_generation_server.models.pali_gemma import PaliGemmaBatch
@ -46,9 +47,12 @@ class SignalHandler:
signal.signal(signal.SIGINT, self.exit_gracefully) signal.signal(signal.SIGINT, self.exit_gracefully)
signal.signal(signal.SIGTERM, self.exit_gracefully) signal.signal(signal.SIGTERM, self.exit_gracefully)
def set_keep_processing(self, value: bool):
self.KEEP_PROCESSING = value
def exit_gracefully(self, signum, frame): def exit_gracefully(self, signum, frame):
print(f"Exiting gracefully: Signal {signum}") print(f"Exiting gracefully: Signal {signum}")
self.KEEP_PROCESSING = False self.set_keep_processing(False)
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
@ -96,6 +100,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context): async def Warmup(self, request, context):
set_max_prefill_tokens(request.max_prefill_tokens)
if self.quantize in {"exl2", "gptq"}: if self.quantize in {"exl2", "gptq"}:
try: try:
# When using GPTQ, Exllama kernels need some global kernels # When using GPTQ, Exllama kernels need some global kernels
@ -150,6 +156,18 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
request.batch, self.model.tokenizer, self.model.dtype, self.model.device request.batch, self.model.tokenizer, self.model.dtype, self.model.device
) )
concat_ns = None
if self.model.support_chunking:
if request.HasField("cached_batch"):
cached_batch = self.cache.pop(request.cached_batch.id)
if cached_batch is None:
raise ValueError(
f"Batch ID {request.cached_batch.id} not found in cache."
)
start_concat = time.time_ns()
batch = self.model.batch_type.concatenate([cached_batch, batch])
concat_ns = time.time_ns() - start_concat
generations, next_batch, timings = self.model.generate_token(batch) generations, next_batch, timings = self.model.generate_token(batch)
self.cache.set(next_batch) self.cache.set(next_batch)
@ -159,6 +177,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
forward_ns=timings[0], forward_ns=timings[0],
decode_ns=timings[1], decode_ns=timings[1],
total_ns=time.time_ns() - start, total_ns=time.time_ns() - start,
concat_ns=concat_ns,
) )
async def Decode(self, request, context): async def Decode(self, request, context):
@ -252,10 +271,12 @@ def serve(
logger.exception("Error when initializing model") logger.exception("Error when initializing model")
raise raise
signal_handler = SignalHandler()
set_adapter_to_index(adapter_to_index) set_adapter_to_index(adapter_to_index)
server = aio.server( server = aio.server(
interceptors=[ interceptors=[
ExceptionInterceptor(), ExceptionInterceptor(lambda: signal_handler.set_keep_processing(False)),
UDSOpenTelemetryAioServerInterceptor(), UDSOpenTelemetryAioServerInterceptor(),
], ],
options=[ options=[
@ -276,7 +297,6 @@ def serve(
await server.start() await server.start()
logger.info("Server started at {}".format(local_url)) logger.info("Server started at {}".format(local_url))
signal_handler = SignalHandler()
while signal_handler.KEEP_PROCESSING: while signal_handler.KEEP_PROCESSING:
await asyncio.sleep(0.5) await asyncio.sleep(0.5)

View File

@ -120,8 +120,12 @@ def _load_and_merge(
if adapter.id == BASE_MODEL_ADAPTER_ID: if adapter.id == BASE_MODEL_ADAPTER_ID:
raise ValueError("Base model adapter cannot be merged.") raise ValueError("Base model adapter cannot be merged.")
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = ( (
load_module_map( module_map,
adapter_config,
adapter_weight_names,
adapter_tokenizer,
) = load_module_map(
model_id, model_id,
adapter.revision, adapter.revision,
adapter.id, adapter.id,
@ -129,7 +133,6 @@ def _load_and_merge(
weight_names, weight_names,
trust_remote_code, trust_remote_code,
) )
)
adapters_to_merge.append((module_map, adapter_config)) adapters_to_merge.append((module_map, adapter_config))
merged_weight_names = merged_weight_names.union(adapter_weight_names) merged_weight_names = merged_weight_names.union(adapter_weight_names)

View File

@ -0,0 +1,24 @@
from typing import Optional
SUPPORT_CHUNKING: Optional[bool] = None
MAX_PREFILL_TOKENS: Optional[int] = None
def set_support_chunking(support_chunking: bool):
global SUPPORT_CHUNKING
SUPPORT_CHUNKING = support_chunking
def get_support_chunking() -> bool:
global SUPPORT_CHUNKING
return SUPPORT_CHUNKING
def set_max_prefill_tokens(max_prefill_tokens: int):
global MAX_PREFILL_TOKENS
MAX_PREFILL_TOKENS = max_prefill_tokens
def get_max_prefill_tokens() -> int:
global MAX_PREFILL_TOKENS
return MAX_PREFILL_TOKENS

View File

@ -7,6 +7,7 @@ from typing import List, Tuple, Union
import torch import torch
# FIXME: this should be optimized
def find_segments( def find_segments(
adapter_indices: Union[torch.Tensor, List[int]] adapter_indices: Union[torch.Tensor, List[int]]
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]: