diff --git a/Cargo.toml b/Cargo.toml index 8abb8ad1..bc2da5a1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,12 @@ incremental = true inherits = "release" debug = 1 incremental = true +panic = "abort" + +[profile.release-opt] +inherits = "release" +debug = 0 +incremental = false lto = "fat" opt-level = 3 codegen-units = 1 -panic = "abort" diff --git a/Dockerfile b/Dockerfile index 422b1374..659e2673 100644 --- a/Dockerfile +++ b/Dockerfile @@ -25,7 +25,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ rm -f $PROTOC_ZIP COPY --from=planner /usr/src/recipe.json recipe.json -RUN cargo chef cook --release --recipe-path recipe.json +RUN cargo chef cook --profile release-opt --recipe-path recipe.json COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml @@ -33,7 +33,7 @@ COPY proto proto COPY benchmark benchmark COPY router router COPY launcher launcher -RUN cargo build --release +RUN cargo build --profile release-opt # Python builder # Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile @@ -226,11 +226,11 @@ RUN cd server && \ pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir # Install benchmarker -COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark +COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark # Install router -COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router +COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router # Install launcher -COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher +COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ build-essential \ diff --git a/Dockerfile_amd b/Dockerfile_amd index 92dd0ea8..b0d181ea 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -25,7 +25,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ rm -f $PROTOC_ZIP COPY --from=planner /usr/src/recipe.json recipe.json -RUN cargo chef cook --release --recipe-path recipe.json +RUN cargo chef cook --profile release-opt --recipe-path recipe.json COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml @@ -33,7 +33,7 @@ COPY proto proto COPY benchmark benchmark COPY router router COPY launcher launcher -RUN cargo build --release +RUN cargo build --profile release-opt # Text Generation Inference base image for RoCm FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update as base @@ -193,11 +193,11 @@ RUN cd server && \ pip install ".[accelerate, peft, outlines]" --no-cache-dir # Install benchmarker -COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark +COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark # Install router -COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router +COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router # Install launcher -COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher +COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher # AWS Sagemaker compatible image FROM base as sagemaker diff --git a/Dockerfile_intel b/Dockerfile_intel index 9c9b5c16..0a700003 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -24,7 +24,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ rm -f $PROTOC_ZIP COPY --from=planner /usr/src/recipe.json recipe.json -RUN cargo chef cook --release --recipe-path recipe.json +RUN cargo chef cook --profile release-opt --recipe-path recipe.json COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml @@ -32,7 +32,7 @@ COPY proto proto COPY benchmark benchmark COPY router router COPY launcher launcher -RUN cargo build --release +RUN cargo build --profile release-opt # Text Generation Inference base image for Intel @@ -78,11 +78,11 @@ ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mp ENV CCL_ZE_IPC_EXCHANGE=sockets # Install benchmarker -COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark +COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark # Install router -COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router +COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router # Install launcher -COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher +COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher # Final image FROM base diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 27b74249..b82d23ba 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -155,6 +155,8 @@ async fn prefill( ignore_eos_token: true, // Will not stop even if a eos token is generated }), top_n_tokens: top_n_tokens.unwrap_or(0), + blocks: vec![], + slots: vec![], }) .collect(); @@ -163,6 +165,7 @@ async fn prefill( requests, size: batch_size, max_tokens: batch_size * (sequence_length + decode_length), + max_blocks: 0, }; // Run prefill diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index ca2908c9..01cc43fd 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -130,6 +130,10 @@ message Request { bool prefill_logprobs = 6; /// Return most likely n tokens uint32 top_n_tokens = 7; + /// Paged attention blocks + repeated uint32 blocks = 9; + /// Paged attention slots + repeated uint32 slots = 10; } message Batch { @@ -141,6 +145,8 @@ message Batch { uint32 size = 3; /// Maximum number of tokens this batch will grow to uint32 max_tokens = 4; + /// Maximum number of Paged Attention blocks + uint32 max_blocks = 5; } message CachedBatch { diff --git a/router/client/src/v3/client.rs b/router/client/src/v3/client.rs index 1f3a89a0..9a3892fb 100644 --- a/router/client/src/v3/client.rs +++ b/router/client/src/v3/client.rs @@ -153,6 +153,9 @@ impl Client { }), // We truncate the input on the server side to be sure that it has the correct size truncate, + // Blocks and slots will be set on the server side if we use paged attention + blocks: vec![], + slots: vec![], // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, @@ -187,7 +190,8 @@ impl Client { id: 0, size: requests.len() as u32, requests, - max_tokens: 0, + max_tokens: max_input_length, + max_blocks: 0, }; let request = tonic::Request::new(WarmupRequest { diff --git a/router/client/src/v3/sharded_client.rs b/router/client/src/v3/sharded_client.rs index 9b4f74d8..94002f55 100644 --- a/router/client/src/v3/sharded_client.rs +++ b/router/client/src/v3/sharded_client.rs @@ -241,12 +241,16 @@ impl Health for ShardedClient { ignore_eos_token: false, }), top_n_tokens: 0, + // Block 0 is reserved for health checks + blocks: vec![0], + slots: (0..16).collect(), }; let batch = Batch { id: u64::MAX, requests: vec![liveness_request], size: 1, max_tokens: 2, + max_blocks: 1, }; self.clone().prefill(batch).await?; Ok(()) diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs new file mode 100644 index 00000000..7467fd85 --- /dev/null +++ b/router/src/infer/v3/block_allocator.rs @@ -0,0 +1,136 @@ +use std::cmp::min; +use tokio::sync::{mpsc, oneshot}; + +#[derive(Debug, Clone)] +pub(crate) struct BlockAllocation { + pub blocks: Vec, + pub slots: Vec, + block_allocator: BlockAllocator, +} + +impl Drop for BlockAllocation { + fn drop(&mut self) { + self.block_allocator.free(self.blocks.clone()) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct BlockAllocator { + /// Channel to communicate with the background task + block_allocator: mpsc::UnboundedSender, +} + +impl BlockAllocator { + pub(crate) fn new( + max_batch_total_tokens: u32, + block_size: u32, + window_size: Option, + ) -> Self { + // Create channel + let (sender, receiver) = mpsc::unbounded_channel(); + + // Launch background queue task + tokio::spawn(block_allocator_task( + max_batch_total_tokens / block_size, + block_size, + window_size, + receiver, + )); + + Self { + block_allocator: sender, + } + } + + pub(crate) async fn allocate(&self, tokens: u32) -> Option { + let (response_sender, response_receiver) = oneshot::channel(); + self.block_allocator + .send(BlockAllocatorCommand::Allocate { + tokens, + response_sender, + }) + .unwrap(); + + response_receiver + .await + .unwrap() + .map(|(blocks, slots)| BlockAllocation { + blocks, + slots, + block_allocator: self.clone(), + }) + } + + pub(crate) fn free(&self, blocks: Vec) { + self.block_allocator + .send(BlockAllocatorCommand::Free { blocks }) + .unwrap(); + } +} + +async fn block_allocator_task( + blocks: u32, + block_size: u32, + window_size: Option, + mut receiver: mpsc::UnboundedReceiver, +) { + // Block 0 is reserved for health checks + let mut free_blocks: Vec = (1..blocks).collect(); + while let Some(cmd) = receiver.recv().await { + match cmd { + BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks), + BlockAllocatorCommand::Allocate { + tokens, + response_sender, + } => { + // Apply window size + let (required_blocks, repeats) = { + let (tokens, repeats) = match window_size { + None => (tokens, 1), + Some(window_size) => { + let repeats = (tokens + window_size - 1) / window_size; + let tokens = min(tokens, window_size); + (tokens, repeats as usize) + } + }; + // Pad to a multiple of block size + let required_blocks = (tokens + block_size - 1) / block_size; + (required_blocks, repeats) + }; + + let tokens = tokens as usize; + let allocation = if required_blocks > free_blocks.len() as u32 { + None + } else { + let blocks = + free_blocks.split_off(free_blocks.len() - required_blocks as usize); + let mut slots = Vec::with_capacity( + (required_blocks * block_size * repeats as u32) as usize, + ); + + 'slots: for block_id in blocks.repeat(repeats).iter() { + for s in (block_id * block_size)..((block_id + 1) * block_size) { + slots.push(s); + if slots.len() == tokens { + break 'slots; + } + } + } + Some((blocks, slots)) + }; + response_sender.send(allocation).unwrap(); + } + } + } +} + +#[derive(Debug)] +enum BlockAllocatorCommand { + Free { + blocks: Vec, + }, + Allocate { + tokens: u32, + response_sender: oneshot::Sender, Vec)>>, + }, +} diff --git a/router/src/infer/v3/mod.rs b/router/src/infer/v3/mod.rs index 4299baf3..f9effab8 100644 --- a/router/src/infer/v3/mod.rs +++ b/router/src/infer/v3/mod.rs @@ -1,3 +1,4 @@ +mod block_allocator; mod queue; mod scheduler; diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index b926f329..0b66142a 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -1,17 +1,20 @@ -use crate::infer::{InferError, InferStreamResponse}; +use crate::infer::v3::block_allocator::{BlockAllocation, BlockAllocator}; +use crate::infer::InferError; +use crate::infer::InferStreamResponse; use crate::validation::{ ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, }; use nohash_hasher::{BuildNoHashHasher, IntMap}; -use std::cmp::min; +use std::cmp::{max, min}; use std::collections::VecDeque; use text_generation_client::v3::{ Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; -use text_generation_client::{ChunksToString, Input}; +use text_generation_client::ChunksToString; +use text_generation_client::Input; use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; -use tracing::{info_span, instrument, Span}; +use tracing::{info_span, instrument, Instrument, Span}; /// Queue entry #[derive(Debug)] @@ -28,6 +31,8 @@ pub(crate) struct Entry { pub queue_time: Instant, /// Instant when this entry was added to a batch pub batch_time: Option, + /// Block Allocation + pub block_allocation: Option, } /// Request Queue @@ -43,6 +48,7 @@ impl Queue { block_size: u32, window_size: Option, speculate: u32, + max_batch_total_tokens: u32, ) -> Self { // Create channel let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); @@ -53,12 +59,14 @@ impl Queue { block_size, window_size, speculate, + max_batch_total_tokens, queue_receiver, )); Self { queue_sender } } + /// Append an entry to the queue #[instrument(skip_all)] pub(crate) fn append(&self, entry: Entry) { // Send append command to the background task managing the state @@ -103,9 +111,16 @@ async fn queue_task( block_size: u32, window_size: Option, speculate: u32, + max_batch_total_tokens: u32, mut receiver: mpsc::UnboundedReceiver, ) { - let mut state = State::new(requires_padding, block_size, window_size, speculate); + let mut state = State::new( + requires_padding, + block_size, + window_size, + speculate, + max_batch_total_tokens, + ); while let Some(cmd) = receiver.recv().await { match cmd { @@ -120,12 +135,14 @@ async fn queue_task( token_budget, response_sender, span, - } => span.in_scope(|| { - let next_batch = - state.next_batch(min_size, max_size, prefill_token_budget, token_budget); + } => { + let next_batch = state + .next_batch(min_size, max_size, prefill_token_budget, token_budget) + .instrument(span) + .await; response_sender.send(next_batch).unwrap(); metrics::gauge!("tgi_queue_size", state.entries.len() as f64); - }), + } } } } @@ -142,9 +159,6 @@ struct State { /// Id of the next batch next_batch_id: u64, - /// Whether the model is using padding - requires_padding: bool, - /// Paged Attention block size block_size: u32, @@ -153,6 +167,9 @@ struct State { /// Speculation amount speculate: u32, + + /// Paged Attention Block Allocation + block_allocator: Option, } impl State { @@ -161,15 +178,19 @@ impl State { block_size: u32, window_size: Option, speculate: u32, + max_batch_total_tokens: u32, ) -> Self { + let block_allocator = (!requires_padding) + .then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size)); + Self { entries: VecDeque::with_capacity(128), next_id: 0, next_batch_id: 0, - requires_padding, block_size, window_size, speculate, + block_allocator, } } @@ -185,7 +206,7 @@ impl State { } // Get the next batch - fn next_batch( + async fn next_batch( &mut self, min_size: Option, max_size: Option, @@ -220,9 +241,10 @@ impl State { let mut max_input_length = 0; let mut prefill_tokens: u32 = 0; let mut decode_tokens: u32 = 0; + let mut max_blocks = 0; // Pop entries starting from the front of the queue - while let Some((id, mut entry)) = self.entries.pop_front() { + 'entry_loop: while let Some((id, mut entry)) = self.entries.pop_front() { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) if entry.response_tx.is_closed() { @@ -231,43 +253,67 @@ impl State { continue; } - if self.requires_padding { - // We pad to max input length in the Python shards - // We need to take these padding tokens into the equation - max_input_length = max_input_length.max(entry.request.input_length); - prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length - } else { - // pad to block size - prefill_tokens += ((entry.request.input_length + self.block_size - 1) - / self.block_size) - * self.block_size; - } + let block_allocation = match &self.block_allocator { + None => { + // We pad to max input length in the Python shards + // We need to take these padding tokens into the equation + max_input_length = max_input_length.max(entry.request.input_length); + prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length; - if self.requires_padding { - decode_tokens += entry.request.stopping_parameters.max_new_tokens; - } else { - let max_new_tokens = match self.window_size { - None => entry.request.stopping_parameters.max_new_tokens, - Some(window_size) => min( - window_size.saturating_sub(entry.request.input_length), - entry.request.stopping_parameters.max_new_tokens, - ), - }; + decode_tokens += entry.request.stopping_parameters.max_new_tokens; + let total_tokens = prefill_tokens + decode_tokens + self.speculate; - // pad to block size - decode_tokens += - ((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size; - } + if prefill_tokens > prefill_token_budget || total_tokens > token_budget { + // Entry is over budget + // Add it back to the front + tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); + self.entries.push_front((id, entry)); + break 'entry_loop; + } + None + } + Some(block_allocator) => { + prefill_tokens += entry.request.input_length; + let max_new_tokens = match self.window_size { + None => entry.request.stopping_parameters.max_new_tokens, + Some(window_size) => min( + window_size.saturating_sub(entry.request.input_length), + entry.request.stopping_parameters.max_new_tokens, + ), + }; + decode_tokens += max_new_tokens; - if prefill_tokens > prefill_token_budget - || (prefill_tokens + decode_tokens + self.speculate) > token_budget - { - // Entry is over budget - // Add it back to the front - tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); - self.entries.push_front((id, entry)); - break; - } + if prefill_tokens > prefill_token_budget + || (prefill_tokens + decode_tokens + self.speculate) > token_budget + { + // Entry is over budget + // Add it back to the front + tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); + self.entries.push_front((id, entry)); + break; + } + + let tokens = entry.request.input_length + + entry.request.stopping_parameters.max_new_tokens + + self.speculate + - 1; + + match block_allocator.allocate(tokens).await { + None => { + // Entry is over budget + // Add it back to the front + tracing::debug!("Over budget: not enough free blocks"); + self.entries.push_front((id, entry)); + break 'entry_loop; + } + Some(block_allocation) => { + tracing::debug!("Allocation: {block_allocation:?}"); + max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); + Some(block_allocation) + } + } + } + }; tracing::debug!("Accepting entry"); // Create a new span to link the batch back to this entry @@ -278,13 +324,23 @@ impl State { // Update entry entry.temp_span = Some(entry_batch_span); + let (blocks, slots) = match &block_allocation { + None => (Vec::new(), Vec::new()), + Some(block_allocation) => ( + block_allocation.blocks.clone(), + block_allocation.slots.clone(), + ), + }; + + entry.block_allocation = block_allocation; + batch_requests.push(Request { id, prefill_logprobs: entry.request.decoder_input_details, - inputs: entry.request.inputs.chunks_to_string(), input_chunks: Some(Input { chunks: entry.request.inputs.clone(), }), + inputs: entry.request.inputs.chunks_to_string(), truncate: entry.request.truncate, parameters: Some(NextTokenChooserParameters::from( entry.request.parameters.clone(), @@ -293,6 +349,8 @@ impl State { entry.request.stopping_parameters.clone(), )), top_n_tokens: entry.request.top_n_tokens, + blocks, + slots, }); // Set batch_time entry.batch_time = Some(Instant::now()); @@ -335,6 +393,7 @@ impl State { requests: batch_requests, size, max_tokens: (prefill_tokens + decode_tokens), + max_blocks, }; // Increment batch id self.next_batch_id += 1; @@ -438,13 +497,14 @@ mod tests { temp_span: None, queue_time: Instant::now(), batch_time: None, + block_allocation: None, }; (entry, receiver_tx) } - #[test] - fn test_append() { - let mut state = State::new(false, 1, None, 0); + #[tokio::test] + async fn test_append() { + let mut state = State::new(false, 1, None, 0, 16); let (entry, _guard) = default_entry(); assert_eq!(state.next_id, 0); @@ -458,23 +518,23 @@ mod tests { assert_eq!(id, 0); } - #[test] - fn test_next_batch_empty() { - let mut state = State::new(false, 1, None, 0); + #[tokio::test] + async fn test_next_batch_empty() { + let mut state = State::new(false, 1, None, 0, 16); - assert!(state.next_batch(None, None, 1, 1).is_none()); - assert!(state.next_batch(Some(1), None, 1, 1).is_none()); + assert!(state.next_batch(None, None, 1, 1).await.is_none()); + assert!(state.next_batch(Some(1), None, 1, 1).await.is_none()); } - #[test] - fn test_next_batch_min_size() { - let mut state = State::new(false, 1, None, 0); + #[tokio::test] + async fn test_next_batch_min_size() { + let mut state = State::new(false, 1, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 2, 2).await.unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&0)); assert!(entries.contains_key(&1)); @@ -490,7 +550,7 @@ mod tests { let (entry3, _guard3) = default_entry(); state.append(entry3); - assert!(state.next_batch(Some(2), None, 2, 2).is_none()); + assert!(state.next_batch(Some(2), None, 2, 2).await.is_none()); assert_eq!(state.next_id, 3); assert_eq!(state.entries.len(), 1); @@ -498,15 +558,15 @@ mod tests { assert_eq!(id, 2); } - #[test] - fn test_next_batch_max_size() { - let mut state = State::new(false, 1, None, 0); + #[tokio::test] + async fn test_next_batch_max_size() { + let mut state = State::new(false, 1, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap(); + let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).await.unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert!(entries.get(&0).unwrap().batch_time.is_some()); @@ -518,15 +578,15 @@ mod tests { assert_eq!(state.next_batch_id, 1); } - #[test] - fn test_next_batch_token_budget() { - let mut state = State::new(false, 1, None, 0); + #[tokio::test] + async fn test_next_batch_token_budget() { + let mut state = State::new(false, 1, None, 0, 2); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 1, 1).await.unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert_eq!(batch.id, 0); @@ -539,7 +599,7 @@ mod tests { let (entry3, _guard3) = default_entry(); state.append(entry3); - let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 3, 3).await.unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&1)); assert!(entries.contains_key(&2)); @@ -553,14 +613,14 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = Queue::new(false, 1, None, 0); + let queue = Queue::new(false, 1, None, 0, 16); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = Queue::new(false, 1, None, 0); + let queue = Queue::new(false, 1, None, 0, 16); assert!(queue.next_batch(None, None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); @@ -568,7 +628,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = Queue::new(false, 1, None, 0); + let queue = Queue::new(false, 1, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -601,7 +661,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_max_size() { - let queue = Queue::new(false, 1, None, 0); + let queue = Queue::new(false, 1, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -617,7 +677,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_budget() { - let queue = Queue::new(false, 1, None, 0); + let queue = Queue::new(false, 1, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -642,7 +702,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_speculate() { - let queue = Queue::new(false, 1, None, 2); + let queue = Queue::new(false, 1, None, 2, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -661,7 +721,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(false, 1, None, 0); + let queue = Queue::new(false, 1, None, 0, 16); let (entry, _) = default_entry(); queue.append(entry); diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index 257d191f..ad03dd83 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -39,7 +39,13 @@ impl SchedulerV3 { speculate: u32, generation_health: Arc, ) -> Self { - let queue = Queue::new(requires_padding, 16, window_size, speculate); + let queue = Queue::new( + requires_padding, + 16, + window_size, + speculate, + max_batch_total_tokens, + ); let batching_task_notifier = Arc::new(Notify::new()); // Spawn batching background task that contains all the inference logic @@ -81,6 +87,7 @@ impl Scheduler for SchedulerV3 { temp_span: None, queue_time: Instant::now(), batch_time: None, + block_allocation: None, }); // Notify the background task that we have a new entry in the queue that needs diff --git a/server/text_generation_server/models/cache_manager.py b/server/text_generation_server/models/cache_manager.py deleted file mode 100644 index c7705fe8..00000000 --- a/server/text_generation_server/models/cache_manager.py +++ /dev/null @@ -1,140 +0,0 @@ -import math -import torch - -from typing import Optional, List, Tuple -from text_generation_server.utils.import_utils import SYSTEM - -BLOCK_SIZE: int = 16 -# Will be set in warmup -CACHE_MANAGER: Optional["CacheManager"] = None - - -class CacheManager: - def __init__( - self, - num_blocks: int, - num_layers: int, - num_heads: int, - head_size: int, - repeat_slots: bool, - dtype: torch.dtype, - device: torch.device, - ): - self.block_size = BLOCK_SIZE - self.num_blocks = num_blocks - self.repeat_slots = repeat_slots - - element_size = torch.tensor([], dtype=dtype).element_size() - if SYSTEM == "xpu": - x = 1 - else: - x = self.block_size // element_size - - self.kv_cache = [ - ( - torch.empty( - (num_blocks, num_heads, head_size // x, self.block_size, x), - dtype=dtype, - device=device, - ), - torch.empty( - (num_blocks, num_heads, head_size, self.block_size), - dtype=dtype, - device=device, - ), - ) - for _ in range(num_layers) - ] - self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu") - self.slots = torch.arange( - 0, num_blocks * self.block_size, dtype=torch.int64 - ).view(num_blocks, self.block_size) - - def allocate( - self, - needed_blocks_slots: List[Tuple[int, int]], - blocks: int, - max_blocks: int, - device: torch.device, - ): - # Get free blocks indices by finding values in mask that are not set to 0 - free_block_indices = self.free_block_mask.nonzero() - if blocks > len(free_block_indices): - raise RuntimeError( - f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks" - ) - - # Slice by the number of required blocks - block_indices = free_block_indices[:blocks] - block_indices = block_indices.flatten() - - # Padded block tables - block_tables_tensor = torch.zeros( - (len(needed_blocks_slots), max_blocks), dtype=torch.int32 - ) - - # Allocate paged attention blocks - cumulative_blocks = 0 - slots = [] - block_tables = [] - for i, (needed_blocks, needed_slots) in enumerate(needed_blocks_slots): - # Get allocated blocks for this sequence - allocated_blocks = block_indices[ - cumulative_blocks : cumulative_blocks + needed_blocks - ] - # Get slots for the allocated blocks - all_slots = self.slots[allocated_blocks].flatten() - - # Repeat slots in the case of context sliding window - if needed_slots > len(all_slots) and self.repeat_slots: - repeats = math.ceil(needed_slots / len(all_slots)) - all_slots = all_slots.repeat(repeats) - - allocated_slots = all_slots[:needed_slots] - - slots.append(allocated_slots) - block_tables.append(allocated_blocks.tolist()) - block_tables_tensor[i, :needed_blocks] = allocated_blocks - cumulative_blocks += needed_blocks - - block_tables = block_tables - block_tables_tensor = block_tables_tensor.to(device) - slots = torch.concat(slots).to(device) - - # Allocate the required number of blocks by setting the mask to 0 - self.free_block_mask[block_indices] = 0 - - return block_tables, block_tables_tensor, slots - - def free(self, block_indices: Optional[List[int]]): - if block_indices is not None and block_indices: - # Reset mask - self.free_block_mask[block_indices] = 1 - - -def set_cache_manager( - num_blocks: int, - num_layers: int, - num_heads: int, - head_size: int, - repeat_slots: bool, - dtype: torch.dtype, - device: torch.device, -) -> CacheManager: - global CACHE_MANAGER - if CACHE_MANAGER is not None: - del CACHE_MANAGER - torch.cuda.empty_cache() - - CACHE_MANAGER = CacheManager( - num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device - ) - return CACHE_MANAGER - - -def get_cache_manager() -> CacheManager: - global CACHE_MANAGER - if CACHE_MANAGER is None: - raise RuntimeError("cache manager was not initialized") - - return CACHE_MANAGER diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 31109bc9..764dc6e2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -512,6 +512,7 @@ class FlashCohereForCausalLM(torch.nn.Module): slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.model( diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 7967e420..9c32490e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -834,6 +834,7 @@ class FlashDbrxForCausalLM(torch.nn.Module): slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.model( diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 89ca8b5b..339198a7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -458,6 +458,7 @@ class FlashGemmaForCausalLM(torch.nn.Module): slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: input_embeds = self.embed_tokens(input_ids) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 59e7bf8b..d399be2f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -388,6 +388,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.gpt_neox( diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 53d3ea42..0a47b1cc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -398,6 +398,7 @@ class FlashPhiForCausalLM(torch.nn.Module): slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model( diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index d489c3ba..7d3c72a7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -670,6 +670,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.transformer( diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 1f47550e..74eedc51 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -482,6 +482,7 @@ class FlashSantacoderForCausalLM(nn.Module): slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.transformer( diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 86d9b4c8..d8c8838c 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -25,11 +25,6 @@ from text_generation_server.models.types import ( Generation, GeneratedText, ) -from text_generation_server.models.cache_manager import ( - get_cache_manager, - set_cache_manager, - BLOCK_SIZE, -) from text_generation_server.pb import generate_pb2 from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS import text_generation_server.models.globals as tgi_globals @@ -44,6 +39,21 @@ from text_generation_server.utils.import_utils import ( tracer = trace.get_tracer(__name__) +BLOCK_SIZE: int = 16 + +# Will be set in init +SLIDING_WINDOW: Optional[int] = None + + +def set_sliding_window(sliding_window: int): + global SLIDING_WINDOW + SLIDING_WINDOW = sliding_window + + +def get_sliding_windows() -> int: + global SLIDING_WINDOW + return SLIDING_WINDOW + @dataclass class FlashCausalLMBatch(Batch): @@ -55,12 +65,15 @@ class FlashCausalLMBatch(Batch): # Decoder values input_ids: torch.Tensor position_ids: torch.Tensor - speculative_ids: torch.Tensor + speculative_ids: Optional[torch.Tensor] # Flash Attention values # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill cu_seqlen_prefill: Optional[torch.Tensor] + # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers + # as we only keep SLIDING_WINDOW values instead of the whole tensor + prefill_cache_indices: Optional[torch.Tensor] # Paged Attention values @@ -69,16 +82,13 @@ class FlashCausalLMBatch(Batch): start_slots: torch.Tensor # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode slot_indices: torch.Tensor - # List of tuple of ints representing the number of blocks and slots needed by each sequence - needed_blocks_slots: Optional[List[Tuple[int, int]]] - # Set in prefill by the CacheManager # list of length b of list of length s_i // block_size - block_tables: Optional[List[List[int]]] + block_tables: List[List[int]] # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences - block_tables_tensor: Optional[torch.Tensor] + block_tables_tensor: torch.Tensor # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences - slots: Optional[torch.Tensor] + slots: torch.Tensor max_seqlen: int @@ -104,7 +114,7 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor: torch.Tensor # Number of blocks in this batch - blocks: int + num_blocks: int # Maximum number of blocks max_blocks: int @@ -113,7 +123,7 @@ class FlashCausalLMBatch(Batch): id=self.batch_id, request_ids=[r.id for r in self.requests], size=len(self), - max_tokens=self.blocks * BLOCK_SIZE, + max_tokens=self.num_blocks * BLOCK_SIZE, ) @classmethod @@ -129,17 +139,6 @@ class FlashCausalLMBatch(Batch): )["input_ids"] return batch_tokenized_inputs - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, - ) -> "FlashCausalLMBatch": - batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer) - return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) - @classmethod def from_tokenized( cls, @@ -149,12 +148,12 @@ class FlashCausalLMBatch(Batch): dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": + sliding_window = get_sliding_windows() position_ids = [] - speculative_ids = [] cu_seqlen_prefill = [0] - needed_blocks_slots = [] start_slots = [] slot_indices = [] + prefill_cache_indices = [] input_lengths = [] prefix_offsets = [] @@ -177,11 +176,14 @@ class FlashCausalLMBatch(Batch): cumulative_max_length = 0 prefill_out_cumulative_length = 0 - blocks = 0 + num_blocks = 0 max_seqlen = 0 max_length = 0 max_blocks = 0 + block_tables = [] + slots = [] + # Parse batch for i, (r, tokenized_input) in enumerate( zip(pb.requests, batch_tokenized_inputs) @@ -225,9 +227,25 @@ class FlashCausalLMBatch(Batch): speculative_length = get_speculate() speculative_length = 0 if speculative_length is None else speculative_length total_tokens = input_length + max_new_tokens - 1 + speculative_length - needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) - blocks += needed_blocks - needed_blocks_slots.append((needed_blocks, total_tokens)) + + # blocks and slots can be empty (for example in warmup) + if not r.blocks: + needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) + request_blocks = [ + b for b in range(num_blocks, num_blocks + needed_blocks) + ] + request_slots = [ + s + for b in request_blocks + for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) + ] + else: + request_blocks = r.blocks + request_slots = r.slots + + block_tables.append(request_blocks) + slots.extend(request_slots[:total_tokens]) + num_blocks += len(request_blocks) start_slots.append(cumulative_max_length) request_slot_indices = torch.arange( @@ -237,6 +255,15 @@ class FlashCausalLMBatch(Batch): ) slot_indices.append(request_slot_indices) + # Create tensor to slice into the kv tensor in prefill + if sliding_window is not None: + request_prefill_cache_indices = torch.arange( + cumulative_length + max(0, input_length - sliding_window), + cumulative_length + input_length, + dtype=torch.int64, + ) + prefill_cache_indices.append(request_prefill_cache_indices) + all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs @@ -261,7 +288,7 @@ class FlashCausalLMBatch(Batch): cumulative_length += input_length cumulative_max_length += total_tokens max_seqlen = max(max_seqlen, input_length) - max_blocks = max(max_blocks, needed_blocks) + max_blocks = max(max_blocks, len(request_blocks)) max_length = max( max_length, input_length + max_new_tokens + speculative_length ) @@ -287,16 +314,23 @@ class FlashCausalLMBatch(Batch): input_ids = np.concatenate(all_input_ids, dtype=np.int64) position_ids = torch.cat(position_ids) slot_indices = torch.cat(slot_indices) + if sliding_window is not None: + prefill_cache_indices = torch.cat(prefill_cache_indices) else: input_ids = all_input_ids[0] position_ids = position_ids[0] slot_indices = slot_indices[0] + if sliding_window is not None: + prefill_cache_indices = prefill_cache_indices[0] cu_seqlen_prefill = torch.tensor( cu_seqlen_prefill, device=device, dtype=torch.int32 ) position_ids = position_ids.to(device) slot_indices = slot_indices.to(device) + prefill_cache_indices = ( + prefill_cache_indices.to(device) if sliding_window is not None else None + ) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) input_lengths_tensor = torch.tensor( input_lengths, dtype=torch.int32, device=device @@ -319,6 +353,14 @@ class FlashCausalLMBatch(Batch): top_n_tokens, device=device, dtype=torch.int64 ) + slots = torch.tensor(slots, dtype=torch.int64, device=device) + block_tables_tensor = torch.zeros( + (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" + ) + for i, request_blocks in enumerate(block_tables): + block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) + block_tables_tensor = block_tables_tensor.to(device) + return cls( batch_id=pb.id, requests=pb.requests, @@ -326,12 +368,12 @@ class FlashCausalLMBatch(Batch): input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, + prefill_cache_indices=prefill_cache_indices, start_slots=start_slots, slot_indices=slot_indices, - needed_blocks_slots=needed_blocks_slots, - block_tables=None, - block_tables_tensor=None, - slots=None, + block_tables=block_tables, + block_tables_tensor=block_tables_tensor, + slots=slots, max_seqlen=max_seqlen, prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, @@ -346,11 +388,22 @@ class FlashCausalLMBatch(Batch): stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, top_n_tokens_tensor=top_n_tokens_tensor, - blocks=blocks, + num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=None, ) + @classmethod + def from_pb( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, + device: torch.device, + ) -> "FlashCausalLMBatch": + batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer) + return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) + @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": if len(request_ids) == 0: @@ -388,7 +441,7 @@ class FlashCausalLMBatch(Batch): stopping_criterias = [] top_n_tokens = [] - blocks = 0 + num_blocks = 0 max_blocks = 0 # Cumulative length cumulative_max_length = 0 @@ -420,7 +473,7 @@ class FlashCausalLMBatch(Batch): ) request_block_table = self.block_tables[idx] - blocks += len(request_block_table) + num_blocks += len(request_block_table) block_tables.append(request_block_table) start_slots.append(cumulative_max_length) @@ -439,17 +492,6 @@ class FlashCausalLMBatch(Batch): max_blocks = max(max_blocks, len(request_block_table)) - block_indices_to_free = [] - # Iterate on all requests - for i, r in enumerate(self.requests): - # Filter requests that are not part of the new batch - if r.id not in requests_idx_mapping.keys(): - block_indices_to_free.extend(self.block_tables[i]) - # Free blocks - get_cache_manager().free(block_indices_to_free) - # Needed to avoid dropping blocks when the batches will go out of scope - self.block_tables = None - # Index into tensors input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] @@ -475,9 +517,9 @@ class FlashCausalLMBatch(Batch): input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, + prefill_cache_indices=None, start_slots=start_slots, slot_indices=slot_indices, - needed_blocks_slots=None, block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, @@ -495,7 +537,7 @@ class FlashCausalLMBatch(Batch): stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, top_n_tokens_tensor=top_n_tokens_tensor, - blocks=blocks, + num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=speculative_ids, ) @@ -507,7 +549,7 @@ class FlashCausalLMBatch(Batch): requests = [] requests_idx_mapping = {} - blocks = 0 + num_blocks = 0 total_batch_size = 0 total_slots = 0 max_blocks = 0 @@ -516,7 +558,7 @@ class FlashCausalLMBatch(Batch): for b in batches: total_batch_size += len(b) total_slots += len(b.slots) - blocks += b.blocks + num_blocks += b.num_blocks speculative_length = ( b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 ) @@ -635,11 +677,6 @@ class FlashCausalLMBatch(Batch): else None ) - # Needed to avoid dropping blocks when the batches will go out of scope - for b in batches: - b.block_tables = None - del b - return cls( batch_id=batches[0].batch_id, requests=requests, @@ -647,9 +684,9 @@ class FlashCausalLMBatch(Batch): input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, + prefill_cache_indices=None, start_slots=start_slots, slot_indices=slot_indices, - needed_blocks_slots=None, block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, @@ -667,18 +704,11 @@ class FlashCausalLMBatch(Batch): stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, top_n_tokens_tensor=top_n_tokens_tensor, - blocks=blocks, + num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=speculative_ids, ) - def __del__(self): - if self.block_tables is not None and self.block_tables: - # Free blocks - get_cache_manager().free( - list(itertools.chain.from_iterable(self.block_tables)) - ) - def __len__(self): return len(self.requests) @@ -702,6 +732,7 @@ class FlashCausalLM(Model): self.head_size = head_size self.cuda_graphs = {} + self.kv_cache = [] super(FlashCausalLM, self).__init__( model=model, @@ -718,6 +749,43 @@ class FlashCausalLM(Model): def batch_type(self) -> Type[FlashCausalLMBatch]: return FlashCausalLMBatch + def max_past(self) -> int: + return getattr(self.model, "max_past", None) + + def init_kv_cache( + self, + num_blocks: int, + num_layers: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + ): + self.kv_cache = [] + empty_cache() + + element_size = torch.tensor([], dtype=dtype).element_size() + if SYSTEM == "xpu": + x = 1 + else: + x = BLOCK_SIZE // element_size + + self.kv_cache = [ + ( + torch.empty( + (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, num_heads, head_size, BLOCK_SIZE), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] + def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) @@ -728,12 +796,11 @@ class FlashCausalLM(Model): .repeat(bs) .reshape((bs, max_bt)) ) - kv_cache = get_cache_manager().kv_cache self.cuda_graphs[bs] = { "input_ids": input_ids, "position_ids": position_ids, - "kv_cache": kv_cache, + "kv_cache": self.kv_cache, "block_tables": block_tables, "slots": slots, "input_lengths": input_lengths, @@ -747,11 +814,12 @@ class FlashCausalLM(Model): input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, - kv_cache=kv_cache, + kv_cache=self.kv_cache, block_tables=block_tables, slots=slots, input_lengths=input_lengths, max_s=max_s, + prefill_cache_indices=None, lm_head_indices=None, ) torch.cuda.synchronize() @@ -761,11 +829,12 @@ class FlashCausalLM(Model): input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, - kv_cache=kv_cache, + kv_cache=self.kv_cache, block_tables=block_tables, slots=slots, input_lengths=input_lengths, max_s=max_s, + prefill_cache_indices=None, lm_head_indices=None, ) self.cuda_graphs[bs]["logits"] = logits @@ -777,17 +846,16 @@ class FlashCausalLM(Model): empty_cache() try: - cache_manager = set_cache_manager( - batch.blocks, + self.init_kv_cache( + batch.num_blocks, self.num_layers, self.num_kv_heads, self.head_size, - self.sliding_window is not None, self.dtype, self.device, ) max_bt = batch.max_blocks - max_s = max_bt * get_cache_manager().block_size + max_s = max_bt * BLOCK_SIZE if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): torch.cuda.tunable.tuning_enable(False) @@ -811,19 +879,17 @@ class FlashCausalLM(Model): num_blocks = ( # Leave 5% for some wiggle room int((free_memory * 0.95) // total_cache_size) - # Add batch.blocks as we allocated it above, so it is included in the peak memory. - + cache_manager.num_blocks + # Add batch.num_blocks as we allocated it above, so it is included in the peak memory. + + batch.num_blocks ) del batch - del cache_manager - set_cache_manager( + self.init_kv_cache( num_blocks, self.num_layers, self.num_kv_heads, self.head_size, - self.sliding_window is not None, self.dtype, self.device, ) @@ -889,7 +955,6 @@ class FlashCausalLM(Model): input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device) position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device) slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) - kv_cache = get_cache_manager().kv_cache # Dummy value, some models (starcoder2) don't accept `None`. input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) @@ -901,12 +966,13 @@ class FlashCausalLM(Model): cu_seqlen_prefill=torch.tensor( [0, seqlen], device=self.device, dtype=torch.int32 ), - kv_cache=get_cache_manager().kv_cache, + kv_cache=self.kv_cache, block_tables=None, input_lengths=input_lengths, slots=slots, max_s=seqlen, lm_head_indices=None, + prefill_cache_indices=None, ) def forward( @@ -917,7 +983,7 @@ class FlashCausalLM(Model): input_ids = batch.input_ids position_ids = batch.position_ids cu_seqlen_prefill = batch.cu_seqlen_prefill - kv_cache = get_cache_manager().kv_cache + kv_cache = self.kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor @@ -956,13 +1022,19 @@ class FlashCausalLM(Model): input_ids = batch.input_ids position_ids = batch.position_ids cu_seqlen_prefill = batch.cu_seqlen_prefill - kv_cache = get_cache_manager().kv_cache + kv_cache = self.kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices + if cu_seqlen_prefill is None and self.max_past() is not None: + # In decode, not prefill, we're actually overwriting the KV-cache + # in a circular buffer mode. + # This makes sure the max_s for the decode pass is correct. + max_s = min(self.max_past(), max_s) + bs = input_ids.shape[0] sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) if sorted_padded_bs: @@ -972,7 +1044,7 @@ class FlashCausalLM(Model): cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: - return self.model.forward( + logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, @@ -981,8 +1053,12 @@ class FlashCausalLM(Model): slots=slots, input_lengths=input_lengths, max_s=max_s, + prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph # Static inputs are potentially padded @@ -1015,24 +1091,7 @@ class FlashCausalLM(Model): prefill = batch.cu_seqlen_prefill is not None prefill_logprobs = batch.prefill_next_token_indices is not None - if batch.needed_blocks_slots: - # Allocate blocks to this batch - block_tables, block_tables_tensor, slots = get_cache_manager().allocate( - batch.needed_blocks_slots, - batch.blocks, - batch.max_blocks, - batch.input_ids.device, - ) - batch.needed_blocks_slots = None - batch.block_tables = block_tables - batch.block_tables_tensor = block_tables_tensor - batch.slots = slots - - try: - out, speculative_logits = self.forward(batch) - except Exception as e: - del batch - raise e + out, speculative_logits = self.forward(batch) if prefill: next_token_logits = ( @@ -1327,7 +1386,6 @@ class FlashCausalLM(Model): batch.all_input_ids[i] = all_input_ids if stopped: - del batch # No need to return a batch if we know that all requests stopped forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index e6125e29..081c2e2c 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -1,308 +1,24 @@ -import math import torch import torch.distributed -import numpy as np - -from dataclasses import dataclass from opentelemetry import trace -from transformers import PreTrainedTokenizerBase, AutoTokenizer, AutoConfig -from typing import Optional, Tuple, Type +from transformers import AutoTokenizer, AutoConfig +from typing import Optional, Tuple -from text_generation_server.pb import generate_pb2 from text_generation_server.models import FlashCausalLM -from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE -from text_generation_server.models.cache_manager import ( - get_cache_manager, -) +from text_generation_server.models.flash_causal_lm import set_sliding_window from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( FlashMistralForCausalLM, MistralConfig, ) -from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils import ( initialize_torch_distributed, weight_files, Weights, - HeterogeneousNextTokenChooser, - StoppingCriteria, ) - -tracer = trace.get_tracer(__name__) - -# Will be set in init -SLIDING_WINDOW: Optional[int] = None -SLIDING_WINDOW_BLOCKS: Optional[int] = None from text_generation_server.utils.import_utils import SYSTEM -MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None - - -def set_sliding_window(sliding_window: int, sliding_window_blocks: int): - global SLIDING_WINDOW - global SLIDING_WINDOW_BLOCKS - SLIDING_WINDOW = sliding_window - SLIDING_WINDOW_BLOCKS = sliding_window_blocks - - -def get_sliding_windows() -> Tuple[int, int]: - global SLIDING_WINDOW - global SLIDING_WINDOW_BLOCKS - return SLIDING_WINDOW, SLIDING_WINDOW_BLOCKS - - -# Adds windowing logic to FlashCausalLMBatch -@dataclass -class FlashMistralBatch(FlashCausalLMBatch): - # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers - # as we only keep SLIDING_WINDOW values instead of the whole tensor - prefill_cache_indices: Optional[torch.Tensor] = None - - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, - ) -> "FlashCausalLMBatch": - batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer) - return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) - - @classmethod - def from_tokenized( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - batch_tokenized_inputs, - dtype: torch.dtype, - device: torch.device, - ) -> "FlashCausalLMBatch": - sliding_window, sliding_window_blocks = get_sliding_windows() - - position_ids = [] - cu_seqlen_prefill = [0] - needed_blocks_slots = [] - start_slots = [] - slot_indices = [] - prefill_cache_indices = [] - - input_lengths = [] - prefix_offsets = [] - read_offsets = [] - all_input_ids = [] - requests_idx_mapping = {} - - all_prefill_logprobs = True - no_prefill_logprobs = True - prefill_head_indices = [] - prefill_next_token_indices = [] - prefill_cu_outlens = [0] - - next_token_chooser_parameters = [] - stopping_criterias = [] - top_n_tokens = [] - - # Cumulative length - cumulative_length = 0 - cumulative_max_length = 0 - prefill_out_cumulative_length = 0 - - blocks = 0 - max_seqlen = 0 - max_length = 0 - max_blocks = 0 - - # Parse batch - for i, (r, tokenized_input) in enumerate( - zip(pb.requests, batch_tokenized_inputs) - ): - # request id -> idx in list mapping - requests_idx_mapping[r.id] = i - - tokenized_input = tokenized_input[-r.truncate :] - if ( - tokenized_input[0] == tokenizer.bos_token_id - and tokenized_input[1] == tokenizer.bos_token_id - ): - tokenized_input = tokenized_input[1:] - - input_length = len(tokenized_input) - input_lengths.append(input_length) - - prefix_offsets.append(input_length - 5) - read_offsets.append(input_length) - - all_input_ids.append(tokenized_input) - - # Position ids - request_position_ids = torch.arange(0, input_length, dtype=torch.int32) - position_ids.append(request_position_ids) - - # Add cumulative lengths of all previous inputs - cu_seqlen_prefill.append(cumulative_length + input_length) - - next_token_chooser_parameters.append(r.parameters) - - stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer - ) - max_new_tokens = stopping_criteria.max_new_tokens - stopping_criterias.append(stopping_criteria) - top_n_tokens.append(r.top_n_tokens) - - # Paged attention - # Remove one as the first token des not have a past - speculative_length = get_speculate() - total_tokens = input_length + max_new_tokens - 1 + speculative_length - - # Needed blocks can not go over SLIDING_WINDOW_BLOCKS - needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) - if sliding_window_blocks is not None: - needed_blocks = min(needed_blocks, sliding_window_blocks) - blocks += needed_blocks - - needed_blocks_slots.append((needed_blocks, total_tokens)) - start_slots.append(cumulative_max_length) - - request_slot_indices = torch.arange( - cumulative_max_length, - cumulative_max_length + input_length, - dtype=torch.int64, - ) - slot_indices.append(request_slot_indices) - - # Create tensor to slice into the kv tensor in prefill - if sliding_window is not None: - request_prefill_cache_indices = torch.arange( - cumulative_length + max(0, input_length - sliding_window), - cumulative_length + input_length, - dtype=torch.int64, - ) - prefill_cache_indices.append(request_prefill_cache_indices) - - all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs - no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs - - if r.prefill_logprobs: - prefill_head_indices.append(request_position_ids + cumulative_length) - prefill_next_token_indices.append( - prefill_out_cumulative_length + input_length - 1 - ) - prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) - prefill_out_cumulative_length += input_length - else: - prefill_head_indices.append( - torch.tensor( - [cumulative_length + input_length - 1], dtype=torch.int32 - ) - ) - prefill_next_token_indices.append(prefill_out_cumulative_length) - prefill_cu_outlens.append(prefill_out_cumulative_length + 1) - prefill_out_cumulative_length += 1 - - # Update - cumulative_length += input_length - cumulative_max_length += total_tokens - max_seqlen = max(max_seqlen, input_length) - max_blocks = max(max_blocks, needed_blocks) - max_length = max( - max_length, input_length + max_new_tokens + speculative_length - ) - - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, dtype, device, tokenizer - ) - start_slots = torch.tensor(start_slots, dtype=torch.int64) - - # Padded all_input_ids_tensor - all_input_ids_tensor = np.zeros( - (len(all_input_ids), max_length), dtype=np.int64 - ) - for i, input_ids in enumerate(all_input_ids): - all_input_ids_tensor[i, : len(input_ids)] = input_ids - - # Create tensors on device - all_input_ids_tensor = torch.tensor( - all_input_ids_tensor, dtype=torch.int64, device=device - ) - - if len(pb.requests) > 1: - input_ids = np.concatenate(all_input_ids, dtype=np.int64) - position_ids = torch.cat(position_ids) - slot_indices = torch.cat(slot_indices) - if sliding_window is not None: - prefill_cache_indices = torch.cat(prefill_cache_indices) - else: - input_ids = all_input_ids[0] - position_ids = position_ids[0] - slot_indices = slot_indices[0] - if sliding_window is not None: - prefill_cache_indices = prefill_cache_indices[0] - - cu_seqlen_prefill = torch.tensor( - cu_seqlen_prefill, device=device, dtype=torch.int32 - ) - - position_ids = position_ids.to(device) - slot_indices = slot_indices.to(device) - prefill_cache_indices = ( - prefill_cache_indices.to(device) if sliding_window is not None else None - ) - input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) - input_lengths_tensor = torch.tensor( - input_lengths, dtype=torch.int32, device=device - ) - - if all_prefill_logprobs: - prefill_head_indices = None - prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 - elif no_prefill_logprobs: - prefill_head_indices = cu_seqlen_prefill[1:] - 1 - prefill_next_token_indices = None - else: - prefill_head_indices = torch.tensor( - torch.cat(prefill_head_indices), dtype=torch.int64, device=device - ) - prefill_next_token_indices = torch.tensor( - prefill_next_token_indices, dtype=torch.int64, device=device - ) - top_n_tokens_tensor = torch.tensor( - top_n_tokens, device=device, dtype=torch.int64 - ) - - return cls( - batch_id=pb.id, - requests=pb.requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - start_slots=start_slots, - slot_indices=slot_indices, - needed_blocks_slots=needed_blocks_slots, - block_tables=None, - block_tables_tensor=None, - slots=None, - max_seqlen=max_seqlen, - prefill_head_indices=prefill_head_indices, - prefill_next_token_indices=prefill_next_token_indices, - prefill_cu_outlens=prefill_cu_outlens, - input_lengths=input_lengths, - input_lengths_tensor=input_lengths_tensor, - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - all_input_ids=all_input_ids, - all_input_ids_tensor=all_input_ids_tensor, - next_token_chooser=next_token_chooser, - stopping_criterias=stopping_criterias, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - blocks=blocks, - max_blocks=max_blocks, - prefill_cache_indices=prefill_cache_indices, - speculative_ids=None, - ) +tracer = trace.get_tracer(__name__) class BaseFlashMistral(FlashCausalLM): @@ -344,9 +60,7 @@ class BaseFlashMistral(FlashCausalLM): # Set context windows if getattr(config, "sliding_window", None) is not None: - set_sliding_window( - config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE) - ) + set_sliding_window(config.sliding_window) else: config.sliding_window = None @@ -384,207 +98,6 @@ class BaseFlashMistral(FlashCausalLM): model.model.head_size, ) - def max_past(self) -> int: - return self.model.max_past - - @property - def batch_type(self) -> Type[FlashMistralBatch]: - return FlashMistralBatch - - def tunableop_warmup(self, seqlen: int): - input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device) - position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device) - slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) - kv_cache = get_cache_manager().kv_cache - - # Dummy value, some models (starcoder2) don't accept `None`. - input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) - - # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. - self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=torch.tensor( - [0, seqlen], device=self.device, dtype=torch.int32 - ), - kv_cache=get_cache_manager().kv_cache, - block_tables=None, - input_lengths=input_lengths, - slots=slots, - max_s=seqlen, - lm_head_indices=None, - prefill_cache_indices=None, - ) - - def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): - input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) - position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) - slots = torch.arange(bs, dtype=torch.int64, device=self.device) - input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s - block_tables = ( - torch.arange(max_bt, dtype=torch.int32, device=self.device) - .repeat(bs) - .reshape((bs, max_bt)) - ) - kv_cache = get_cache_manager().kv_cache - - self.cuda_graphs[bs] = { - "input_ids": input_ids, - "position_ids": position_ids, - "kv_cache": kv_cache, - "block_tables": block_tables, - "slots": slots, - "input_lengths": input_lengths, - } - graph = torch.cuda.CUDAGraph() - self.cuda_graphs[bs]["graph"] = graph - - torch.cuda.synchronize() - # Run once outside to warmup - self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=None, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - input_lengths=input_lengths, - max_s=max_s, - prefill_cache_indices=None, - lm_head_indices=None, - ) - torch.cuda.synchronize() - - with torch.cuda.graph(graph, pool=MEM_POOL): - logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=None, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - input_lengths=input_lengths, - max_s=max_s, - prefill_cache_indices=None, - lm_head_indices=None, - ) - self.cuda_graphs[bs]["logits"] = logits - self.cuda_graphs[bs]["speculative_logits"] = speculative_logits - torch.cuda.synchronize() - - def forward( - self, batch: FlashMistralBatch - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - # Model Forward - if batch.speculative_ids is not None: - input_ids = batch.input_ids - position_ids = batch.position_ids - cu_seqlen_prefill = batch.cu_seqlen_prefill - kv_cache = get_cache_manager().kv_cache - block_tables = batch.block_tables_tensor - slots = batch.slots[batch.slot_indices] - input_lengths = batch.input_lengths_tensor - max_s = batch.max_seqlen - lm_head_indices = batch.prefill_head_indices - - speculative_ids = batch.speculative_ids - - B, speculative_length = speculative_ids.shape - new_length = speculative_length + 1 - new_input_ids = torch.cat( - [input_ids.unsqueeze(-1), speculative_ids], dim=1 - ).reshape(-1) - arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) - arange_int = arange.to(dtype=torch.int32) - new_position_ids = ( - position_ids.unsqueeze(-1).expand(B, new_length) + arange - ).view(-1) - slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) - input_lengths = ( - input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int - ).view(-1) - - # Add Copy the block tables for all members - block_tables = ( - block_tables.unsqueeze(1) - .expand(B, new_length, -1) - .reshape(B * new_length, -1) - .contiguous() - ) - max_s = max_s + speculative_length - - input_ids = new_input_ids - position_ids = new_position_ids - else: - input_ids = batch.input_ids - position_ids = batch.position_ids - cu_seqlen_prefill = batch.cu_seqlen_prefill - kv_cache = get_cache_manager().kv_cache - block_tables = batch.block_tables_tensor - slots = batch.slots[batch.slot_indices] - input_lengths = batch.input_lengths_tensor - max_s = batch.max_seqlen - lm_head_indices = batch.prefill_head_indices - - if cu_seqlen_prefill is None and self.max_past() is not None: - # In decode, not prefill, we're actually overwriting the KV-cache - # in a circular buffer mode. - # This makes sure the max_s for the decode pass is correct. - max_s = min(self.max_past(), max_s) - - bs = input_ids.shape[0] - padded_bs = bs - if bs == 3: - padded_bs = 4 - elif 3 < bs <= 8: - padded_bs = 8 - elif bs > 8: - padded_bs = (bs + 7) // 8 * 8 - - # Try to find an associated cuda graph - cuda_graph = self.cuda_graphs.get(padded_bs, None) - - if cu_seqlen_prefill is not None or cuda_graph is None: - logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - input_lengths=input_lengths, - max_s=max_s, - prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=lm_head_indices, - ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - return logits, speculative_logits - - # Copy inputs to the static inputs of the cuda graph - # Static inputs are potentially padded - cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids - cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids - cuda_graph["block_tables"][ - : block_tables.shape[0], : block_tables.shape[1] - ] = block_tables - cuda_graph["slots"].fill_(-1) - cuda_graph["slots"][: slots.shape[0]] = slots - cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths - - # Replay the graph - cuda_graph["graph"].replay() - - # Slice output to the correct shape - speculative_logits = ( - cuda_graph["speculative_logits"][:bs] - if cuda_graph["speculative_logits"] is not None - else None - ) - logits = cuda_graph["logits"][:bs] - return logits, speculative_logits - class FlashMistral(BaseFlashMistral): def __init__( diff --git a/server/text_generation_server/models/flash_qwen2.py b/server/text_generation_server/models/flash_qwen2.py index 59064b30..75285863 100644 --- a/server/text_generation_server/models/flash_qwen2.py +++ b/server/text_generation_server/models/flash_qwen2.py @@ -7,7 +7,6 @@ from opentelemetry import trace from transformers import AutoTokenizer, AutoConfig from typing import Optional -from text_generation_server.models.cache_manager import BLOCK_SIZE from text_generation_server.models.flash_mistral import ( BaseFlashMistral, set_sliding_window, @@ -57,9 +56,7 @@ class FlashQwen2(BaseFlashMistral): # Set context windows if config.sliding_window is not None: - set_sliding_window( - config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE) - ) + set_sliding_window(config.sliding_window) torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/flash_starcoder2.py b/server/text_generation_server/models/flash_starcoder2.py index dc5d49be..5533c9d9 100644 --- a/server/text_generation_server/models/flash_starcoder2.py +++ b/server/text_generation_server/models/flash_starcoder2.py @@ -6,7 +6,6 @@ from typing import Optional from transformers.models.gpt2 import GPT2TokenizerFast -from text_generation_server.models.cache_manager import BLOCK_SIZE from text_generation_server.models.flash_mistral import ( BaseFlashMistral, set_sliding_window, @@ -56,9 +55,7 @@ class FlashStarcoder2(BaseFlashMistral): # Set context windows if config.sliding_window is not None: - set_sliding_window( - config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE) - ) + set_sliding_window(config.sliding_window) torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index f0db89b2..92d79070 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -11,13 +11,9 @@ from typing import Optional, Tuple, List, Type, Dict from transformers import PreTrainedTokenizerBase from transformers.image_processing_utils import select_best_resolution from text_generation_server.pb import generate_pb2 +from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch from text_generation_server.models.flash_mistral import ( BaseFlashMistral, - FlashMistralBatch, -) -from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch -from text_generation_server.models.cache_manager import ( - get_cache_manager, ) tracer = trace.get_tracer(__name__) @@ -140,7 +136,7 @@ def load_data_uri(image_uri: str) -> Image.Image: return image -class VlmCausalLMBatch(FlashMistralBatch): +class VlmCausalLMBatch(FlashCausalLMBatch): pixel_values: Optional[List[torch.Tensor]] pixel_attention_mask: Optional[List[torch.Tensor]] image_sizes: Optional[List[Tuple[int, int]]] @@ -268,7 +264,7 @@ class VlmCausalLM(BaseFlashMistral): input_ids = batch.input_ids position_ids = batch.position_ids cu_seqlen_prefill = batch.cu_seqlen_prefill - kv_cache = get_cache_manager().kv_cache + kv_cache = self.kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor @@ -307,7 +303,7 @@ class VlmCausalLM(BaseFlashMistral): input_ids = batch.input_ids position_ids = batch.position_ids cu_seqlen_prefill = batch.cu_seqlen_prefill - kv_cache = get_cache_manager().kv_cache + kv_cache = self.kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor