From a6a0c97ed92b46592572f15b1cd954c789205447 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Wed, 16 Oct 2024 12:49:33 +0200 Subject: [PATCH] feat: prefill chunking (#2600) * wip * rollback * refactor to use prefix/postfix namming + fix all_input_ids_tensor * maybe patching vlms? * fix filter and concat * wip, no filter, no concat * current * add prepare_for_prefill * working * load tested * re-create slots * re-create slots * fix slot_filtering_indices * feedback loop * remove log * fix benchmarker * fix vlm and seq2seq * rename to cache and input lengths * fix prefill logprobs * fix launcher * fix logprobs? * idk at this point * max input length * omfg * remove debugging lines * fix tests * fix mllama * fix cargo tests * remove support chunking for paged * Fixing non blocked attentions * Fixing dtype + AMD, Ipex targets. * lint fix. * rename * Fix prefix_caching variable, remove defaults in server (confusing a lot of the times). * Add simple resolution when user specifies ATTENTION=paged. * Put back non default simple tests. * Fix env name --------- Co-authored-by: Nicolas Patry --- Dockerfile_amd | 3 +- Dockerfile_intel | 3 +- backends/client/src/v3/client.rs | 10 +- backends/client/src/v3/sharded_client.rs | 8 +- backends/v2/src/backend.rs | 20 +- backends/v3/src/backend.rs | 134 +- backends/v3/src/client/grpc_client.rs | 21 +- backends/v3/src/client/mod.rs | 9 - backends/v3/src/client/sharded_client.rs | 28 +- backends/v3/src/lib.rs | 16 +- backends/v3/src/main.rs | 28 +- backends/v3/src/queue.rs | 177 +- benchmark/src/generation.rs | 5 +- benchmark/src/main.rs | 1 + integration-tests/conftest.py | 28 +- .../models/test_flash_pali_gemma.py | 24 +- .../test_grammar_response_format_llama.py | 1 - integration-tests/models/test_idefics.py | 24 +- integration-tests/models/test_idefics2.py | 28 +- integration-tests/models/test_llava_next.py | 14 +- integration-tests/models/test_mllama.py | 15 - launcher/src/main.rs | 20 +- proto/v3/generate.proto | 18 +- router/src/lib.rs | 39 - server/tests/conftest.py | 2 +- server/text_generation_server/interceptor.py | 5 +- .../layers/attention/common.py | 106 +- .../layers/attention/cuda.py | 319 ++-- .../layers/attention/flash_attn_triton.py | 1 - .../layers/attention/ipex.py | 3 +- .../layers/attention/rocm.py | 2 +- .../models/causal_lm.py | 1 + .../models/custom_modeling/mllama.py | 11 +- .../models/flash_causal_lm.py | 1444 ++++++++++------- .../text_generation_server/models/globals.py | 11 +- .../models/idefics_causal_lm.py | 1 + server/text_generation_server/models/mamba.py | 1 + .../models/mllama_causal_lm.py | 65 +- server/text_generation_server/models/model.py | 37 + .../models/seq2seq_lm.py | 1 + server/text_generation_server/models/types.py | 8 + .../models/vlm_causal_lm.py | 53 +- server/text_generation_server/server.py | 26 +- .../text_generation_server/utils/adapter.py | 21 +- .../utils/prefill_chunking.py | 24 + .../text_generation_server/utils/segments.py | 1 + 46 files changed, 1694 insertions(+), 1123 deletions(-) create mode 100644 server/text_generation_server/utils/prefill_chunking.py diff --git a/Dockerfile_amd b/Dockerfile_amd index 4bb6407a..b84d4edd 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -327,7 +327,8 @@ ENV ROCM_USE_CUSTOM_PAGED_ATTN=1 ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0 ENV VLLM_MOE_PADDING=0 ENV ATTENTION=paged -ENV USE_PREFIX_CACHING=0 +ENV PREFIX_CACHING=0 +ENV PREFILL_CHUNKING=0 ENV ROCM_USE_SKINNY_GEMM=1 COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh diff --git a/Dockerfile_intel b/Dockerfile_intel index 9b5dd20a..96f24248 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -218,7 +218,8 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/lo FROM ${PLATFORM} AS final ENV ATTENTION=paged -ENV USE_PREFIX_CACHING=0 +ENV PREFIX_CACHING=0 +ENV PREFILL_CHUNKING=0 ENV CUDA_GRAPHS=0 ENTRYPOINT ["text-generation-launcher"] CMD ["--json-output"] diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs index 479d31bf..d43f789e 100644 --- a/backends/client/src/v3/client.rs +++ b/backends/client/src/v3/client.rs @@ -158,7 +158,8 @@ impl Client { // Blocks and slots will be set on the server side if we use paged attention blocks: vec![], slots: vec![], - prefix_len: 0, + cache_len: 0, + chunk_len: None, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, @@ -217,8 +218,13 @@ impl Client { pub async fn prefill( &mut self, batch: Batch, + cached_batch: Option, ) -> Result<(Vec, Option, PrefillTimings)> { - let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); + let request = tonic::Request::new(PrefillRequest { + batch: Some(batch), + cached_batch, + }) + .inject_context(); let response = self.stub.prefill(request).await?.into_inner(); Ok(( response.generations, diff --git a/backends/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs index 645c076a..854a5895 100644 --- a/backends/client/src/v3/sharded_client.rs +++ b/backends/client/src/v3/sharded_client.rs @@ -134,11 +134,12 @@ impl ShardedClient { pub async fn prefill( &mut self, batch: Batch, + cached_batch: Option, ) -> Result<(Vec, Option, PrefillTimings)> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| Box::pin(client.prefill(batch.clone()))) + .map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone()))) .collect(); #[allow(clippy::type_complexity)] let results: Result, Option, PrefillTimings)>> = @@ -245,7 +246,8 @@ impl Health for ShardedClient { // Block 0 is reserved for health checks blocks: vec![0], slots: (0..16).collect(), - prefix_len: 0, + cache_len: 0, + chunk_len: None, adapter_id: None, }; let batch = Batch { @@ -255,7 +257,7 @@ impl Health for ShardedClient { max_tokens: 2, max_blocks: 1, }; - self.clone().prefill(batch).await?; + self.clone().prefill(batch, None).await?; Ok(()) } } diff --git a/backends/v2/src/backend.rs b/backends/v2/src/backend.rs index 086fc6dc..bc264138 100644 --- a/backends/v2/src/backend.rs +++ b/backends/v2/src/backend.rs @@ -6,7 +6,7 @@ use nohash_hasher::IntMap; use std::sync::Arc; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::validation::ValidGenerateRequest; -use text_generation_router::{Attention, FinishReason, PrefillToken, Token}; +use text_generation_router::{FinishReason, PrefillToken, Token}; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Notify}; use tokio::time::Instant; @@ -36,18 +36,14 @@ impl BackendV2 { speculate: u32, ) -> Self { // Infer shared state - let attention = if let Ok(attention) = std::env::var("ATTENTION") { - attention - .parse() - .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")) - } else { - Attention::Paged - }; - let block_size = if attention == Attention::FlashDecoding { - 256 - } else { - 16 + let attention = std::env::var("ATTENTION").unwrap_or("paged".to_string()); + let block_size = match attention.as_str() { + "flashinfer" => 1, + "flashdecoding" => 256, + "paged" => 16, + _ => unreachable!(), }; + let queue = Queue::new(requires_padding, block_size, window_size, speculate); let batching_task_notifier = Arc::new(Notify::new()); diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index f8a10ca2..a5c0f512 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -1,12 +1,14 @@ -use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient}; /// Batching and inference logic +use crate::client::{ + Batch, CachedBatch, ClientError, Generation, Health, InfoResponse, ShardedClient, +}; use crate::queue::{Entry, Queue}; use async_trait::async_trait; use nohash_hasher::IntMap; use std::sync::Arc; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::validation::ValidGenerateRequest; -use text_generation_router::{Attention, FinishReason, PrefillToken, Token}; +use text_generation_router::{FinishReason, PrefillToken, Token}; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Notify}; use tokio::time::Instant; @@ -31,27 +33,22 @@ impl BackendV3 { max_batch_total_tokens: u32, max_waiting_tokens: usize, max_batch_size: Option, - requires_padding: bool, - window_size: Option, - speculate: u32, + shard_info: InfoResponse, ) -> Self { - let prefix_caching = - std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var"); - let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1"); - let attention: String = std::env::var("ATTENTION").expect("attention env var"); + if shard_info.support_chunking { + tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored."); + } - let attention: Attention = attention - .parse() - .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")); - let block_size = attention.block_size(); + let block_size = shard_info.block_size; let queue = Queue::new( - requires_padding, + shard_info.requires_padding, block_size, - prefix_caching, - window_size, - speculate, + shard_info.use_prefix_caching, + shard_info.window_size, + shard_info.speculate, max_batch_total_tokens, + shard_info.support_chunking, ); let batching_task_notifier = Arc::new(Notify::new()); @@ -63,6 +60,7 @@ impl BackendV3 { max_batch_total_tokens, max_waiting_tokens, max_batch_size, + shard_info.support_chunking, queue.clone(), batching_task_notifier.clone(), )); @@ -127,6 +125,7 @@ pub(crate) async fn batching_task( max_batch_total_tokens: u32, max_waiting_tokens: usize, max_batch_size: Option, + support_chunking: bool, queue: Queue, notifier: Arc, ) { @@ -147,7 +146,7 @@ pub(crate) async fn batching_task( ) .await { - let mut cached_batch = prefill(&mut client, batch, &mut entries) + let mut cached_batch = prefill(&mut client, batch, None, &mut entries) .instrument(span) .await; let mut waiting_tokens = 1; @@ -158,28 +157,44 @@ pub(crate) async fn batching_task( // Get current batch info let batch_size = batch.size; let batch_max_tokens = batch.max_tokens; + let current_tokens = batch.current_tokens; let mut batches = vec![batch]; metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); - let min_size = if waiting_tokens >= max_waiting_tokens { - // If we didn't onboard any new requests since >= max_waiting_tokens, we try - // to add a new batch even though its size might be small - None + let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); + + let (min_size, max_size, prefill_token_budget) = if support_chunking { + // Since the next batch will be concatenated with the current batch, + // the current batch tokens must be subtracted to the prefill budget + let prefill_token_budget = + max_batch_prefill_tokens.saturating_sub(current_tokens); + // We can ignore min_size and max_size + // Models than rely on max_size cannot support chunking + // Regarding min_size, chunking allow us to consistently run at the compute + // bound, making min_size useless. + (None, None, prefill_token_budget) } else { - // Minimum batch size - // TODO: temporarily disable to avoid incorrect deallocation + - // reallocation when using prefix caching. - Some((batch_size as f32 * waiting_served_ratio).floor() as usize) + let min_size = if waiting_tokens >= max_waiting_tokens { + // If we didn't onboard any new requests since >= max_waiting_tokens, we try + // to add a new batch even though its size might be small + None + } else { + // Minimum batch size + // TODO: temporarily disable to avoid incorrect deallocation + + // reallocation when using prefix caching. + Some((batch_size as f32 * waiting_served_ratio).floor() as usize) + }; + + let max_size = + max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); + + (min_size, max_size, max_batch_prefill_tokens) }; - let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); - let max_size = - max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); - // Try to get a new batch - if let Some((mut new_entries, new_batch, span)) = queue - .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget) + if let Some((new_entries, new_batch, span)) = queue + .next_batch(min_size, max_size, prefill_token_budget, token_budget) .await { // Tracking metrics @@ -187,31 +202,45 @@ pub(crate) async fn batching_task( metrics::counter!("tgi_batch_concat", "reason" => "backpressure") .increment(1); } else { - metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") - .increment(1); + let counter = if support_chunking { + metrics::counter!("tgi_batch_concat", "reason" => "chunking") + } else { + metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") + }; + counter.increment(1); } - - entries.iter_mut().for_each(|(_, entry)| { - // Create a new span to add the info that this entry is waiting - // because a new batch is being computed - let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); - // Add relationships - span.follows_from(&entry_waiting_span); - entry_waiting_span.follows_from(&span); - // Update entry - entry.temp_span = Some(entry_waiting_span); - }); + let cached_batch = if support_chunking { + // Concat current batch to the new one + batches.pop() + } else { + // Request are waiting only if we don't support chunking + entries.iter_mut().for_each(|(_, entry)| { + // Create a new span to add the info that this entry is waiting + // because a new batch is being computed + let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); + // Add relationships + span.follows_from(&entry_waiting_span); + entry_waiting_span.follows_from(&span); + // Update entry + entry.temp_span = Some(entry_waiting_span); + }); + None + }; + entries.extend(new_entries); // Generate one token for this new batch to have the attention past in cache - let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries) - .instrument(span) - .await; + let new_cached_batch = + prefill(&mut client, new_batch, cached_batch, &mut entries) + .instrument(span) + .await; // Reset waiting counter waiting_tokens = 1; // Extend current batch with the new batch if let Some(new_cached_batch) = new_cached_batch { - entries.extend(new_entries); batches.push(new_cached_batch); + } else if support_chunking { + // New cached batch is empty, no work left + break; } } @@ -244,13 +273,14 @@ pub(crate) async fn batching_task( async fn prefill( client: &mut ShardedClient, batch: Batch, + cached_batch: Option, entries: &mut IntMap, ) -> Option { let start_time = Instant::now(); let batch_id = batch.id; metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1); - match client.prefill(batch).await { + match client.prefill(batch, cached_batch).await { Ok((generations, next_batch, timings)) => { let start_filtering_time = Instant::now(); // Send generated tokens and filter stopped entries @@ -259,6 +289,10 @@ async fn prefill( // Filter next batch and remove requests that were stopped let next_batch = filter_batch(client, next_batch, entries).await; + if let Some(concat_duration) = timings.concat { + metrics::histogram!("tgi_batch_concat_duration", "method" => "decode") + .record(concat_duration.as_secs_f64()); + } metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill") .record(timings.forward.as_secs_f64()); metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill") diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs index 648662db..fe810f24 100644 --- a/backends/v3/src/client/grpc_client.rs +++ b/backends/v3/src/client/grpc_client.rs @@ -158,7 +158,8 @@ impl Client { // Blocks and slots will be set on the server side if we use paged attention blocks: vec![], slots: vec![], - prefix_len: 0, + cache_len: 0, + chunk_len: None, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, @@ -217,13 +218,23 @@ impl Client { pub async fn prefill( &mut self, batch: Batch, + cached_batch: Option, ) -> Result<(Vec, Option, PrefillTimings)> { - let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); + let request = tonic::Request::new(PrefillRequest { + batch: Some(batch), + cached_batch, + }) + .inject_context(); let response = self.stub.prefill(request).await?.into_inner(); Ok(( response.generations, response.batch, - PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns), + PrefillTimings::new( + response.concat_ns, + response.forward_ns, + response.decode_ns, + response.total_ns, + ), )) } @@ -252,14 +263,16 @@ impl Client { } pub struct PrefillTimings { + pub concat: Option, pub forward: Duration, pub decode: Duration, pub total: Duration, } impl PrefillTimings { - fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { + fn new(concat_ns: Option, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { Self { + concat: concat_ns.map(Duration::from_nanos), forward: Duration::from_nanos(forward_ns), decode: Duration::from_nanos(decode_ns), total: Duration::from_nanos(total_ns), diff --git a/backends/v3/src/client/mod.rs b/backends/v3/src/client/mod.rs index 755431f4..d4ac50c9 100644 --- a/backends/v3/src/client/mod.rs +++ b/backends/v3/src/client/mod.rs @@ -29,15 +29,6 @@ pub trait Health { async fn model_health(&self) -> Result<()>; } -#[derive(Debug)] -pub struct ShardInfo { - pub requires_padding: bool, - pub dtype: String, - pub device_type: String, - pub window_size: Option, - pub speculate: u32, -} - #[derive(Error, Debug, Clone)] pub enum ClientError { #[error("Could not connect to Text Generation server: {0}")] diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index ea77a696..e181cd28 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -1,6 +1,6 @@ -use crate::client::{ClientError, Result}; +use crate::client::Health; /// Multi shard Client -use crate::client::{Health, ShardInfo}; +use crate::client::{ClientError, Result}; use crate::client::grpc_client::{DecodeTimings, PrefillTimings}; use crate::client::{ @@ -49,13 +49,13 @@ impl ShardedClient { /// Get the model info #[instrument(skip(self))] - pub async fn info(&mut self) -> Result { + pub async fn info(&mut self) -> Result { let futures: Vec<_> = self .clients .iter_mut() .map(|client| client.info()) .collect(); - join_all(futures).await.pop().unwrap().map(ShardInfo::from) + join_all(futures).await.pop().unwrap() } /// GRPC health check @@ -135,11 +135,12 @@ impl ShardedClient { pub async fn prefill( &mut self, batch: Batch, + cached_batch: Option, ) -> Result<(Vec, Option, PrefillTimings)> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| Box::pin(client.prefill(batch.clone()))) + .map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone()))) .collect(); #[allow(clippy::type_complexity)] let results: Result, Option, PrefillTimings)>> = @@ -194,18 +195,6 @@ impl ShardedClient { } } -impl From for ShardInfo { - fn from(value: InfoResponse) -> Self { - Self { - requires_padding: value.requires_padding, - dtype: value.dtype, - device_type: value.device_type, - window_size: value.window_size, - speculate: value.speculate, - } - } -} - #[async_trait] impl Health for ShardedClient { async fn device_health(&self) -> Result<()> { @@ -246,8 +235,9 @@ impl Health for ShardedClient { // Block 0 is reserved for health checks blocks: vec![0], slots: (0..16).collect(), - prefix_len: 0, + cache_len: 0, adapter_id: None, + chunk_len: None, }; let batch = Batch { id: u64::MAX, @@ -256,7 +246,7 @@ impl Health for ShardedClient { max_tokens: 2, max_blocks: 1, }; - self.clone().prefill(batch).await?; + self.clone().prefill(batch, None).await?; Ok(()) } } diff --git a/backends/v3/src/lib.rs b/backends/v3/src/lib.rs index af66b21e..7daf9eae 100644 --- a/backends/v3/src/lib.rs +++ b/backends/v3/src/lib.rs @@ -29,6 +29,14 @@ pub struct BackendInfo { pub max_waiting_tokens: usize, #[schema(nullable = true, example = "null")] pub max_batch_size: Option, + #[schema(example = "false")] + pub support_chunking: bool, + #[schema(example = "false")] + pub prefix_caching: bool, + #[schema(example = "flashinfer")] + pub attention_impl: String, + #[schema(example = "1")] + pub block_size: u32, } #[allow(clippy::too_many_arguments)] @@ -110,6 +118,10 @@ pub async fn connect_backend( model_device_type: shard_info.device_type.clone(), model_dtype: shard_info.dtype.clone(), speculate: shard_info.speculate as usize, + support_chunking: shard_info.support_chunking, + prefix_caching: shard_info.use_prefix_caching, + attention_impl: shard_info.attention_impl.clone(), + block_size: shard_info.block_size, }; let backend = BackendV3::new( @@ -119,9 +131,7 @@ pub async fn connect_backend( max_batch_total_tokens, max_waiting_tokens, max_batch_size, - shard_info.requires_padding, - shard_info.window_size, - shard_info.speculate, + shard_info, ); tracing::info!("Using backend V3"); diff --git a/backends/v3/src/main.rs b/backends/v3/src/main.rs index 471ddb5a..b4751bd5 100644 --- a/backends/v3/src/main.rs +++ b/backends/v3/src/main.rs @@ -131,25 +131,12 @@ async fn main() -> Result<(), RouterError> { "`max_input_tokens` must be < `max_total_tokens`".to_string(), )); } - if max_input_tokens as u32 > max_batch_prefill_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); - } if validation_workers == 0 { return Err(RouterError::ArgumentValidation( "`validation_workers` must be > 0".to_string(), )); } - - if let Some(ref max_batch_total_tokens) = max_batch_total_tokens { - if max_batch_prefill_tokens > *max_batch_total_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); - } - if max_total_tokens as u32 > *max_batch_total_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); - } - } - if let Some(max_batch_size) = max_batch_size { if max_batch_size == 0 { return Err(RouterError::ArgumentValidation( @@ -158,7 +145,7 @@ async fn main() -> Result<(), RouterError> { } } - let (backend, _backend_info) = connect_backend( + let (backend, backend_info) = connect_backend( max_input_tokens, max_total_tokens, master_shard_uds_path, @@ -170,6 +157,19 @@ async fn main() -> Result<(), RouterError> { ) .await?; + // Validate remaining args now that the backend is known + let support_chunking = backend_info.support_chunking; + let max_batch_total_tokens = backend_info.max_batch_total_tokens; + if max_input_tokens as u32 > max_batch_prefill_tokens && !support_chunking { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); + } + if max_batch_prefill_tokens > max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); + } + if max_total_tokens as u32 > max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); + } + // Run server server::run( backend, diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index f8123b57..6662b8de 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -4,7 +4,7 @@ use crate::client::{ Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; use nohash_hasher::{BuildNoHashHasher, IntMap}; -use std::cmp::{max, min}; +use std::cmp::max; use std::collections::VecDeque; use text_generation_router::infer::InferError; use text_generation_router::infer::InferStreamResponse; @@ -50,6 +50,7 @@ impl Queue { window_size: Option, speculate: u32, max_batch_total_tokens: u32, + support_chunking: bool, ) -> Self { // Create channel let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); @@ -62,6 +63,7 @@ impl Queue { window_size, speculate, max_batch_total_tokens, + support_chunking, queue_receiver, )); @@ -87,6 +89,10 @@ impl Queue { prefill_token_budget: u32, token_budget: u32, ) -> Option { + if prefill_token_budget == 0 || token_budget == 0 { + return None; + }; + // Create response channel let (response_sender, response_receiver) = oneshot::channel(); // Send next batch command to the background task managing the state @@ -108,6 +114,7 @@ impl Queue { } // Background task responsible of the queue state +#[allow(clippy::too_many_arguments)] async fn queue_task( requires_padding: bool, block_size: u32, @@ -115,6 +122,7 @@ async fn queue_task( window_size: Option, speculate: u32, max_batch_total_tokens: u32, + support_chunking: bool, mut receiver: mpsc::UnboundedReceiver, ) { let mut state = State::new( @@ -124,6 +132,7 @@ async fn queue_task( window_size, speculate, max_batch_total_tokens, + support_chunking, ); while let Some(cmd) = receiver.recv().await { @@ -166,12 +175,14 @@ struct State { /// Paged Attention block size block_size: u32, - /// Sliding window - window_size: Option, - /// Speculation amount speculate: u32, + /// Whether the model allow the prefill chunking + /// If it does, the last request in the batch will be split to exactly match the prefill + /// token budget + support_chunking: bool, + /// Paged Attention Block Allocation block_allocator: Option, } @@ -184,6 +195,7 @@ impl State { window_size: Option, speculate: u32, max_batch_total_tokens: u32, + support_chunking: bool, ) -> Self { let block_allocator = (!requires_padding).then(|| { BlockAllocator::new( @@ -199,8 +211,8 @@ impl State { next_id: 0, next_batch_id: 0, block_size, - window_size, speculate, + support_chunking, block_allocator, } } @@ -287,32 +299,7 @@ impl State { } None } - Some(_block_allocator) => { - prefill_tokens += entry.request.input_length; - let max_new_tokens = match self.window_size { - None => entry.request.stopping_parameters.max_new_tokens, - Some(window_size) => min( - window_size.saturating_sub(entry.request.input_length), - entry.request.stopping_parameters.max_new_tokens, - ), - }; - decode_tokens += max_new_tokens; - - if prefill_tokens > prefill_token_budget - || (prefill_tokens + decode_tokens + self.speculate) > token_budget - { - // Entry is over budget - // Add it back to the front - tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); - self.entries.push_front((id, entry)); - break; - } - - let tokens = entry.request.input_length - + entry.request.stopping_parameters.max_new_tokens - + self.speculate - - 1; - + Some(block_allocator) => { // If users wants the prefill logprobs, we cannot reuse the cache. // So no input_ids for the radix tree. let input_ids = if entry.request.decoder_input_details { @@ -321,10 +308,73 @@ impl State { entry.request.input_ids.clone() }; - Some((tokens, input_ids)) + let tokens = entry.request.input_length + + entry.request.stopping_parameters.max_new_tokens + + self.speculate + - 1; + tracing::debug!("Allocating {tokens} with {input_ids:?}"); + + let block_allocation = match block_allocator.allocate(tokens, input_ids).await { + None => { + // Entry is over budget + // Add it back to the front + tracing::debug!("Over budget: not enough free blocks"); + self.entries.push_front((id, entry)); + break 'entry_loop; + } + Some(mut block_allocation) => { + tracing::debug!("Allocation: {block_allocation:?}"); + max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); + + if block_allocation.prefix_len == entry.request.input_length { + // The whole request was found in the radix trie + // However, for the transformer forward to work, we need to + // have at least one token of postfix. + block_allocation.prefix_len -= 1; + } + + block_allocation + } + }; + + let postfix_len = entry.request.input_length - block_allocation.prefix_len; + + if prefill_tokens + postfix_len > prefill_token_budget { + // Entry is over budget + if self.support_chunking { + // We support chunking, just set postfix_len to exactly match prefill_token_budget + let chunk_len = prefill_token_budget.saturating_sub(prefill_tokens); + if chunk_len > 0 { + // Push this entry inside the batch + batch.push((id, entry, Some(block_allocation), Some(chunk_len))); + } else { + // We cannot prefill even one token for this entry + // Add it back to the queue + self.entries.push_front((id, entry)); + } + tracing::debug!( + "Matched budget: prefill_tokens={} == {prefill_token_budget}", + prefill_tokens + postfix_len + ); + break 'entry_loop; + } else { + // We don't support chunking, this entry needs to go back to the buffer + // Add it back to the front + tracing::debug!( + "Over budget: prefill_tokens={} > {prefill_token_budget}", + prefill_tokens + postfix_len + ); + self.entries.push_front((id, entry)); + break 'entry_loop; + } + } + + prefill_tokens += postfix_len; + + Some(block_allocation) } }; - batch.push((id, entry, block_allocation)); + batch.push((id, entry, block_allocation, None)); if Some(batch.len()) == max_size { break; } @@ -342,7 +392,7 @@ impl State { // Batch is too small if batch.len() < min_size { // Add back entries to the queue in the correct order - for (id, entry, _) in batch.into_iter().rev() { + for (id, entry, _, _) in batch.into_iter().rev() { self.entries.push_front((id, entry)); } return None; @@ -353,29 +403,7 @@ impl State { let mut batch_entries = IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default()); - for (id, mut entry, block_allocation) in batch { - let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) = - (block_allocation, &self.block_allocator) - { - tracing::debug!("Allocating {tokens} with {input_ids:?}"); - match block_allocator.allocate(tokens, input_ids).await { - None => { - // Entry is over budget - // Add it back to the front - tracing::debug!("Over budget: not enough free blocks"); - self.entries.push_front((id, entry)); - continue; - } - Some(block_allocation) => { - tracing::debug!("Allocation: {block_allocation:?}"); - max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); - Some(block_allocation) - } - } - } else { - None - }; - tracing::debug!("Accepting entry"); + for (id, mut entry, block_allocation, chunk_len) in batch { // Create a new span to link the batch back to this entry let entry_batch_span = info_span!(parent: &entry.span, "infer"); // Add relationships @@ -427,8 +455,9 @@ impl State { top_n_tokens: entry.request.top_n_tokens, blocks, slots, - prefix_len, + cache_len: prefix_len, adapter_id: entry.request.adapter_id.clone(), + chunk_len, }); // Set batch_time entry.batch_time = Some(Instant::now()); @@ -436,12 +465,6 @@ impl State { batch_entries.insert(id, entry); } - // Empty batch - if batch_requests.is_empty() { - tracing::debug!("Filterered out all entries"); - return None; - } - // Final batch size let size = batch_requests.len() as u32; next_batch_span.record("batch_size", size); @@ -531,7 +554,7 @@ mod tests { request: ValidGenerateRequest { inputs: vec![], input_ids: Some(Arc::new(vec![])), - input_length: 0, + input_length: 1, add_special_tokens: true, truncate: 0, decoder_input_details: false, @@ -567,7 +590,7 @@ mod tests { #[tokio::test] async fn test_append() { - let mut state = State::new(false, 1, false, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16, false); let (entry, _guard) = default_entry(); assert_eq!(state.next_id, 0); @@ -583,7 +606,7 @@ mod tests { #[tokio::test] async fn test_next_batch_empty() { - let mut state = State::new(false, 1, false, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16, false); assert!(state.next_batch(None, None, 1, 1).await.is_none()); assert!(state.next_batch(Some(1), None, 1, 1).await.is_none()); @@ -591,7 +614,7 @@ mod tests { #[tokio::test] async fn test_next_batch_min_size() { - let mut state = State::new(false, 1, false, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -623,7 +646,7 @@ mod tests { #[tokio::test] async fn test_next_batch_max_size() { - let mut state = State::new(false, 1, false, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -643,7 +666,7 @@ mod tests { #[tokio::test] async fn test_next_batch_token_budget() { - let mut state = State::new(false, 1, false, None, 0, 2); + let mut state = State::new(false, 1, false, None, 0, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -676,14 +699,14 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16, false); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16, false); assert!(queue.next_batch(None, None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); @@ -691,7 +714,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -724,7 +747,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_max_size() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -740,7 +763,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_budget() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -765,7 +788,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_speculate() { - let queue = Queue::new(false, 1, false, None, 2, 16); + let queue = Queue::new(true, 1, false, None, 2, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -784,7 +807,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16, false); let (entry, _) = default_entry(); queue.append(entry); diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 789c7b51..63fc7808 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -158,7 +158,8 @@ async fn prefill( top_n_tokens: top_n_tokens.unwrap_or(0), blocks: vec![], slots: vec![], - prefix_len: 0, + cache_len: 0, + chunk_len: None, adapter_id: None, }) .collect(); @@ -173,7 +174,7 @@ async fn prefill( // Run prefill let start_time = Instant::now(); - let (_, decode_batch, _) = client.prefill(batch.clone()).await?; + let (_, decode_batch, _) = client.prefill(batch.clone(), None).await?; // Get latency let latency = start_time.elapsed(); diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index 2ee3d7c5..2e2e9a11 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -178,6 +178,7 @@ fn main() -> Result<(), Box> { .clear_cache(None) .await .expect("Unable to clear cache"); + tracing::info!("Connected"); // Run app diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index f24fc079..356fa5e3 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -9,13 +9,16 @@ import subprocess import sys import tempfile import time -from typing import Dict, List, Optional - import docker import pytest +import base64 + +from pathlib import Path +from typing import Dict, List, Optional from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError from docker.errors import NotFound from syrupy.extensions.json import JSONSnapshotExtension + from text_generation import AsyncClient from text_generation.types import ( BestOfSequence, @@ -403,6 +406,7 @@ def launcher(event_loop): print(" ".join(args), file=sys.stderr) env["LOG_LEVEL"] = "info,text_generation_router=debug" + env["PREFILL_CHUNKING"] = "1" if not use_flash_attention: env["USE_FLASH_ATTENTION"] = "false" @@ -501,6 +505,7 @@ def launcher(event_loop): env = { "LOG_LEVEL": "info,text_generation_router=debug", + "PREFILL_CHUNKING": "1", } if not use_flash_attention: env["USE_FLASH_ATTENTION"] = "false" @@ -642,3 +647,22 @@ def generate_multi(): return responses return generate_load_inner + + +# TODO fix the server parsser to count inline image tokens correctly +@pytest.fixture +def chicken(): + path = Path(__file__).parent / "images" / "chicken_on_money.png" + + with open(path, "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return f"data:image/png;base64,{encoded_string.decode('utf-8')}" + + +@pytest.fixture +def cow_beach(): + path = Path(__file__).parent / "images" / "cow_beach.png" + + with open(path, "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return f"data:image/png;base64,{encoded_string.decode('utf-8')}" diff --git a/integration-tests/models/test_flash_pali_gemma.py b/integration-tests/models/test_flash_pali_gemma.py index 52ecaed4..93962eb3 100644 --- a/integration-tests/models/test_flash_pali_gemma.py +++ b/integration-tests/models/test_flash_pali_gemma.py @@ -1,5 +1,4 @@ import pytest -import base64 @pytest.fixture(scope="module") @@ -20,24 +19,11 @@ async def flash_pali_gemma(flash_pali_gemma_handle): return flash_pali_gemma_handle.client -def get_chicken(): - with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" - - -def get_cow_beach(): - with open("integration-tests/images/cow_beach.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" - - @pytest.mark.release @pytest.mark.asyncio @pytest.mark.private -async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot): - cow = get_cow_beach() - inputs = f"![]({cow})Where is the cow standing?\n" +async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot, cow_beach): + inputs = f"![]({cow_beach})Where is the cow standing?\n" response = await flash_pali_gemma.generate(inputs, max_new_tokens=20) assert response.generated_text == "beach" @@ -47,9 +33,9 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot): @pytest.mark.release @pytest.mark.asyncio @pytest.mark.private -async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot): - chicken = get_chicken() - cow_beach = get_cow_beach() +async def test_flash_pali_gemma_two_images( + flash_pali_gemma, response_snapshot, chicken, cow_beach +): response = await flash_pali_gemma.generate( f"caption![]({chicken})![]({cow_beach})\n", max_new_tokens=20, diff --git a/integration-tests/models/test_grammar_response_format_llama.py b/integration-tests/models/test_grammar_response_format_llama.py index 25bf9d98..eb3268ce 100644 --- a/integration-tests/models/test_grammar_response_format_llama.py +++ b/integration-tests/models/test_grammar_response_format_llama.py @@ -25,7 +25,6 @@ async def llama_grammar(llama_grammar_handle): @pytest.mark.release @pytest.mark.asyncio async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot): - class Weather(BaseModel): unit: str temperature: List[int] diff --git a/integration-tests/models/test_idefics.py b/integration-tests/models/test_idefics.py index eb573385..e5d08bb7 100644 --- a/integration-tests/models/test_idefics.py +++ b/integration-tests/models/test_idefics.py @@ -1,5 +1,4 @@ import pytest -import base64 @pytest.fixture(scope="module") @@ -16,22 +15,8 @@ async def idefics(idefics_handle): return idefics_handle.client -# TODO fix the server parsser to count inline image tokens correctly -def get_chicken(): - with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" - - -def get_cow_beach(): - with open("integration-tests/images/cow_beach.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" - - @pytest.mark.asyncio -async def test_idefics(idefics, response_snapshot): - chicken = get_chicken() +async def test_idefics(idefics, response_snapshot, chicken): response = await idefics.generate( f"User:![]({chicken})Can you tell me a very short story based on the image?", max_new_tokens=10, @@ -48,9 +33,7 @@ async def test_idefics(idefics, response_snapshot): @pytest.mark.release @pytest.mark.asyncio @pytest.mark.private -async def test_idefics_two_images(idefics, response_snapshot): - chicken = get_chicken() - cow_beach = get_cow_beach() +async def test_idefics_two_images(idefics, response_snapshot, chicken, cow_beach): response = await idefics.generate( f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken? \nAssistant:", max_new_tokens=20, @@ -63,8 +46,7 @@ async def test_idefics_two_images(idefics, response_snapshot): @pytest.mark.release @pytest.mark.asyncio -async def test_idefics_load(idefics, generate_load, response_snapshot): - chicken = get_chicken() +async def test_idefics_load(idefics, generate_load, response_snapshot, chicken): responses = await generate_load( idefics, f"User:![]({chicken})Can you tell me a very short story based on the image?", diff --git a/integration-tests/models/test_idefics2.py b/integration-tests/models/test_idefics2.py index c5f48da3..881e37f9 100644 --- a/integration-tests/models/test_idefics2.py +++ b/integration-tests/models/test_idefics2.py @@ -1,18 +1,4 @@ import pytest -import base64 - - -# TODO fix the server parsser to count inline image tokens correctly -def get_chicken(): - with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" - - -def get_cow_beach(): - with open("integration-tests/images/cow_beach.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" @pytest.fixture(scope="module") @@ -31,8 +17,9 @@ async def flash_idefics2_next(flash_idefics2_next_handle): @pytest.mark.asyncio @pytest.mark.private -async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot): - chicken = get_chicken() +async def test_flash_idefics2_next_simple( + flash_idefics2_next, response_snapshot, chicken +): response = await flash_idefics2_next.generate( f"User:![]({chicken})Write me a short story \nAssistant:", max_new_tokens=10, @@ -46,9 +33,9 @@ async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot @pytest.mark.asyncio @pytest.mark.private -async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot): - chicken = get_chicken() - cow_beach = get_cow_beach() +async def test_flash_idefics2_two_images( + flash_idefics2_next, response_snapshot, chicken, cow_beach +): response = await flash_idefics2_next.generate( f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken? \nAssistant:", max_new_tokens=20, @@ -87,9 +74,8 @@ async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snap @pytest.mark.asyncio @pytest.mark.private async def test_flash_idefics2_next_load( - flash_idefics2_next, generate_load, response_snapshot + flash_idefics2_next, generate_load, response_snapshot, chicken ): - chicken = get_chicken() responses = await generate_load( flash_idefics2_next, f"User:![]({chicken})Write me a short story \nAssistant:", diff --git a/integration-tests/models/test_llava_next.py b/integration-tests/models/test_llava_next.py index ea277d71..1ac8f172 100644 --- a/integration-tests/models/test_llava_next.py +++ b/integration-tests/models/test_llava_next.py @@ -1,12 +1,4 @@ import pytest -import base64 - - -# TODO fix the server parsser to count inline image tokens correctly -def get_chicken(): - with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" @pytest.fixture(scope="module") @@ -29,8 +21,7 @@ async def flash_llava_next(flash_llava_next_handle): @pytest.mark.release @pytest.mark.asyncio @pytest.mark.private -async def test_flash_llava_next_simple(flash_llava_next, response_snapshot): - chicken = get_chicken() +async def test_flash_llava_next_simple(flash_llava_next, response_snapshot, chicken): response = await flash_llava_next.generate( f"User:![]({chicken})Can you tell me a very short story based on the image?", max_new_tokens=10, @@ -70,9 +61,8 @@ async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot): @pytest.mark.asyncio @pytest.mark.private async def test_flash_llava_next_load( - flash_llava_next, generate_load, response_snapshot + flash_llava_next, generate_load, response_snapshot, chicken ): - chicken = get_chicken() responses = await generate_load( flash_llava_next, f"User:![]({chicken})Can you tell me a very short story based on the image?", diff --git a/integration-tests/models/test_mllama.py b/integration-tests/models/test_mllama.py index 1b4264aa..02781707 100644 --- a/integration-tests/models/test_mllama.py +++ b/integration-tests/models/test_mllama.py @@ -1,5 +1,4 @@ import pytest -import base64 import asyncio @@ -15,22 +14,8 @@ async def mllama(mllama_handle): return mllama_handle.client -# TODO fix the server parsser to count inline image tokens correctly -def get_chicken(): - with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" - - -def get_cow_beach(): - with open("integration-tests/images/cow_beach.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" - - @pytest.mark.asyncio async def test_mllama_simpl(mllama, response_snapshot): - # chicken = get_chicken() response = await mllama.chat( max_tokens=10, temperature=0.0, diff --git a/launcher/src/main.rs b/launcher/src/main.rs index ee259e43..0d7af66d 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -68,7 +68,7 @@ fn get_config( fn resolve_attention(config: &Option, lora_adapters: &Option) -> (String, String) { let compute_capability = gpu::get_cuda_capability(); - let mut prefix_caching: Option = std::env::var("USE_PREFIX_CACHING").ok(); + let mut prefix_caching: Option = std::env::var("PREFIX_CACHING").ok(); let mut attention: Option = std::env::var("ATTENTION").ok(); if let Some(config) = config { if prefix_caching.is_none() { @@ -124,6 +124,10 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> } } } + if attention == Some("paged".to_string()) && prefix_caching.is_none() { + tracing::info!("Disabling prefix caching on paged attention"); + prefix_caching = Some("0".to_string()); + } let attention = attention.unwrap_or("flashinfer".to_string()); let prefix_caching = prefix_caching.unwrap_or("true".to_string()); @@ -1678,7 +1682,7 @@ fn main() -> Result<(), LauncherError> { }; let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters); tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}"); - std::env::set_var("USE_PREFIX_CACHING", prefix_caching); + std::env::set_var("PREFIX_CACHING", prefix_caching); std::env::set_var("ATTENTION", attention); let max_input_tokens = { @@ -1729,12 +1733,6 @@ fn main() -> Result<(), LauncherError> { "`max_input_tokens must be < `max_total_tokens`".to_string(), )); } - if max_input_tokens as u32 > max_batch_prefill_tokens { - return Err(LauncherError::ArgumentValidation(format!( - "`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}", - max_batch_prefill_tokens, max_input_tokens - ))); - } if matches!(args.quantize, Some(Quantization::Bitsandbytes)) { tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases."); @@ -1788,12 +1786,6 @@ fn main() -> Result<(), LauncherError> { } if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens { - if max_batch_prefill_tokens > *max_batch_total_tokens { - return Err(LauncherError::ArgumentValidation(format!( - "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", - max_batch_prefill_tokens, max_batch_total_tokens - ))); - } if max_total_tokens as u32 > *max_batch_total_tokens { return Err(LauncherError::ArgumentValidation(format!( "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 34894bda..c91e7cc4 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -34,6 +34,10 @@ message InfoResponse { string device_type = 3; optional uint32 window_size = 4; uint32 speculate = 5; + bool support_chunking = 6; + bool use_prefix_caching = 7; + string attention_impl = 8; + uint32 block_size = 9; } /// Empty request @@ -135,10 +139,14 @@ message Request { repeated uint32 slots = 10; /// LORA adapter index optional string adapter_id = 11; - /// Prefix length that can be retrieved from the KV cache. - uint32 prefix_len = 12; + /// Tokens that can be retrieved from the KV cache. + /// This value is set for the first prefill and never reset + uint32 cache_len = 12; /// Context truncation bool add_special_tokens = 13; + /// Chunk of tokens that must be computed for the first prefill + /// This value is set for the first prefill and never reset + optional uint32 chunk_len = 14; } message Batch { @@ -163,6 +171,8 @@ message CachedBatch { uint32 size = 3; /// Maximum number of tokens this batch will grow to uint32 max_tokens = 4; + /// Number of tokens in the next forward + uint32 current_tokens = 5; } enum FinishReason { @@ -220,6 +230,8 @@ message FilterBatchResponse { message PrefillRequest { /// Batch Batch batch = 1; + /// Optional cached batch + CachedBatch cached_batch = 2; } message PrefillResponse { @@ -233,6 +245,8 @@ message PrefillResponse { uint64 decode_ns = 4; /// Total elapsed time in nanoseconds uint64 total_ns = 5; + /// Concatenate elapsed time in nanoseconds + optional uint64 concat_ns = 6; } message DecodeRequest { diff --git a/router/src/lib.rs b/router/src/lib.rs index b29c9395..fdbd931e 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -18,45 +18,6 @@ use tracing::warn; use utoipa::ToSchema; use validation::Validation; -#[derive(PartialEq)] -pub enum Attention { - Paged, - FlashDecoding, - FlashInfer, -} - -impl Attention { - pub fn block_size(&self) -> u32 { - match self { - Attention::FlashDecoding => 256, - Attention::FlashInfer => 1, - Attention::Paged => 16, - } - } -} - -#[derive(Debug)] -pub struct ParseError; - -impl std::fmt::Display for ParseError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Cannot parse attention value") - } -} -impl std::error::Error for ParseError {} - -impl std::str::FromStr for Attention { - type Err = ParseError; - fn from_str(s: &str) -> Result { - match s { - "paged" => Ok(Attention::Paged), - "flashdecoding" => Ok(Attention::FlashDecoding), - "flashinfer" => Ok(Attention::FlashInfer), - _ => Err(ParseError), - } - } -} - /// Hub type #[derive(Clone, Debug, Deserialize)] pub struct HubModelInfo { diff --git a/server/tests/conftest.py b/server/tests/conftest.py index d99771f8..98222792 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -2,7 +2,7 @@ import pytest import os from text_generation_server.pb import generate_pb2 -os.environ["USE_PREFIX_CACHING"] = "1" +os.environ["PREFIX_CACHING"] = "1" os.environ["ATTENTION"] = "flashinfer" diff --git a/server/text_generation_server/interceptor.py b/server/text_generation_server/interceptor.py index 57df1725..a5c023e4 100644 --- a/server/text_generation_server/interceptor.py +++ b/server/text_generation_server/interceptor.py @@ -9,6 +9,9 @@ from typing import Callable, Any class ExceptionInterceptor(AsyncServerInterceptor): + def __init__(self, shutdown_callback): + self.shutdown_callback = shutdown_callback + async def intercept( self, method: Callable, @@ -25,7 +28,7 @@ class ExceptionInterceptor(AsyncServerInterceptor): # Runtime Error cannot be recovered from if isinstance(err, RuntimeError): - exit(1) + self.shutdown_callback() if torch.cuda.is_available(): torch.cuda.empty_cache() diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py index c8ac0c2a..a3b919ee 100644 --- a/server/text_generation_server/layers/attention/common.py +++ b/server/text_generation_server/layers/attention/common.py @@ -1,72 +1,52 @@ from dataclasses import dataclass -from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import ATTENTION import torch from typing import Optional -if ATTENTION in {"flashinfer", "flashdecoding"}: +@dataclass +class Seqlen: + input_lengths: torch.Tensor + cache_lengths: torch.Tensor + cu_seqlen_q: Optional[torch.Tensor] + cu_seqlen_k: Optional[torch.Tensor] + max_q: int + max_k: int - @dataclass - class Seqlen: - input_lengths: torch.Tensor - prefix_lengths: torch.Tensor - cu_seqlen_q: Optional[torch.Tensor] - cu_seqlen_k: Optional[torch.Tensor] - max_q: int - max_k: int + def __init__( + self, + input_lengths, + cache_lengths, + cu_seqlen_q=None, + max_q=None, + max_k=None, + ): + self.input_lengths = input_lengths + self.cache_lengths = cache_lengths + device = self.input_lengths.device + shape = self.input_lengths.shape + if cu_seqlen_q is None: + cu_seqlen_q = torch.arange( + shape[0] + 1, + device=device, + dtype=torch.int32, + ) + max_q = 1 + else: + assert max_q is not None + assert max_k is not None + cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) - def __init__( - self, - input_lengths, - prefix_lengths, - cu_seqlen_q=None, - max_q=None, - max_k=None, - ): - self.input_lengths = input_lengths - self.prefix_lengths = prefix_lengths - device = self.input_lengths.device - shape = self.input_lengths.shape - if cu_seqlen_q is None: - cu_seqlen_q = torch.arange( - shape[0] + 1, - device=device, - dtype=torch.int32, - ) - max_q = 1 - else: - assert max_q is not None - assert max_k is not None - cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) + # cuda graphs don't like this and this is necessary to clamp within mistral + # Although FA2 might not want the clamping + # cu_seqlen_k[0] = 0 + total = self.input_lengths + self.cache_lengths + torch.cumsum(total, -1, out=cu_seqlen_k[1:]) - # cuda graphs don't like this and this is necessary to clamp within mistral - # Although FA2 might not want the clamping - # cu_seqlen_k[0] = 0 - total = self.input_lengths + self.prefix_lengths - torch.cumsum(total, -1, out=cu_seqlen_k[1:]) + self.cu_seqlen_q = cu_seqlen_q + self.cu_seqlen_k = cu_seqlen_k + self.max_q = max_q + self.max_k = max_k - self.cu_seqlen_q = cu_seqlen_q - self.cu_seqlen_k = cu_seqlen_k - self.max_q = max_q - self.max_k = max_k - - def clamp(self, max): - # Flash decoding doesn't need to clamp - return self - -else: - - @dataclass - class Seqlen: - input_lengths: torch.Tensor - prefix_lengths: torch.Tensor - cu_seqlen_q: torch.Tensor - max_q: int - max_k: int - - def clamp(self, max): - if SYSTEM == "rocm": - return self - self.input_lengths = torch.clamp(self.input_lengths, max=max) - return self + def clamp(self, max): + # Flash decoding doesn't need to clamp + return self diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index cd3ea369..265a8ae4 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -123,7 +123,7 @@ def paged_attention( else: if softcap is not None: raise RuntimeError("Paged attention doesn't support softcapping") - input_lengths = seqlen.input_lengths + input_lengths = seqlen.input_lengths + seqlen.cache_lengths from vllm._C import ops out = torch.empty_like(query) @@ -244,117 +244,232 @@ if ATTENTION == "flashinfer": window_left=window_size_left, ) -elif V2: +elif ATTENTION == "flashdecoding": + if V2: - def attention( - q, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, - softmax_scale, - window_size_left=-1, - causal=True, - softcap=0.0, - ): - out = torch.empty_like(q) - if window_size_left <= 0 and window_size_left != -1: - raise ValueError("`window_size_left` must be > 0 or -1") - return flash_attn_2_cuda.varlen_fwd( + def attention( q, - key_cache, - value_cache, - out, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_k, - None, - None, - block_tables, - None, - seqlen.max_q, - seqlen.max_k, - 0.0, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + seqlen: Seqlen, + block_tables: torch.Tensor, softmax_scale, - False, - causal, - window_size_left, - 0, - softcap, - False, - None, - )[0] + window_size_left=-1, + causal=True, + softcap=0.0, + ): + out = torch.empty_like(q) + if window_size_left <= 0 and window_size_left != -1: + raise ValueError("`window_size_left` must be > 0 or -1") + return flash_attn_2_cuda.varlen_fwd( + q, + key_cache, + value_cache, + out, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_k, + None, + None, + block_tables, + None, + seqlen.max_q, + seqlen.max_k, + 0.0, + softmax_scale, + False, + causal, + window_size_left, + 0, + softcap, + False, + None, + )[0] + + else: + + def attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seqlen: Seqlen, + block_tables: torch.Tensor, + softmax_scale: float, + window_size_left: int = -1, + causal: bool = True, + softcap=None, + ): + if window_size_left != -1: + raise NotImplementedError( + "window_size_left is only available with flash attn v2" + ) + if softcap is not None: + raise NotImplementedError( + "softcap is only available with flash attn v2" + ) + + # Flash attention v1 requires q, k and v to have the same number of heads + if k.shape[1] != q.shape[1]: + # MQA expand + if k.shape[1] == 1: + k = k.expand(-1, q.shape[1], -1) + # Grouped attention reshape + else: + original_shape = k.shape + k = ( + k.unsqueeze(2) + .expand(-1, -1, q.shape[1] // k.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) + if v.shape[1] != q.shape[1]: + # MQA expand + if v.shape[1] == 1: + v = v.expand(-1, q.shape[1], -1) + # Grouped attention reshape + else: + original_shape = v.shape + v = ( + v.unsqueeze(2) + .expand(-1, -1, q.shape[1] // v.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) + + out = torch.empty_like(q) + flash_attn_cuda.fwd( + q, + k, + v, + out, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_q, + seqlen.max_q, + seqlen.max_k, + 0.0, + softmax_scale, + False, + causal, + False, + 0, + None, + ) + return out + +elif ATTENTION == "paged": + if V2: + + def attention( + q, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + seqlen: Seqlen, + block_tables: torch.Tensor, + softmax_scale, + window_size_left=-1, + causal=True, + softcap=0.0, + ): + out = torch.empty_like(q) + if window_size_left <= 0 and window_size_left != -1: + raise ValueError("`window_size_left` must be > 0 or -1") + return flash_attn_2_cuda.varlen_fwd( + q, + key_cache, + value_cache, + out, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_k, + None, + None, + None, # block_tables, + None, + seqlen.max_q, + seqlen.max_k, + 0.0, + softmax_scale, + False, + causal, + window_size_left, + 0, + softcap, + False, + None, + )[0] + + else: + + def attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seqlen: Seqlen, + block_tables: torch.Tensor, + softmax_scale: float, + window_size_left: int = -1, + causal: bool = True, + softcap=None, + ): + if window_size_left != -1: + raise NotImplementedError( + "window_size_left is only available with flash attn v2" + ) + if softcap is not None: + raise NotImplementedError( + "softcap is only available with flash attn v2" + ) + + # Flash attention v1 requires q, k and v to have the same number of heads + if k.shape[1] != q.shape[1]: + # MQA expand + if k.shape[1] == 1: + k = k.expand(-1, q.shape[1], -1) + # Grouped attention reshape + else: + original_shape = k.shape + k = ( + k.unsqueeze(2) + .expand(-1, -1, q.shape[1] // k.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) + if v.shape[1] != q.shape[1]: + # MQA expand + if v.shape[1] == 1: + v = v.expand(-1, q.shape[1], -1) + # Grouped attention reshape + else: + original_shape = v.shape + v = ( + v.unsqueeze(2) + .expand(-1, -1, q.shape[1] // v.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) + + out = torch.empty_like(q) + flash_attn_cuda.fwd( + q, + k, + v, + out, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_q, + seqlen.max_q, + seqlen.max_k, + 0.0, + softmax_scale, + False, + causal, + False, + 0, + None, + ) + return out else: - - def attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, - softmax_scale: float, - window_size_left: int = -1, - causal: bool = True, - softcap=None, - ): - if window_size_left != -1: - raise NotImplementedError( - "window_size_left is only available with flash attn v2" - ) - if softcap is not None: - raise NotImplementedError("softcap is only available with flash attn v2") - - # Flash attention v1 requires q, k and v to have the same number of heads - if k.shape[1] != q.shape[1]: - # MQA expand - if k.shape[1] == 1: - k = k.expand(-1, q.shape[1], -1) - # Grouped attention reshape - else: - original_shape = k.shape - k = ( - k.unsqueeze(2) - .expand(-1, -1, q.shape[1] // k.shape[1], -1) - .reshape(original_shape[0], -1, original_shape[2]) - ) - if v.shape[1] != q.shape[1]: - # MQA expand - if v.shape[1] == 1: - v = v.expand(-1, q.shape[1], -1) - # Grouped attention reshape - else: - original_shape = v.shape - v = ( - v.unsqueeze(2) - .expand(-1, -1, q.shape[1] // v.shape[1], -1) - .reshape(original_shape[0], -1, original_shape[2]) - ) - - out = torch.empty_like(q) - flash_attn_cuda.fwd( - q, - k, - v, - out, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_q, - seqlen.max_q, - seqlen.max_k, - 0.0, - softmax_scale, - False, - causal, - False, - 0, - None, - ) - return out + raise RuntimeError(f"Unknwon attention {ATTENTION}") # Prefill in the cache with every kind of attention, unless we # have a configuration that requires flash-attention v1, which # does not support block tables. -PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2 +PREFILL_IN_KV_CACHE = ATTENTION == "flashinfer" or (ATTENTION == "flashdecoding" and V2) __all__ = [ "PREFILL_IN_KV_CACHE", diff --git a/server/text_generation_server/layers/attention/flash_attn_triton.py b/server/text_generation_server/layers/attention/flash_attn_triton.py index 3a6f9a73..fd180f0f 100644 --- a/server/text_generation_server/layers/attention/flash_attn_triton.py +++ b/server/text_generation_server/layers/attention/flash_attn_triton.py @@ -699,7 +699,6 @@ def check_args( class _attention(torch.autograd.Function): - @staticmethod def forward( ctx, diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 131c9bb0..17f6a7f1 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -66,6 +66,7 @@ def paged_attention( softcap: Optional[float] = None, ): out = torch.empty_like(query) + input_lengths = seqlen.input_lengths + seqlen.cache_lengths ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( out, query, @@ -74,7 +75,7 @@ def paged_attention( kv_head_mapping, softmax_scale, block_tables, - seqlen.input_lengths, + input_lengths, BLOCK_SIZE, max_s, None, diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 01d4685a..27e7638a 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -104,7 +104,7 @@ def paged_attention( _PARTITION_SIZE = _PARTITION_SIZE_CUSTOM max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE - input_lengths = seqlen.input_lengths + input_lengths = seqlen.input_lengths + seqlen.cache_lengths out = torch.empty_like(query) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index ef46cb8c..bd8176be 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -76,6 +76,7 @@ class CausalLMBatch(Batch): request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.max_tokens, + current_tokens=len(self.input_ids), ) @classmethod diff --git a/server/text_generation_server/models/custom_modeling/mllama.py b/server/text_generation_server/models/custom_modeling/mllama.py index 6e091a74..be0a4b5d 100644 --- a/server/text_generation_server/models/custom_modeling/mllama.py +++ b/server/text_generation_server/models/custom_modeling/mllama.py @@ -493,9 +493,14 @@ class MllamaVisionModel(nn.Module): aspect_ratio_ids: torch.Tensor, attention_mask: torch.Tensor, ) -> torch.Tensor: - batch_size, num_concurrent_media, num_tiles, num_channels, height, width = ( - pixel_values.shape - ) + ( + batch_size, + num_concurrent_media, + num_tiles, + num_channels, + height, + width, + ) = pixel_values.shape pixel_values = pixel_values.reshape( batch_size * num_concurrent_media * num_tiles, num_channels, height, width diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 33fe30a8..c9b7decd 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -16,7 +16,17 @@ from transformers import ( AutoTokenizer, GenerationConfig, ) -from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict +from typing import ( + Any, + ContextManager, + Iterable, + Optional, + Tuple, + List, + Type, + Dict, + Union, +) from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE @@ -24,6 +34,10 @@ from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models import Model from text_generation_server.utils.log import log_master +from text_generation_server.utils.prefill_chunking import ( + get_support_chunking, + get_max_prefill_tokens, +) from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils import ( @@ -60,7 +74,6 @@ from text_generation_server.utils.import_utils import ( tracer = trace.get_tracer(__name__) - # Will be set in init SLIDING_WINDOW: Optional[int] = None @@ -117,45 +130,48 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping: Dict[int, int] # Decoder values - input_ids: torch.Tensor - position_ids: torch.Tensor + # Can be a list for easy filtering + # If `input_ids` is a list, it needs to be materialized to a tensor first + input_ids: Union[torch.Tensor, List[List[int]]] + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + position_ids: Optional[torch.Tensor] speculative_ids: Optional[torch.Tensor] - # Flash Attention values - - # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill - cu_seqlen_prefill: Optional[torch.Tensor] - # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers - # as we only keep SLIDING_WINDOW values instead of the whole tensor - prefill_cache_indices: Optional[torch.Tensor] - - # Paged Attention values - # Set when creating the batch - # CPU tensor of length b indicating the start of each sequence in slots - start_slots: torch.Tensor # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode - slot_indices: torch.Tensor + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + slot_indices: Optional[torch.Tensor] # list of length b of list of length s_i // block_size block_tables: List[List[int]] # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences block_tables_tensor: torch.Tensor # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences - slots: torch.Tensor - # size [b], containing the number of blocks that can be retrieved from the cache - prefix_lens: List[int] - prefix_lens_tensor: torch.Tensor + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + slots: Optional[torch.Tensor] - max_seqlen: int + max_input_length: int + max_current_length: int + + # Whether this batch contains at least one request that is prefilling + prefilling: bool + # Whether each request is prefilling + prefilling_mask: List[bool] # Prefill metadata tensors to efficiently compute logprobs + # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill + cu_seqlen_prefill: Optional[torch.Tensor] + # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers + # as we only keep SLIDING_WINDOW values instead of the whole tensor + prefill_cache_indices: Optional[torch.Tensor] + # Will be set by `generate_token` and reset after each prefill forward prefill_head_indices: Optional[torch.Tensor] + # Will be set by `generate_token` and reset after each prefill forward prefill_next_token_indices: Optional[torch.tensor] + # Will be set by `generate_token` and reset after each prefill forward prefill_cu_outlens: Optional[List[int]] - - # Prefixes - prefix_ids: List[List[int]] + # Will be set by `generate_token` and reset after each prefill forward + prefill_logprob_tokens: List[Optional[Tokens]] # All tokens all_input_ids: List[List[int]] @@ -163,7 +179,14 @@ class FlashCausalLMBatch(Batch): # Lengths of all generations present in the batch input_lengths: List[int] - input_lengths_tensor: torch.Tensor + # size [b], containing the number of blocks that can be retrieved from the cache + cache_lengths: List[int] + prompt_lengths: List[int] + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + input_lengths_tensor: Optional[torch.Tensor] + cache_lengths_tensor: Optional[torch.Tensor] + prompt_lengths_tensor: torch.Tensor + prefix_offsets: List[Optional[int]] read_offsets: List[Optional[int]] @@ -174,7 +197,8 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor: torch.Tensor # Adapter metadata for each request - adapter_meta: AdapterBatchMetadata + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + adapter_meta: Optional[AdapterBatchMetadata] # Number of blocks in this batch num_blocks: int @@ -187,6 +211,11 @@ class FlashCausalLMBatch(Batch): request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.num_blocks * BLOCK_SIZE, + current_tokens=( + sum([len(i) for i in self.input_ids]) + if isinstance(self.input_ids, list) + else len(self.input_ids) + ), ) @classmethod @@ -218,46 +247,28 @@ class FlashCausalLMBatch(Batch): dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": - sliding_window = get_sliding_windows() - position_ids = [] - cu_seqlen_prefill = [0] - start_slots = [] - slot_indices = [] - prefill_cache_indices = [] + speculate = get_speculate() + cache_lengths = [] input_lengths = [] + prompt_lengths = [] prefix_offsets = [] read_offsets = [] all_input_ids = [] - prefix_ids = [] + all_postfix_ids = [] requests_idx_mapping = {} - all_prefill_logprobs = True - no_prefill_logprobs = True - prefill_head_indices = [] - prefill_next_token_indices = [] - prefill_cu_outlens = [0] - next_token_chooser_parameters = [] stopping_criterias = [] top_n_tokens = [] - adapter_indices_list = [] - adapter_set = set() - - # Cumulative length - cumulative_length = 0 - cumulative_slot_tokens = 0 - prefill_out_cumulative_length = 0 - num_blocks = 0 - max_seqlen = 0 + max_input_length = 0 + max_current_length = 0 max_length = 0 max_blocks = 0 block_tables = [] - slots = [] - prefix_lens = [] # Parse batch for i, (r, tokenized_input) in enumerate( @@ -266,38 +277,47 @@ class FlashCausalLMBatch(Batch): # request id -> idx in list mapping requests_idx_mapping[r.id] = i - orig_input_length = len(tokenized_input) + prompt_length = len(tokenized_input) + prompt_lengths.append(prompt_length) + + cache_length = r.cache_len - prefix_len = r.prefix_len assert ( - prefix_len <= orig_input_length - ), f"Prefix {prefix_len} vs input {orig_input_length}" - if prefix_len == orig_input_length: - assert prefix_len > 0 - prefix_len -= 1 + cache_length <= prompt_length + ), f"Prefix {cache_length} vs input {prompt_length}" + if cache_length == prompt_length: + assert False, "unreachable" - # Commented as it's costly. - # log_master(logger.debug, "Tokenized input ids {tokenized_input}") - prefix_ids.append(tokenized_input[:prefix_len]) - tokenized_input = tokenized_input[prefix_len:] + # `chunk_len` is an optional field in the protobuf + # It is only set if the model support chunking + if r.HasField("chunk_len"): + input_length = r.chunk_len + + if cache_length + input_length < prompt_length: + # FIXME: speculate is not supported for context chunking at the moment + assert speculate == 0 + assert get_support_chunking() + assert input_length > 0 + + postfix_ids = tokenized_input[ + cache_length : cache_length + input_length + ] + assert ( + len(postfix_ids) == input_length + ), "Rust and Python tokenizers are not aligned" + else: + # Use all the remaining ids + postfix_ids = tokenized_input[cache_length:] + input_length = len(postfix_ids) - input_length = len(tokenized_input) input_lengths.append(input_length) - prefix_offsets.append(input_length - 5) - read_offsets.append(input_length) + prefix_offsets.append(prompt_length - 5) + read_offsets.append(prompt_length) + all_postfix_ids.append(postfix_ids) all_input_ids.append(tokenized_input) - # Position ids - request_position_ids = torch.arange( - prefix_len, orig_input_length, dtype=torch.int32 - ) - position_ids.append(request_position_ids) - - # Add cumulative lengths of all previous inputs - cu_seqlen_prefill.append(cumulative_length + input_length) - next_token_chooser_parameters.append(r.parameters) stopping_criteria = StoppingCriteria.from_pb( @@ -307,22 +327,13 @@ class FlashCausalLMBatch(Batch): stopping_criterias.append(stopping_criteria) top_n_tokens.append(r.top_n_tokens) - ADAPTER_TO_INDEX = get_adapter_to_index() - adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) - adapter_indices_list.append(torch.full((input_length,), adapter_index)) - adapter_set.add(adapter_index) - # Paged attention # Remove one as the first token des not have a past speculative_length = get_speculate() speculative_length = 0 if speculative_length is None else speculative_length # Tokens that need to be mapped to blocks. - block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length - - # Tokens that need to be mapped to slots. We don't need slots for the - # cached prefix (if present). - slot_tokens = input_length + max_new_tokens - 1 + speculative_length + block_tokens = prompt_length + max_new_tokens - 1 + speculative_length # blocks and slots can be empty (for example in warmup) if not r.blocks: @@ -330,77 +341,26 @@ class FlashCausalLMBatch(Batch): request_blocks = [ b for b in range(num_blocks, num_blocks + needed_blocks) ] - request_slots = [ - s - for b in request_blocks - for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) - ] else: request_blocks = r.blocks - request_slots = r.slots[ - prefix_len: #: orig_input_length + max_new_tokens + speculative_length - ] block_tables.append(request_blocks) - slots.extend(request_slots) - prefix_lens.append(prefix_len) + cache_lengths.append(cache_length) num_blocks += len(request_blocks) - start_slots.append(cumulative_slot_tokens) - - request_slot_indices = torch.arange( - cumulative_slot_tokens, - cumulative_slot_tokens + input_length, - dtype=torch.int64, - ) - slot_indices.append(request_slot_indices) - - # Create tensor to slice into the kv tensor in prefill - if sliding_window is not None: - request_prefill_cache_indices = torch.arange( - cumulative_length + max(0, input_length - sliding_window), - cumulative_length + input_length, - dtype=torch.int64, - ) - prefill_cache_indices.append(request_prefill_cache_indices) - - all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs - no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs - - if r.prefill_logprobs: - prefill_head_indices.append(request_position_ids + cumulative_length) - prefill_next_token_indices.append( - prefill_out_cumulative_length + input_length - 1 - ) - prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) - prefill_out_cumulative_length += input_length - else: - prefill_head_indices.append( - torch.tensor( - [cumulative_length + input_length - 1], dtype=torch.int32 - ) - ) - prefill_next_token_indices.append(prefill_out_cumulative_length) - prefill_cu_outlens.append(prefill_out_cumulative_length + 1) - prefill_out_cumulative_length += 1 # Update - cumulative_length += input_length - cumulative_slot_tokens += slot_tokens - max_seqlen = max(max_seqlen, input_length) max_blocks = max(max_blocks, len(request_blocks)) + max_input_length = max(max_input_length, input_length) + max_current_length = max(max_current_length, cache_length + input_length) max_length = max( - max_length, input_length + max_new_tokens + speculative_length + max_length, + prompt_length + max_new_tokens + speculative_length, ) - adapter_indices = torch.cat(adapter_indices_list).to( - dtype=torch.int64, device=device - ) - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype, device, tokenizer ) - start_slots = torch.tensor(start_slots, dtype=torch.int64) # Padded all_input_ids_tensor all_input_ids_tensor = np.zeros( @@ -414,103 +374,59 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor, dtype=torch.int64, device=device ) - if len(pb.requests) > 1: - input_ids = np.concatenate(all_input_ids, dtype=np.int64) - position_ids = torch.cat(position_ids) - slot_indices = torch.cat(slot_indices) - if sliding_window is not None: - prefill_cache_indices = torch.cat(prefill_cache_indices) - else: - input_ids = all_input_ids[0] - position_ids = position_ids[0] - slot_indices = slot_indices[0] - if sliding_window is not None: - prefill_cache_indices = prefill_cache_indices[0] - - cu_seqlen_prefill = torch.tensor( - cu_seqlen_prefill, device=device, dtype=torch.int32 - ) - position_ids = position_ids.to(device) - slot_indices = slot_indices.to(device) - prefill_cache_indices = ( - prefill_cache_indices.to(device) if sliding_window is not None else None - ) - input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) - input_lengths_tensor = torch.tensor( - input_lengths, dtype=torch.int32, device=device - ) - - adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - adapter_segments = torch.tensor( - adapter_segments, dtype=torch.int32, device=device - ) - - if all_prefill_logprobs: - prefill_head_indices = None - prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 - elif no_prefill_logprobs: - prefill_head_indices = cu_seqlen_prefill[1:] - 1 - prefill_next_token_indices = None - else: - prefill_head_indices = torch.tensor( - torch.cat(prefill_head_indices), dtype=torch.int64, device=device - ) - prefill_next_token_indices = torch.tensor( - prefill_next_token_indices, dtype=torch.int64, device=device - ) top_n_tokens_tensor = torch.tensor( top_n_tokens, device=device, dtype=torch.int64 ) - slots = torch.tensor(slots, dtype=torch.int64, device=device) - block_tables_tensor = torch.zeros( (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" ) for i, request_blocks in enumerate(block_tables): block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) block_tables_tensor = block_tables_tensor.to(device) - prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device) + prompt_lengths_tensor = torch.tensor( + prompt_lengths, dtype=torch.int32, device=device + ) return cls( batch_id=pb.id, requests=pb.requests, requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - prefill_cache_indices=prefill_cache_indices, - start_slots=start_slots, - slot_indices=slot_indices, + input_ids=all_postfix_ids, block_tables=block_tables, block_tables_tensor=block_tables_tensor, - slots=slots, - prefix_lens=prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, - max_seqlen=max_seqlen, - prefill_head_indices=prefill_head_indices, - prefill_next_token_indices=prefill_next_token_indices, - prefill_cu_outlens=prefill_cu_outlens, + cache_lengths=cache_lengths, + max_input_length=max_input_length, + max_current_length=max_current_length, + prefilling=True, + prefilling_mask=[True] * len(pb.requests), + prefill_logprob_tokens=[None] * len(pb.requests), input_lengths=input_lengths, - input_lengths_tensor=input_lengths_tensor, + prompt_lengths=prompt_lengths, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, - prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, top_n_tokens_tensor=top_n_tokens_tensor, num_blocks=num_blocks, max_blocks=max_blocks, - adapter_meta=AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ), speculative_ids=None, + prompt_lengths_tensor=prompt_lengths_tensor, + # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` + position_ids=None, + cu_seqlen_prefill=None, + prefill_cache_indices=None, + slot_indices=None, + slots=None, + prefill_head_indices=None, + prefill_next_token_indices=None, + prefill_cu_outlens=None, + cache_lengths_tensor=None, + input_lengths_tensor=None, + adapter_meta=None, ) @classmethod @@ -533,7 +449,7 @@ class FlashCausalLMBatch(Batch): if len(request_ids) == len(self): return self - device = self.input_ids.device + device = self.block_tables_tensor.device # New values after filtering requests_idx_mapping = {} @@ -548,19 +464,23 @@ class FlashCausalLMBatch(Batch): # Create on CPU to only move to GPU once instead of at every copy slot_indices = torch.empty(len(request_ids), dtype=torch.int64) - max_seqlen = 0 + max_input_length = 0 + max_current_length = 0 requests = [] - start_slots = [] block_tables = [] all_input_ids = [] - prefix_ids = [] + input_ids = [] + prompt_lengths = [] input_lengths = [] - prefix_lens = [] + cache_lengths = [] prefix_offsets = [] read_offsets = [] + prefilling_mask = [] + prefill_logprob_tokens = [] + stopping_criterias = [] top_n_tokens = [] adapter_set = set() @@ -577,16 +497,23 @@ class FlashCausalLMBatch(Batch): requests.append(self.requests[idx]) + # Prefilling + request_prefilling = self.prefilling_mask[idx] + prefilling_mask.append(request_prefilling) + # Get length request_input_length = self.input_lengths[idx] - prefix_len = self.prefix_lens[idx] - max_seqlen = max(max_seqlen, request_input_length) + request_cache_length = self.cache_lengths[idx] + max_input_length = max(max_input_length, request_input_length) + max_current_length = max( + max_current_length, request_cache_length + request_input_length + ) all_input_ids.append(self.all_input_ids[idx]) - prefix_ids.append(self.prefix_ids[idx]) + prompt_lengths.append(self.prompt_lengths[idx]) input_lengths.append(request_input_length) - prefix_lens.append(prefix_len) + cache_lengths.append(request_cache_length) prefix_offsets.append(self.prefix_offsets[idx]) read_offsets.append(self.read_offsets[idx]) @@ -594,60 +521,79 @@ class FlashCausalLMBatch(Batch): stopping_criterias.append(stopping_criteria) top_n_tokens.append(self.top_n_tokens[idx]) + prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx]) ADAPTER_TO_INDEX = get_adapter_to_index() adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0) adapter_set.add(adapter_index) - remaining_tokens = ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - ) - request_block_table = self.block_tables[idx] num_blocks += len(request_block_table) block_tables.append(request_block_table) - start_slots.append(cumulative_max_length) - # Copy to tensor (CPU) - slot_indices[i] = cumulative_max_length + request_input_length - 1 + # Input ids if the request was part of a prefilling batch + # If the batch was decoding we can index into the tensor directly later + if self.prefilling: + input_ids.append(self.input_ids[idx]) + else: + # Copy to tensor (CPU) + slot_indices[i] = cumulative_max_length - # Set slice - slot_filtering_indices[ - self.start_slots[idx] : self.start_slots[idx] - + request_input_length - + remaining_tokens - - 1 - ] = True + remaining_tokens = ( + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + ) - cumulative_max_length += request_input_length + remaining_tokens - 1 + # Set slice + slot_filtering_indices[ + self.slot_indices[idx] : self.slot_indices[idx] + + request_input_length + + remaining_tokens + - 1 + ] = True + + cumulative_max_length += request_input_length + remaining_tokens - 1 max_blocks = max(max_blocks, len(request_block_table)) - # Index into tensors - input_ids = self.input_ids[indices] - position_ids = self.position_ids[indices] - adapter_indices = self.adapter_meta.adapter_indices[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices] block_tables_tensor = self.block_tables_tensor[indices] - input_lengths_tensor = self.input_lengths_tensor[indices] - slots = self.slots[slot_filtering_indices] - prefix_lens_tensor = self.prefix_lens_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) top_n_tokens_tensor = self.top_n_tokens_tensor[indices] speculative_ids = ( self.speculative_ids[indices] if self.speculative_ids is not None else None ) + prompt_lengths_tensor = self.prompt_lengths_tensor[indices] - start_slots = torch.tensor(start_slots, dtype=torch.int64) + if self.prefilling: + # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` + position_ids = None + slot_indices = None + slots = None + cache_lengths_tensor = None + input_lengths_tensor = None + adapter_meta = None + else: + # Index into tensors + input_ids = self.input_ids[indices] + position_ids = self.position_ids[indices] + adapter_indices = self.adapter_meta.adapter_indices[indices] + input_lengths_tensor = self.input_lengths_tensor[indices] + slots = self.slots[slot_filtering_indices] + cache_lengths_tensor = self.cache_lengths_tensor[indices] - # Move to GPU now that we have the whole tensor - slot_indices = slot_indices.to(device) + # Move to GPU now that we have the whole tensor + slot_indices = slot_indices.to(device) - adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - adapter_segments = torch.tensor( - adapter_segments, dtype=torch.int32, device=device - ) - # assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum() + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) + adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ) return type(self)( batch_id=self.batch_id, @@ -657,24 +603,28 @@ class FlashCausalLMBatch(Batch): position_ids=position_ids, cu_seqlen_prefill=None, prefill_cache_indices=None, - start_slots=start_slots, slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, - max_seqlen=max_seqlen, + max_input_length=max_input_length, + max_current_length=max_current_length, + prefilling=self.prefilling, + prefilling_mask=prefilling_mask, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, + prefill_logprob_tokens=prefill_logprob_tokens, + prompt_lengths=prompt_lengths, + prompt_lengths_tensor=prompt_lengths_tensor, input_lengths=input_lengths, input_lengths_tensor=input_lengths_tensor, - prefix_lens=prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, + cache_lengths=cache_lengths, + cache_lengths_tensor=cache_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, - prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -682,12 +632,7 @@ class FlashCausalLMBatch(Batch): num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=speculative_ids, - adapter_meta=AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ), + adapter_meta=adapter_meta, ) @classmethod @@ -697,74 +642,98 @@ class FlashCausalLMBatch(Batch): requests = [] requests_idx_mapping = {} + prefilling = False num_blocks = 0 total_batch_size = 0 total_slots = 0 max_blocks = 0 max_length = 0 - max_seqlen = 0 + max_input_length = 0 + max_current_length = 0 for b in batches: total_batch_size += len(b) - total_slots += len(b.slots) + max_blocks = max(max_blocks, b.max_blocks) + # If `b` is prefilling and was just filtered, `b.slots` is None + # `total_slots` is not used if any of the batches is prefilling + total_slots += len(b.slots) if not b.prefilling else 0 num_blocks += b.num_blocks speculative_length = ( b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 ) - max_blocks = max(max_blocks, b.max_blocks) - max_seqlen = max(max_seqlen, b.max_seqlen) + max_input_length = max(max_input_length, b.max_input_length) + max_current_length = max(max_current_length, b.max_current_length) max_length = max( max_length, max( - input_length + prompt_length + stopping_criteria.max_new_tokens + speculative_length - - stopping_criteria.current_tokens - for input_length, stopping_criteria in zip( - b.input_lengths, b.stopping_criterias + for prompt_length, stopping_criteria in zip( + b.prompt_lengths, b.stopping_criterias ) ), ) + prefilling = prefilling or b.prefilling - input_ids = batches[0].input_ids.new_empty(total_batch_size) - position_ids = batches[0].position_ids.new_empty(total_batch_size) - slots = batches[0].slots.new_empty(total_slots) - slot_indices = batches[0].slot_indices.new_empty(total_batch_size) - input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( + if prefilling: + input_ids = [] + # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` + position_ids = None + slots = None + slot_indices = None + cache_lengths_tensor = None + input_lengths_tensor = None + adapter_meta = None + adapter_segment_builder = None + else: + input_ids = batches[0].input_ids.new_empty(total_batch_size) + position_ids = batches[0].position_ids.new_empty(total_batch_size) + slots = batches[0].slots.new_empty(total_slots) + slot_indices = batches[0].slot_indices.new_empty(total_batch_size) + input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( + total_batch_size + ) + cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty( + total_batch_size + ) + total_indices_size = sum( + b.adapter_meta.adapter_indices.shape[0] for b in batches + ) + adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty( + total_indices_size + ) + adapter_segment_builder = SegmentConcatBuilder() + adapter_set = set() + + prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty( total_batch_size ) block_tables_tensor = batches[0].block_tables_tensor.new_zeros( (total_batch_size, max_blocks) ) - prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size) all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( (total_batch_size, max_length) ) top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( total_batch_size, ) - total_indices_size = sum( - b.adapter_meta.adapter_indices.shape[0] for b in batches - ) - adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty( - total_indices_size - ) - adapter_set = set() - adapter_segment_builder = SegmentConcatBuilder() - start_slots = [] block_tables = [] - prefix_lens = [] + cache_lengths = [] all_input_ids = [] - prefix_ids = [] + prompt_lengths = [] input_lengths = [] prefix_offsets = [] read_offsets = [] + prefill_logprob_tokens = [] + next_token_chooser_parameters = [] fsm_grammar_states = [] stopping_criterias = [] top_n_tokens = [] + prefilling_mask = [] # Cumulative length cumulative_batch_size = 0 @@ -783,32 +752,9 @@ class FlashCausalLMBatch(Batch): start_index = cumulative_batch_size end_index = cumulative_batch_size + len(batch) - slots_start_index = cumulative_slots - slots_end_index = cumulative_slots + len(batch.slots) # Copy tensors (GPU) - input_ids[start_index:end_index] = batch.input_ids - position_ids[start_index:end_index] = batch.position_ids - slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots - input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor - slots[slots_start_index:slots_end_index] = batch.slots - - # Copy over adapter indices - adapter_start_index = cumulative_adapter_indices_size - adapter_end_index = ( - cumulative_adapter_indices_size - + batch.adapter_meta.adapter_indices.shape[0] - ) - adapter_indices[adapter_start_index:adapter_end_index] = ( - batch.adapter_meta.adapter_indices - ) - cumulative_adapter_indices_size = adapter_end_index - adapter_set.update(batch.adapter_meta.adapter_set) - adapter_segment_builder.concat( - batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices - ) - all_input_ids_tensor[ start_index:end_index, : batch.all_input_ids_tensor.shape[1] ] = batch.all_input_ids_tensor[:, :max_length] @@ -816,20 +762,56 @@ class FlashCausalLMBatch(Batch): block_tables_tensor[ start_index:end_index, : batch.block_tables_tensor.shape[1] ] = batch.block_tables_tensor[:, :max_blocks] + prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor - prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor + if not prefilling: + slots_start_index = cumulative_slots + slots_end_index = cumulative_slots + len(batch.slots) - start_slots.append(batch.start_slots + cumulative_slots) + input_ids[start_index:end_index] = batch.input_ids + position_ids[start_index:end_index] = batch.position_ids + slots[slots_start_index:slots_end_index] = batch.slots + slot_indices[start_index:end_index] = ( + batch.slot_indices + cumulative_slots + ) + input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor + cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor + # Copy over adapter indices + adapter_start_index = cumulative_adapter_indices_size + adapter_end_index = ( + cumulative_adapter_indices_size + + batch.adapter_meta.adapter_indices.shape[0] + ) + adapter_indices[adapter_start_index:adapter_end_index] = ( + batch.adapter_meta.adapter_indices + ) + cumulative_adapter_indices_size = adapter_end_index + adapter_set.update(batch.adapter_meta.adapter_set) + adapter_segment_builder.concat( + batch.adapter_meta.adapter_segments, + batch.adapter_meta.segment_indices, + ) + + # Update + cumulative_slots += len(batch.slots) + else: + if isinstance(batch.input_ids, torch.Tensor): + batch.input_ids = batch.input_ids.view(-1, 1).tolist() + input_ids.extend(batch.input_ids) + + prefilling_mask.extend(batch.prefilling_mask) block_tables.extend(batch.block_tables) - prefix_lens.extend(batch.prefix_lens) + cache_lengths.extend(batch.cache_lengths) all_input_ids.extend(batch.all_input_ids) - prefix_ids.extend(batch.prefix_ids) + prompt_lengths.extend(batch.prompt_lengths) input_lengths.extend(batch.input_lengths) prefix_offsets.extend(batch.prefix_offsets) read_offsets.extend(batch.read_offsets) + prefill_logprob_tokens.extend(batch.prefill_logprob_tokens) + next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states) stopping_criterias.extend(batch.stopping_criterias) @@ -838,11 +820,6 @@ class FlashCausalLMBatch(Batch): # Update cumulative_batch_size += len(batch) - cumulative_slots += len(batch.slots) - - start_slots = torch.concat(start_slots) - - # assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum() next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, @@ -858,7 +835,14 @@ class FlashCausalLMBatch(Batch): else None ) - adapter_segments, adapter_segment_indices = adapter_segment_builder.build() + if adapter_segment_builder is not None: + adapter_segments, adapter_segment_indices = adapter_segment_builder.build() + adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ) return cls( batch_id=batches[0].batch_id, @@ -868,24 +852,28 @@ class FlashCausalLMBatch(Batch): position_ids=position_ids, cu_seqlen_prefill=None, prefill_cache_indices=None, - start_slots=start_slots, slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, - prefix_lens=prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, + cache_lengths=cache_lengths, + cache_lengths_tensor=cache_lengths_tensor, slots=slots, - max_seqlen=max_seqlen, + max_input_length=max_input_length, + max_current_length=max_current_length, + prefilling=prefilling, + prefilling_mask=prefilling_mask, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, + prefill_logprob_tokens=prefill_logprob_tokens, + prompt_lengths=prompt_lengths, + prompt_lengths_tensor=prompt_lengths_tensor, input_lengths=input_lengths, input_lengths_tensor=input_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, - prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -893,12 +881,195 @@ class FlashCausalLMBatch(Batch): num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=speculative_ids, - adapter_meta=AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ), + adapter_meta=adapter_meta, + ) + + def prepare_for_prefill(self): + # Prepare values if we need to continue prefilling + # Speculation must be ignored while we prefill even with chunking + # it simplifies everything + assert self.speculative_ids is None + + sliding_window = get_sliding_windows() + position_ids = [] + cu_seqlen_prefill = [0] + slot_indices = [] + prefill_cache_indices = [] + all_prefill_logprobs = True + no_prefill_logprobs = True + prefill_head_indices = [] + prefill_next_token_indices = [] + prefill_cu_outlens = [0] + + # Cumulative length + cumulative_length = 0 + cumulative_slot_tokens = 0 + prefill_out_cumulative_length = 0 + + slots = [] + adapter_indices_list = [] + adapter_set = set() + + for i, ( + r, + cache_length, + input_length, + prompt_length, + request_prefilling, + blocks, + ) in enumerate( + zip( + self.requests, + self.cache_lengths, + self.input_lengths, + self.prompt_lengths, + self.prefilling_mask, + self.block_tables, + ) + ): + next_chunk_length = input_length + # Position ids + request_position_ids = torch.arange( + cache_length, cache_length + input_length, dtype=torch.int32 + ) + position_ids.append(request_position_ids) + + # Add cumulative lengths of all previous inputs + cu_seqlen_prefill.append(cumulative_length + input_length) + + if not r.slots: + request_slots = [ + s + for b in blocks + for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) + ] + else: + request_slots = r.slots + + request_slots = request_slots[cache_length:] + request_slot_indices = torch.arange( + cumulative_slot_tokens, + cumulative_slot_tokens + input_length, + dtype=torch.int64, + ) + + # Create tensor to slice into the kv tensor in prefill + if sliding_window is not None: + request_prefill_cache_indices = torch.arange( + cumulative_length + max(0, input_length - sliding_window), + cumulative_length + input_length, + dtype=torch.int64, + ) + + # Prefill logprobs is ignored if the request is done prefilling + prefill_logprobs = r.prefill_logprobs and request_prefilling + + all_prefill_logprobs = all_prefill_logprobs and prefill_logprobs + no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs + + if prefill_logprobs: + prefill_head_indices.append( + torch.arange( + cumulative_length, + cumulative_length + input_length, + dtype=torch.int64, + ) + ) + prefill_next_token_indices.append( + prefill_out_cumulative_length + input_length - 1 + ) + prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) + prefill_out_cumulative_length += input_length + else: + prefill_head_indices.append( + torch.tensor( + [cumulative_length + input_length - 1], + dtype=torch.int64, + ) + ) + prefill_next_token_indices.append(prefill_out_cumulative_length) + prefill_cu_outlens.append(prefill_out_cumulative_length + 1) + prefill_out_cumulative_length += 1 + + slots.extend(request_slots) + slot_indices.append(request_slot_indices) + + if sliding_window is not None: + prefill_cache_indices.append(request_prefill_cache_indices) + + ADAPTER_TO_INDEX = get_adapter_to_index() + adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) + adapter_indices_list.append(torch.full((next_chunk_length,), adapter_index)) + adapter_set.add(adapter_index) + + # Update + cumulative_length += next_chunk_length + cumulative_slot_tokens += len(request_slots) + + device = self.block_tables_tensor.device + + if isinstance(self.input_ids, list): + if len(self) > 1: + input_ids = np.concatenate(self.input_ids, dtype=np.int64) + else: + input_ids = self.input_ids[0] + self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + + if len(self) > 1: + position_ids = torch.cat(position_ids) + slot_indices = torch.cat(slot_indices) + if sliding_window is not None: + prefill_cache_indices = torch.cat(prefill_cache_indices) + else: + position_ids = position_ids[0] + slot_indices = slot_indices[0] + if sliding_window is not None: + prefill_cache_indices = prefill_cache_indices[0] + + self.prefill_cu_outlens = prefill_cu_outlens + cu_seqlen_prefill = torch.tensor( + cu_seqlen_prefill, device=device, dtype=torch.int32 + ) + self.cu_seqlen_prefill = cu_seqlen_prefill + self.position_ids = position_ids.to(device) + self.slot_indices = slot_indices.to(device) + self.prefill_cache_indices = ( + prefill_cache_indices.to(device) if sliding_window is not None else None + ) + self.input_lengths_tensor = torch.tensor( + self.input_lengths, dtype=torch.int32, device=device + ) + + if all_prefill_logprobs: + prefill_head_indices = None + prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 + elif no_prefill_logprobs: + prefill_head_indices = cu_seqlen_prefill[1:] - 1 + prefill_next_token_indices = None + else: + prefill_head_indices = torch.cat(prefill_head_indices).to(device) + prefill_next_token_indices = torch.tensor( + prefill_next_token_indices, dtype=torch.int64, device=device + ) + + self.prefill_head_indices = prefill_head_indices + self.prefill_next_token_indices = prefill_next_token_indices + self.slots = torch.tensor(slots, dtype=torch.int64, device=device) + self.cache_lengths_tensor = torch.tensor( + self.cache_lengths, dtype=torch.int32, device=device + ) + adapter_indices = torch.cat(adapter_indices_list).to( + dtype=torch.int64, device=device + ) + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) + self.adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, ) def __len__(self): @@ -938,6 +1109,7 @@ class FlashCausalLM(Model): head_size: Optional[int] = None, skip_special_tokens: bool = True, kv_cache_dtype: Optional[torch.dtype] = None, + support_chunking: bool = True, ): self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() @@ -1065,6 +1237,7 @@ class FlashCausalLM(Model): rank=rank, world_size=world_size, sliding_window=config.sliding_window, + support_chunking=support_chunking, ) @property @@ -1101,11 +1274,11 @@ class FlashCausalLM(Model): position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) slots = torch.arange(bs, dtype=torch.int64, device=self.device) input_lengths = [max_s] * bs - prefix_lengths = [0] * bs + cache_lengths = [0] * bs input_lengths_tensor = ( torch.ones(bs, dtype=torch.int32, device=self.device) * max_s ) - prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) + cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) block_tables = torch.arange( max_bt, dtype=torch.int32, device=self.device ).repeat(bs) @@ -1115,7 +1288,7 @@ class FlashCausalLM(Model): block_tables = block_tables_to_ragged( block_tables=block_tables, input_lengths=input_lengths, - prefix_lens=prefix_lengths, + cache_lengths=cache_lengths, ) from text_generation_server.layers.attention.flashinfer import ( create_decode_state_cuda_graphs, @@ -1144,7 +1317,7 @@ class FlashCausalLM(Model): "block_tables": block_tables, "slots": slots, "input_lengths": input_lengths_tensor, - "prefix_lengths": prefix_lengths_tensor, + "cache_lengths": cache_lengths_tensor, "state": state, "graph": graph, } @@ -1156,11 +1329,11 @@ class FlashCausalLM(Model): cu_seqlen_prefill=None, input_lengths_tensor=input_lengths_tensor, state=state, - prefix_lens_tensor=prefix_lengths_tensor, + cache_lengths_tensor=cache_lengths_tensor, ): seqlen = Seqlen( input_lengths=input_lengths_tensor, - prefix_lengths=prefix_lengths_tensor, + cache_lengths=cache_lengths_tensor, cu_seqlen_q=None, max_q=1, max_k=max_s, @@ -1184,7 +1357,7 @@ class FlashCausalLM(Model): with torch.cuda.graph(graph, pool=MEM_POOL): seqlen = Seqlen( input_lengths=input_lengths_tensor, - prefix_lengths=prefix_lengths_tensor, + cache_lengths=cache_lengths_tensor, cu_seqlen_q=None, max_q=1, max_k=max_s, @@ -1207,6 +1380,7 @@ class FlashCausalLM(Model): def warmup(self, batch: FlashCausalLMBatch): # The warmup batch is the biggest batch we could ever receive + self.kv_cache = [] empty_cache() try: @@ -1226,7 +1400,7 @@ class FlashCausalLM(Model): _, batch, _ = self.generate_token(batch) except torch.cuda.OutOfMemoryError as e: raise RuntimeError( - f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " + f"Not enough memory to handle {batch.to_pb().current_tokens} prefill tokens. " f"You need to decrease `--max-batch-prefill-tokens`" ) from e @@ -1341,14 +1515,16 @@ class FlashCausalLM(Model): # Dummy value, some models (starcoder2) don't accept `None`. input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) - prefix_lens_tensor = torch.zeros(seqlen, dtype=torch.int32, device=self.device) + cache_lengths_tensor = torch.zeros( + seqlen, dtype=torch.int32, device=self.device + ) cu_seqlen_prefill = torch.tensor( [0, seqlen], device=self.device, dtype=torch.int32 ) max_s = seqlen seqlen = Seqlen( input_lengths=input_lengths, - prefix_lengths=prefix_lens_tensor, + cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, max_q=1, max_k=seqlen, @@ -1380,7 +1556,7 @@ class FlashCausalLM(Model): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - max_s = batch.max_seqlen + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices speculative_ids = batch.speculative_ids @@ -1399,8 +1575,8 @@ class FlashCausalLM(Model): input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) - prefix_lens_tensor = ( - batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + cache_lengths_tensor = ( + batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length) ).reshape(-1) # Add Copy the block tables for all members @@ -1422,10 +1598,12 @@ class FlashCausalLM(Model): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - prefix_lens_tensor = batch.prefix_lens_tensor - max_s = batch.max_seqlen + cache_lengths_tensor = batch.cache_lengths_tensor + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices + print(slots) + if cu_seqlen_prefill is None and self.max_past() is not None: # In decode, not prefill, we're actually overwriting the KV-cache # in a circular buffer mode. @@ -1445,21 +1623,20 @@ class FlashCausalLM(Model): block_tables = block_tables_to_ragged( block_tables=block_tables, input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, + cache_lengths=batch.cache_lengths, ) with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=cu_seqlen_prefill, input_lengths_tensor=input_lengths, - prefix_lens_tensor=prefix_lens_tensor, + cache_lengths_tensor=cache_lengths_tensor, ): - max_k = (input_lengths + prefix_lens_tensor).max().item() seqlen = Seqlen( input_lengths=input_lengths, - prefix_lengths=prefix_lens_tensor, + cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, - max_q=max_s, - max_k=max_k, + max_q=batch.max_input_length, + max_k=batch.max_current_length, ) logits, speculative_logits = self.model.forward( input_ids=input_ids, @@ -1486,7 +1663,7 @@ class FlashCausalLM(Model): block_tables = block_tables_to_ragged( block_tables=block_tables, input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, + cache_lengths=batch.cache_lengths, ) # assert block_tables.shape[0] >= slots.shape[0] cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables @@ -1501,14 +1678,16 @@ class FlashCausalLM(Model): cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths - cuda_graph["prefix_lengths"].zero_() - cuda_graph["prefix_lengths"][: prefix_lens_tensor.shape[0]] = prefix_lens_tensor + cuda_graph["cache_lengths"].zero_() + cuda_graph["cache_lengths"][ + : cache_lengths_tensor.shape[0] + ] = cache_lengths_tensor with self._forward_context( block_tables=cuda_graph["block_tables"], cu_seqlen_prefill=None, input_lengths_tensor=cuda_graph["input_lengths"], - prefix_lens_tensor=cuda_graph["prefix_lengths"], + cache_lengths_tensor=cuda_graph["cache_lengths"], state=cuda_graph["state"], ): # Replay the graph @@ -1528,7 +1707,10 @@ class FlashCausalLM(Model): self, batch: FlashCausalLMBatch ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]: start = time.time_ns() - prefill = batch.cu_seqlen_prefill is not None + prefill = batch.prefilling + if prefill: + batch.prepare_for_prefill() + prefill_logprobs = batch.prefill_next_token_indices is not None # Update adapter indices for speculative tokens (if present) @@ -1570,14 +1752,62 @@ class FlashCausalLM(Model): if prefill_logprobs else speculative_logits ) - next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty( - len(batch) - ) - + if len(batch) > 1 and prefill_logprobs: + # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs + # When batch == 1, we will just use the batch.input_ids values directly + prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) else: + prefill_logprobs = None next_token_logits = out next_adapter_indices = batch.adapter_meta.adapter_indices + finished_prefilling = True + next_chunk_lengths = [] + current_prefilling_mask = batch.prefilling_mask + if prefill: + if get_support_chunking(): + next_prefilling_mask = [] + # Budget in tokens for the next batch + # We remove (len(batch) - 1) to always have enough space for at least a single decode + # for the remaining requests -1 because the first request does not need to be removed from the budget + # (ex: you have one request in the batch, you want it to take the full budget not budget -1) + batch_budget = get_max_prefill_tokens() - (len(batch) - 1) + # We reverse to prioritize older requests + # zip() is not reversible so reverse the underlying lists instead + for cache_length, input_length, prompt_length in zip( + reversed(batch.cache_lengths), + reversed(batch.input_lengths), + reversed(batch.prompt_lengths), + ): + remaining_prefill_tokens = max( + prompt_length - cache_length - input_length, 0 + ) + if remaining_prefill_tokens > 0: + next_chunk_length = max( + min(remaining_prefill_tokens, batch_budget), 1 + ) + batch_budget -= next_chunk_length + finished_prefilling = False + next_prefilling_mask.append(True) + else: + # FIXME: use true number of accepted tokens instead of 1 + # Since speculation will be turned off, this is always true + next_chunk_length = 1 + next_prefilling_mask.append(False) + next_chunk_lengths.append(next_chunk_length) + + # Reverse back the obtained values² + next_chunk_lengths.reverse() + next_prefilling_mask.reverse() + else: + # The model does not support chunking + # We know we only do a single prefill + finished_prefilling = True + next_prefilling_mask = [False] * len(batch) + + batch.prefilling = not finished_prefilling + batch.prefilling_mask = next_prefilling_mask + speculate = get_speculate() ( next_input_ids, @@ -1586,7 +1816,7 @@ class FlashCausalLM(Model): accepted_ids, speculative_ids, ) = batch.next_token_chooser( - batch.all_input_ids_tensor[:, : batch.max_seqlen], + batch.all_input_ids_tensor[:, : batch.max_current_length], next_token_logits, speculate, batch.speculative_ids, @@ -1597,29 +1827,28 @@ class FlashCausalLM(Model): batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids ) - if prefill: - if len(batch) > 1 and prefill_logprobs: - # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs - # When batch == 1, we will just use the batch.input_ids values directly - prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) - + # Since we are done prefilling, all the tensors that were concatenating values for all the requests + # instantly become of shape [BATCH_SIZE] + if prefill and finished_prefilling: next_position_ids = batch.position_ids.new_empty(len(batch)) batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] - # We do not need cu_seqlen_prefill anymore - batch.cu_seqlen_prefill = None - else: - prefill_logprobs = None + next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty( + len(batch) + ) + elif not prefill: next_position_ids = batch.position_ids - # Cumulative length - cumulative_length = 0 - - # Results - generations: List[Generation] = [] - stopped = True - # Zipped iterator - iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids) + iterator = zip( + batch.requests, + batch.prompt_lengths, + batch.cache_lengths, + batch.input_lengths, + batch.all_input_ids, + accepted_ids, + current_prefilling_mask, + batch.prefilling_mask, + ) # We do two for loops as the first one can run completely asynchronously from the GPU while for the second # one, we need to first do a GPU <-> CPU sync @@ -1627,16 +1856,22 @@ class FlashCausalLM(Model): # For each member of the batch index = 0 - for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator): - # Indexing metadata - start_index = cumulative_length - end_index = cumulative_length + input_length - - if prefill: + # Cumulative length + cumulative_length = 0 + for i, ( + request, + prompt_length, + cache_length, + input_length, + all_input_ids, + n_accepted_ids, + request_was_prefilling, + request_is_prefilling, + ) in enumerate(iterator): + if prefill and finished_prefilling: # Indexing metadata - out_start_index = batch.prefill_cu_outlens[i] - out_end_index = batch.prefill_cu_outlens[i + 1] - out_length = out_end_index - out_start_index + _start_index = cumulative_length + end_index = cumulative_length + input_length # Initialize position_ids # In decode, we do not need this as we can just increment position ids @@ -1648,41 +1883,43 @@ class FlashCausalLM(Model): end_index - 1 ] - # Used to gather prefill logprobs - # Copy batch.input_ids to prefill_token_indices - if prefill_logprobs: - if len(batch) > 1: - prefill_tokens_indices[out_start_index : out_end_index - 1] = ( - batch.input_ids[start_index + 1 : start_index + out_length] - ) - else: - # Set prefill_tokens_indices to the correct slice - prefill_tokens_indices = batch.input_ids[ - start_index + 1 : start_index + out_length - ] + # Used to gather prefill logprobs + # Copy batch.all_input_ids_tensor to prefill_token_indices + if request.prefill_logprobs and request_was_prefilling: + # Indexing metadata + out_start_index = batch.prefill_cu_outlens[i] + out_end_index = batch.prefill_cu_outlens[i + 1] - for j in range(n_accepted_ids): - batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index] - index += 1 + # Logprobs generated by the model are for the next token + # So we need to translate the id tensor by 1 + ids = batch.all_input_ids_tensor[ + i, cache_length + 1 : cache_length + input_length + 1 + ] + if len(batch) > 1: + prefill_tokens_indices[out_start_index:out_end_index] = ids + else: + # Set prefill_tokens_indices to the correct slice + prefill_tokens_indices = ids + if not request_is_prefilling: + # Only save tokens if we are done prefilling for this request + for j in range(n_accepted_ids): + batch.all_input_ids_tensor[i, cache_length + input_length + j] = ( + next_input_ids[index + j] + ) + index += n_accepted_ids cumulative_length += input_length # Update values - batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] - batch.speculative_ids = speculative_ids - batch.position_ids = next_position_ids + accepted_ids - batch.input_lengths_tensor += accepted_ids - batch.slot_indices += accepted_ids - batch.adapter_meta.adapter_indices = next_adapter_indices - - if prefill: - # adjust segment lengths to account for all request lengths being 1 during decoding - adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) - batch.adapter_meta.adapter_segments = torch.tensor( - adapter_segments, - dtype=torch.int32, - device=batch.adapter_meta.adapter_segments.device, - ) + # These values can be updated without a GPU -> CPU sync + if not prefill or (prefill and finished_prefilling): + batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] + batch.speculative_ids = speculative_ids + batch.position_ids = next_position_ids + accepted_ids + batch.cache_lengths_tensor += batch.input_lengths_tensor + batch.input_lengths_tensor = accepted_ids.to(dtype=torch.int32) + batch.slot_indices += accepted_ids + batch.adapter_meta.adapter_indices = next_adapter_indices if prefill and prefill_logprobs: # Get prefill logprobs @@ -1693,183 +1930,292 @@ class FlashCausalLM(Model): # GPU <-> CPU sync prefill_logprobs = prefill_logprobs.view(-1).tolist() + # Does a GPU <-> CPU sync internally + if prefill and finished_prefilling: + # adjust segment lengths to account for all request lengths being 1 during decoding + adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) + batch.adapter_meta.adapter_segments = torch.tensor( + adapter_segments, + dtype=torch.int32, + device=batch.adapter_meta.adapter_segments.device, + ) + # GPU <-> CPU sync next_token_logprobs = next_token_logprobs.tolist() next_token_ids = next_input_ids.tolist() accepted_ids = accepted_ids.tolist() + + # Update values if we need to continue prefilling + # This represents the `else` case of the `Update values` if above + # but since this require the `next_token_ids` to be on CPU, it is better to do it here + if prefill and not finished_prefilling: + # Speculation must be ignored while we prefill even with chunking + # it simplifies everything + assert batch.speculative_ids is None + + all_postfix_ids = [] + for i, ( + request_prefilling, + next_token_id, + all_input_ids, + cache_length, + input_length, + next_chunk_length, + ) in enumerate( + zip( + batch.prefilling_mask, + next_token_ids, + batch.all_input_ids, + batch.cache_lengths, + batch.input_lengths, + next_chunk_lengths, + ) + ): + if request_prefilling: + next_cache_length = cache_length + input_length + # Get new prompt IDs to prefill + postfix_ids = all_input_ids[ + next_cache_length : next_cache_length + next_chunk_length + ] + else: + # This request is done prefilling, the new id is the one selected the sampling method + postfix_ids = [next_token_id] + + all_postfix_ids.append(postfix_ids) + + batch.input_ids = all_postfix_ids + start_decode = time.time_ns() + # Results + generations: List[Generation] = [] + stopped = True + # Zipped iterator iterator = zip( batch.requests, + batch.prompt_lengths, + batch.cache_lengths, batch.input_lengths, batch.prefix_offsets, batch.read_offsets, batch.stopping_criterias, batch.all_input_ids, - batch.prefix_ids, batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, batch.top_n_tokens, + current_prefilling_mask, + batch.prefilling_mask, accepted_ids, batch_top_token_ids, batch_top_token_logprobs, ) + # Reset max_input_length + batch.max_input_length = 0 # For each member of the batch index = 0 for i, ( request, + prompt_length, + cache_length, input_length, prefix_offset, read_offset, stopping_criteria, all_input_ids, - prefix_ids, do_sample, seed, top_n_tokens, + request_was_prefilling, + request_is_prefilling, n_accepted_ids, top_token_ids, top_token_logprobs, ) in enumerate(iterator): - # Append next token to all tokens - next_token_texts = [] - left = 0 - - if n_accepted_ids > 1: - log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}") - - current_stopped = False - for j in range(index, index + n_accepted_ids): - # Generated token - next_token_id = next_token_ids[j] - all_input_ids.append(next_token_id) - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids, - prefix_offset, - read_offset, - ) - next_token_texts.append(next_token_text) - - stop, reason = stopping_criteria( - next_token_id, - next_token_text, - ) - - if stop: - left = index + n_accepted_ids - j - 1 - current_stopped = True - break - else: - current_stopped = False - stopped = stopped and current_stopped - - _next_token_ids = next_token_ids[index : index + n_accepted_ids - left] - _next_token_logprobs = next_token_logprobs[ - index : index + n_accepted_ids - left - ] - index += n_accepted_ids - - # Shard generations - # All generations will be appended in the rust sharded client - if i % self.world_size == self.rank: - if stop: - # Decode generated tokens - output_text, _, _ = self.decode_token( - all_input_ids, - prefix_offset=len(all_input_ids) - - stopping_criteria.current_tokens - - 1, - read_offset=len(all_input_ids) - - stopping_criteria.current_tokens, - skip_special_tokens=True, - ) - generated_text = GeneratedText( - output_text, - stopping_criteria.current_tokens, - reason, - seed if do_sample else None, - ) - else: - generated_text = None - + # Compute logprobs first as, even though we might skip the token, + # it can still be required to compute the logprobs + # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need + # this state to be stable + if request.id % self.world_size == self.rank: # Prefill - if prefill and request.prefill_logprobs: + if request_was_prefilling and request.prefill_logprobs: out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] + if not request_is_prefilling: + # The request is dones prefilling, meaning that we started generating new tokens + # The last logprob is a logprob for a generated token that was not part of the prompt + # We need to remove it + out_end_index -= 1 + + request_prefill_logprobs = prefill_logprobs[ + out_start_index:out_end_index + ] + # Logprobs generated by the model are for the next token + # So we need to translate the id tensor by 1 + prefill_token_ids = all_input_ids[ + cache_length + 1 : cache_length + input_length + 1 + ] + + past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i] + + if past_prefill_logprob_tokens is None: + # add nan for cached prompt tokens/first token + request_prefill_logprobs = [float("nan")] * ( + cache_length + 1 + ) + request_prefill_logprobs + prefill_token_ids = ( + all_input_ids[: cache_length + 1] + prefill_token_ids + ) - # Remove generated token to only have prefill and add nan for first prompt token - request_prefill_logprobs = ( - [float("nan")] * (len(prefix_ids) + 1) - ) + prefill_logprobs[out_start_index : out_end_index - 1] - prefill_token_ids = all_input_ids[:-1] prefill_texts = self.tokenizer.batch_decode( - prefix_ids + prefill_token_ids, + prefill_token_ids, clean_up_tokenization_spaces=False, skip_special_tokens=False, ) - prefill_tokens = Tokens( - prefix_ids + prefill_token_ids, + prefill_logprob_tokens = Tokens( + prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special=[], ) - else: - prefill_tokens = None - - if top_n_tokens > 0: - all_top_tokens = [] - for top_token_ids, top_token_logprobs in zip( - top_token_ids, top_token_logprobs - ): - toptoken_texts = self.tokenizer.batch_decode( - top_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, + if past_prefill_logprob_tokens is not None: + prefill_logprob_tokens = ( + past_prefill_logprob_tokens + prefill_logprob_tokens ) - special_toptokens = [ - token_id in self.all_special_ids - for token_id in top_token_ids - ] - top_tokens = Tokens( - top_token_ids, - top_token_logprobs, - toptoken_texts, - special_toptokens, - ) - all_top_tokens.append(top_tokens) - top_tokens = all_top_tokens + + batch.prefill_logprob_tokens[i] = prefill_logprob_tokens else: - top_tokens = None + batch.prefill_logprob_tokens[i] = None - generation = Generation( - request.id, - prefill_tokens, - Tokens( - _next_token_ids, - _next_token_logprobs, - next_token_texts, - [nid in self.all_special_ids for nid in _next_token_ids], - ), - generated_text, - top_tokens, - ) + # If it is, the tokens we decoded should be ignored + if request_is_prefilling: + # Make sure that we do not stop as even though this request did not create a token, it is still + # processing + stopped = False + new_input_length = next_chunk_lengths[i] + else: + new_input_length = n_accepted_ids + # Append next token to all tokens + next_token_texts = [] + left = 0 - generations.append(generation) + if n_accepted_ids > 1: + log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}") - # accept each new token for this specific request since we may - # have more than one new token per request with speculative decoding - for next_token_id in _next_token_ids: - batch.next_token_chooser = ( - batch.next_token_chooser.advance_grammar_single(i, next_token_id) - ) + current_stopped = False + for j in range(index, index + n_accepted_ids): + # Generated token + next_token_id = next_token_ids[j] + all_input_ids.append(next_token_id) + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids, + prefix_offset, + read_offset, + ) + next_token_texts.append(next_token_text) + + stop, reason = stopping_criteria( + next_token_id, + next_token_text, + ) + + if stop: + left = index + n_accepted_ids - j - 1 + current_stopped = True + break + else: + current_stopped = False + stopped = stopped and current_stopped + + _next_token_ids = next_token_ids[index : index + n_accepted_ids - left] + _next_token_logprobs = next_token_logprobs[ + index : index + n_accepted_ids - left + ] + + # Shard generations + # All generations will be appended in the rust sharded client + if request.id % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text, _, _ = self.decode_token( + all_input_ids, + prefix_offset=len(all_input_ids) + - stopping_criteria.current_tokens + - 1, + read_offset=len(all_input_ids) + - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + generated_text = GeneratedText( + output_text, + stopping_criteria.current_tokens, + reason, + seed if do_sample else None, + ) + else: + generated_text = None + + if top_n_tokens > 0: + all_top_tokens = [] + for top_token_ids, top_token_logprobs in zip( + top_token_ids, top_token_logprobs + ): + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids + for token_id in top_token_ids + ] + top_tokens = Tokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + all_top_tokens.append(top_tokens) + top_tokens = all_top_tokens + else: + top_tokens = None + + generation = Generation( + request.id, + batch.prefill_logprob_tokens[i], + Tokens( + _next_token_ids, + _next_token_logprobs, + next_token_texts, + [nid in self.all_special_ids for nid in _next_token_ids], + ), + generated_text, + top_tokens, + ) + + generations.append(generation) + + # accept each new token for this specific request since we may + # have more than one new token per request with speculative decoding + for next_token_id in _next_token_ids: + batch.next_token_chooser = ( + batch.next_token_chooser.advance_grammar_single( + i, next_token_id + ) + ) # Update values - batch.input_lengths[i] = input_length + n_accepted_ids - if batch.input_lengths[i] > batch.max_seqlen: - batch.max_seqlen = batch.input_lengths[i] + index += n_accepted_ids + current_cache_length = cache_length + input_length + batch.cache_lengths[i] = current_cache_length + current_input_length = new_input_length + batch.max_input_length = max(batch.max_input_length, current_input_length) + batch.input_lengths[i] = current_input_length + current_length = current_cache_length + current_input_length + batch.max_current_length = max(batch.max_current_length, current_length) + batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids @@ -1880,9 +2226,13 @@ class FlashCausalLM(Model): decode_ns = time.time_ns() - start_decode return generations, None, (forward_ns, decode_ns) - batch.prefill_cu_outlens = None - batch.prefill_head_indices = None - batch.prefill_next_token_indices = None + if prefill and finished_prefilling: + # We do not need prefill tensors anymore + batch.cu_seqlen_prefill = None + batch.prefill_cache_indices = None + batch.prefill_cu_outlens = None + batch.prefill_head_indices = None + batch.prefill_next_token_indices = None forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode @@ -1894,7 +2244,7 @@ class FlashCausalLM(Model): block_tables: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], input_lengths_tensor: torch.Tensor, - prefix_lens_tensor: torch.Tensor, + cache_lengths_tensor: torch.Tensor, state: Optional[Any] = None, ) -> ContextManager: if ATTENTION != "flashinfer": @@ -1905,8 +2255,6 @@ class FlashCausalLM(Model): use_prefill_with_paged_kv_state, ) - # has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens) - if cu_seqlen_prefill is not None: return use_prefill_with_paged_kv_state( state=( @@ -1915,11 +2263,11 @@ class FlashCausalLM(Model): # block_tables=block_tables_to_ragged( # block_tables=block_tables, # input_lengths=input_lengths, - # prefix_lens=prefix_lens, + # cache_lengths=cache_lengths, # ), block_tables=block_tables, cu_seqlens=cu_seqlen_prefill, - input_lengths=input_lengths_tensor + prefix_lens_tensor, + input_lengths=input_lengths_tensor + cache_lengths_tensor, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_size=self.head_size, @@ -1931,7 +2279,7 @@ class FlashCausalLM(Model): assert input_lengths_tensor is not None return use_decode_state( state=state if state is not None else self.decode_state, - input_lengths=input_lengths_tensor + prefix_lens_tensor, + input_lengths=input_lengths_tensor + cache_lengths_tensor, block_tables=block_tables, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, @@ -1943,19 +2291,19 @@ class FlashCausalLM(Model): def block_tables_to_ragged( - *, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int] + *, block_tables: torch.Tensor, input_lengths: List[int], cache_lengths: List[int] ) -> torch.Tensor: """Convert block table to ragged format compatible with FlashInfer.""" - assert len(input_lengths) == len(prefix_lens) + assert len(input_lengths) == len(cache_lengths) - total_len = sum(input_lengths) + sum(prefix_lens) + total_len = sum(input_lengths) + sum(cache_lengths) block_tables_ragged = torch.empty( total_len, dtype=torch.int32, device=block_tables.device ) offset = 0 - for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)): - seq_len = prefix_len + input_length + for i, (input_length, cache_length) in enumerate(zip(input_lengths, cache_lengths)): + seq_len = cache_length + input_length block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len] offset += seq_len diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 6c518c2c..4ac6a6b4 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,9 +5,14 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master -PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"} +ATTENTION = os.environ["ATTENTION"] +# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0" +PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in { + "1", + "true", +} +PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "0").lower() in {"1", "true"} log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") -ATTENTION = os.getenv("ATTENTION") _expected = {"paged", "flashdecoding", "flashinfer"} assert ( ATTENTION in _expected @@ -18,7 +23,7 @@ if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}: raise RuntimeError("Prefix caching is only supported with flashinfer") MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None -TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95")) +TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90")) assert TGI_WIGGLE_ROOM > 0 assert TGI_WIGGLE_ROOM < 1 diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 9a7a6fe1..34b74ba8 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -83,6 +83,7 @@ class IdeficsCausalLMBatch(Batch): request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.max_tokens, + current_tokens=len(self), ) @classmethod diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index f6dcde68..dfc61fb8 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -116,6 +116,7 @@ class MambaBatch(Batch): request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.max_tokens, + current_tokens=len(self), ) @classmethod diff --git a/server/text_generation_server/models/mllama_causal_lm.py b/server/text_generation_server/models/mllama_causal_lm.py index 9e19e171..6399f92c 100644 --- a/server/text_generation_server/models/mllama_causal_lm.py +++ b/server/text_generation_server/models/mllama_causal_lm.py @@ -1,14 +1,17 @@ -from io import BytesIO -from PIL import Image import torch + +import numpy as np + from typing import Iterable, Optional, Tuple, List, Dict from text_generation_server.pb.generate_pb2 import Request - +from io import BytesIO +from PIL import Image from dataclasses import dataclass from opentelemetry import trace from transformers import ( PreTrainedTokenizerBase, ) + from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM from text_generation_server.pb import generate_pb2 from text_generation_server.models.flash_causal_lm import ( @@ -167,6 +170,13 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp( max=config.text_config.vocab_size - 1 ) + if isinstance(batch.input_ids, list): + if len(batch) > 1: + input_ids = np.concatenate(batch.input_ids, dtype=np.int64) + else: + input_ids = batch.input_ids[0] + batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1) if image_inputs is not None: @@ -190,7 +200,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): class MllamaCausalLM(VlmCausalLM): def forward( self, - batch: VlmCausalLMBatch, + batch: MllamaCausalLMBatch, adapter_data: Optional[Dict[str, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Model Forward @@ -202,7 +212,7 @@ class MllamaCausalLM(VlmCausalLM): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - max_s = batch.max_seqlen + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices speculative_ids = batch.speculative_ids @@ -221,8 +231,8 @@ class MllamaCausalLM(VlmCausalLM): input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) - prefix_lens_tensor = ( - batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + cache_lengths_tensor = ( + batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length) ).reshape(-1) # Add Copy the block tables for all members @@ -244,8 +254,8 @@ class MllamaCausalLM(VlmCausalLM): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - prefix_lens_tensor = batch.prefix_lens_tensor - max_s = batch.max_seqlen + cache_lengths_tensor = batch.cache_lengths_tensor + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices if cu_seqlen_prefill is None and self.max_past() is not None: @@ -254,7 +264,6 @@ class MllamaCausalLM(VlmCausalLM): # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) - bs = input_ids.shape[0] # Try to find an associated cuda graph bs = input_ids.shape[0] sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) @@ -269,26 +278,24 @@ class MllamaCausalLM(VlmCausalLM): # Only run cuda graphs when there's no images. or batch.cross_attention_states is not None ): - input_lengths = input_lengths + prefix_lens_tensor if PREFIX_CACHING: block_tables = block_tables_to_ragged( block_tables=block_tables, input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, + cache_lengths=batch.cache_lengths, ) with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=cu_seqlen_prefill, input_lengths_tensor=input_lengths, - prefix_lens_tensor=prefix_lens_tensor, + cache_lengths_tensor=cache_lengths_tensor, ): - max_k = (input_lengths + prefix_lens_tensor).max().item() seqlen = Seqlen( input_lengths=input_lengths, - prefix_lengths=prefix_lens_tensor, + cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, - max_q=max_s, - max_k=max_k, + max_q=batch.max_input_length, + max_k=batch.max_current_length, ) if batch.pixel_values is not None: @@ -330,22 +337,34 @@ class MllamaCausalLM(VlmCausalLM): block_tables = block_tables_to_ragged( block_tables=block_tables, input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, + cache_lengths=batch.cache_lengths, ) cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables else: cuda_graph["block_tables"][ : block_tables.shape[0], : block_tables.shape[1] ] = block_tables + + # XXX: This is working only because block 0 is reserved for the healthcheck + # so it doesn't matter if we override it with bogus values. cuda_graph["slots"].fill_(0) cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( - input_lengths + prefix_lens_tensor - ) + cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + cuda_graph["cache_lengths"].zero_() + cuda_graph["cache_lengths"][ + : cache_lengths_tensor.shape[0] + ] = cache_lengths_tensor - # Replay the graph - cuda_graph["graph"].replay() + with self._forward_context( + block_tables=cuda_graph["block_tables"], + cu_seqlen_prefill=None, + input_lengths_tensor=cuda_graph["input_lengths"], + cache_lengths_tensor=cuda_graph["cache_lengths"], + state=cuda_graph["state"], + ): + # Replay the graph + cuda_graph["graph"].replay() # Slice output to the correct shape speculative_logits = ( diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 20402e07..b3630013 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -5,8 +5,17 @@ from abc import ABC, abstractmethod from typing import List, Tuple, Optional, TypeVar, Type, Dict from collections import defaultdict from transformers import PreTrainedTokenizerBase +from loguru import logger +from text_generation_server.models.globals import ( + ATTENTION, + PREFIX_CACHING, + BLOCK_SIZE, + PREFILL_CHUNKING, +) from text_generation_server.models.types import Batch, Generation +from text_generation_server.utils.log import log_master +from text_generation_server.utils.prefill_chunking import set_support_chunking from text_generation_server.utils.speculate import get_speculate from text_generation_server.pb.generate_pb2 import InfoResponse from text_generation_server.adapters.weights import LayerAdapterWeights @@ -31,6 +40,7 @@ class Model(ABC): sliding_window: Optional[int] = None, speculate: Optional[int] = None, adapter_id: str = BASE_MODEL_ADAPTER_ID, + support_chunking: bool = False, ): self.model_id = model_id self.model = model.eval() @@ -60,6 +70,29 @@ class Model(ABC): speculate = get_speculate() self.speculate = speculate + support_chunking = support_chunking and PREFILL_CHUNKING + + if speculate != 0 and support_chunking: + log_master( + logger.warning, + "Prefill chunking does not support speculation yet. " + "Prefill chunking will be turned off", + ) + support_chunking = False + if ATTENTION not in ["flashinfer", "flashdecoding"] and support_chunking: + log_master( + logger.warning, + "Prefill chunking is only supported with `flashinfer` or `flashdecoding` attention types.", + ) + support_chunking = False + + log_master( + logger.info, f"Using experimental prefill chunking = {support_chunking}" + ) + + self.support_chunking = support_chunking + set_support_chunking(support_chunking) + self.has_position_ids = ( inspect.signature(model.forward).parameters.get("position_ids", None) is not None @@ -78,6 +111,10 @@ class Model(ABC): device_type=self.device.type, window_size=self.sliding_window, speculate=self.speculate, + support_chunking=self.support_chunking, + use_prefix_caching=PREFIX_CACHING, + attention_impl=ATTENTION, + block_size=BLOCK_SIZE, ) @property diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 91c99c50..3880a438 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -80,6 +80,7 @@ class Seq2SeqLMBatch(Batch): request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.max_tokens, + current_tokens=len(self.decoder_input_ids), ) @classmethod diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index d4e7cca7..ed9ae989 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -74,6 +74,14 @@ class Tokens: def __len__(self): return len(self.token_ids) + def __add__(self, other: "Tokens") -> "Tokens": + return Tokens( + self.token_ids + other.token_ids, + self.logprobs + other.logprobs, + self.texts + other.texts, + self.is_special + other.is_special, + ) + @dataclass class Generation: diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 7f7d2e4d..150cf0d0 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -271,6 +271,8 @@ class VlmCausalLM(FlashCausalLM): model_id=model_id, revision=revision, trust_remote_code=trust_remote_code, + # FIXME: VLM do not work with context chunking yet + support_chunking=False, **kwargs, ) @@ -295,7 +297,7 @@ class VlmCausalLM(FlashCausalLM): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - max_s = batch.max_seqlen + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices speculative_ids = batch.speculative_ids @@ -314,8 +316,8 @@ class VlmCausalLM(FlashCausalLM): input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) - prefix_lens_tensor = ( - batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + cache_lengths_tensor = ( + batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length) ).reshape(-1) # Add Copy the block tables for all members @@ -337,8 +339,8 @@ class VlmCausalLM(FlashCausalLM): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - prefix_lens_tensor = batch.prefix_lens_tensor - max_s = batch.max_seqlen + cache_lengths_tensor = batch.cache_lengths_tensor + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices if cu_seqlen_prefill is None and self.max_past() is not None: @@ -347,7 +349,6 @@ class VlmCausalLM(FlashCausalLM): # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) - bs = input_ids.shape[0] # Try to find an associated cuda graph bs = input_ids.shape[0] sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) @@ -357,26 +358,24 @@ class VlmCausalLM(FlashCausalLM): else: cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: - input_lengths = input_lengths + prefix_lens_tensor - if PREFIX_CACHING: + if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, + cache_lengths=batch.cache_lengths, ) with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=cu_seqlen_prefill, input_lengths_tensor=input_lengths, - prefix_lens_tensor=prefix_lens_tensor, + cache_lengths_tensor=cache_lengths_tensor, ): - max_k = (input_lengths + prefix_lens_tensor).max().item() seqlen = Seqlen( input_lengths=input_lengths, - prefix_lengths=prefix_lens_tensor, + cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, - max_q=max_s, - max_k=max_k, + max_q=batch.max_input_length, + max_k=batch.max_current_length, ) logits, speculative_logits = self.model.forward( input_ids=input_ids, @@ -411,22 +410,34 @@ class VlmCausalLM(FlashCausalLM): block_tables = block_tables_to_ragged( block_tables=block_tables, input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, + cache_lengths=batch.cache_lengths, ) cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables else: cuda_graph["block_tables"][ : block_tables.shape[0], : block_tables.shape[1] ] = block_tables - cuda_graph["slots"].fill_(-1) + + # XXX: This is working only because block 0 is reserved for the healthcheck + # so it doesn't matter if we override it with bogus values. + cuda_graph["slots"].fill_(0) cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( - input_lengths + prefix_lens_tensor - ) + cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + cuda_graph["cache_lengths"].zero_() + cuda_graph["cache_lengths"][ + : cache_lengths_tensor.shape[0] + ] = cache_lengths_tensor - # Replay the graph - cuda_graph["graph"].replay() + with self._forward_context( + block_tables=cuda_graph["block_tables"], + cu_seqlen_prefill=None, + input_lengths_tensor=cuda_graph["input_lengths"], + cache_lengths_tensor=cuda_graph["cache_lengths"], + state=cuda_graph["state"], + ): + # Replay the graph + cuda_graph["graph"].replay() # Slice output to the correct shape speculative_logits = ( diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 46e342a4..aef00fb5 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -15,6 +15,7 @@ from text_generation_server.cache import Cache from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.models import Model, get_model_with_lora_adapters from text_generation_server.utils.adapter import AdapterInfo +from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens try: from text_generation_server.models.pali_gemma import PaliGemmaBatch @@ -46,9 +47,12 @@ class SignalHandler: signal.signal(signal.SIGINT, self.exit_gracefully) signal.signal(signal.SIGTERM, self.exit_gracefully) + def set_keep_processing(self, value: bool): + self.KEEP_PROCESSING = value + def exit_gracefully(self, signum, frame): print(f"Exiting gracefully: Signal {signum}") - self.KEEP_PROCESSING = False + self.set_keep_processing(False) class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): @@ -96,6 +100,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): + set_max_prefill_tokens(request.max_prefill_tokens) + if self.quantize in {"exl2", "gptq"}: try: # When using GPTQ, Exllama kernels need some global kernels @@ -150,6 +156,18 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): request.batch, self.model.tokenizer, self.model.dtype, self.model.device ) + concat_ns = None + if self.model.support_chunking: + if request.HasField("cached_batch"): + cached_batch = self.cache.pop(request.cached_batch.id) + if cached_batch is None: + raise ValueError( + f"Batch ID {request.cached_batch.id} not found in cache." + ) + start_concat = time.time_ns() + batch = self.model.batch_type.concatenate([cached_batch, batch]) + concat_ns = time.time_ns() - start_concat + generations, next_batch, timings = self.model.generate_token(batch) self.cache.set(next_batch) @@ -159,6 +177,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): forward_ns=timings[0], decode_ns=timings[1], total_ns=time.time_ns() - start, + concat_ns=concat_ns, ) async def Decode(self, request, context): @@ -252,10 +271,12 @@ def serve( logger.exception("Error when initializing model") raise + signal_handler = SignalHandler() + set_adapter_to_index(adapter_to_index) server = aio.server( interceptors=[ - ExceptionInterceptor(), + ExceptionInterceptor(lambda: signal_handler.set_keep_processing(False)), UDSOpenTelemetryAioServerInterceptor(), ], options=[ @@ -276,7 +297,6 @@ def serve( await server.start() logger.info("Server started at {}".format(local_url)) - signal_handler = SignalHandler() while signal_handler.KEEP_PROCESSING: await asyncio.sleep(0.5) diff --git a/server/text_generation_server/utils/adapter.py b/server/text_generation_server/utils/adapter.py index 2b61f9bb..09254b68 100644 --- a/server/text_generation_server/utils/adapter.py +++ b/server/text_generation_server/utils/adapter.py @@ -120,15 +120,18 @@ def _load_and_merge( if adapter.id == BASE_MODEL_ADAPTER_ID: raise ValueError("Base model adapter cannot be merged.") - module_map, adapter_config, adapter_weight_names, adapter_tokenizer = ( - load_module_map( - model_id, - adapter.revision, - adapter.id, - adapter.path, - weight_names, - trust_remote_code, - ) + ( + module_map, + adapter_config, + adapter_weight_names, + adapter_tokenizer, + ) = load_module_map( + model_id, + adapter.revision, + adapter.id, + adapter.path, + weight_names, + trust_remote_code, ) adapters_to_merge.append((module_map, adapter_config)) diff --git a/server/text_generation_server/utils/prefill_chunking.py b/server/text_generation_server/utils/prefill_chunking.py new file mode 100644 index 00000000..c227d30f --- /dev/null +++ b/server/text_generation_server/utils/prefill_chunking.py @@ -0,0 +1,24 @@ +from typing import Optional + +SUPPORT_CHUNKING: Optional[bool] = None +MAX_PREFILL_TOKENS: Optional[int] = None + + +def set_support_chunking(support_chunking: bool): + global SUPPORT_CHUNKING + SUPPORT_CHUNKING = support_chunking + + +def get_support_chunking() -> bool: + global SUPPORT_CHUNKING + return SUPPORT_CHUNKING + + +def set_max_prefill_tokens(max_prefill_tokens: int): + global MAX_PREFILL_TOKENS + MAX_PREFILL_TOKENS = max_prefill_tokens + + +def get_max_prefill_tokens() -> int: + global MAX_PREFILL_TOKENS + return MAX_PREFILL_TOKENS diff --git a/server/text_generation_server/utils/segments.py b/server/text_generation_server/utils/segments.py index f5961102..b3f92369 100644 --- a/server/text_generation_server/utils/segments.py +++ b/server/text_generation_server/utils/segments.py @@ -7,6 +7,7 @@ from typing import List, Tuple, Union import torch +# FIXME: this should be optimized def find_segments( adapter_indices: Union[torch.Tensor, List[int]] ) -> Tuple[List[int], List[int]]: