From e74bd41e0f279ab569cf6a65ac3e2cea50e80d39 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 30 Jun 2023 19:09:59 +0200 Subject: [PATCH] feat(server): add paged attention to flash models (#516) Closes #478 --- Dockerfile | 16 +- README.md | 4 +- integration-tests/models/test_flash_neox.py | 2 + launcher/src/main.rs | 44 +- proto/generate.proto | 12 + router/client/src/client.rs | 58 ++ router/client/src/sharded_client.rs | 21 + router/src/infer.rs | 16 +- router/src/main.rs | 42 +- router/src/queue.rs | 47 +- router/src/server.rs | 2 + server/Makefile-vllm | 13 + server/text_generation_server/cache.py | 4 +- .../models/causal_lm.py | 2 +- .../custom_modeling/flash_llama_modeling.py | 216 +++---- .../custom_modeling/flash_neox_modeling.py | 227 +++---- .../custom_modeling/flash_rw_modeling.py | 355 +++++------ .../flash_santacoder_modeling.py | 202 +++---- .../models/custom_modeling/t5_modeling.py | 4 +- .../models/flash_causal_lm.py | 567 +++++++++++------- .../models/flash_llama.py | 6 +- .../models/flash_neox.py | 6 +- .../text_generation_server/models/flash_rw.py | 6 +- .../models/flash_santacoder.py | 13 +- server/text_generation_server/models/model.py | 6 + .../models/seq2seq_lm.py | 2 +- server/text_generation_server/server.py | 7 + server/text_generation_server/utils/tokens.py | 2 + .../text_generation_server/utils/weights.py | 31 +- 29 files changed, 1045 insertions(+), 888 deletions(-) create mode 100644 server/Makefile-vllm diff --git a/Dockerfile b/Dockerfile index 2a313c25..1a969383 100644 --- a/Dockerfile +++ b/Dockerfile @@ -88,7 +88,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins RUN /opt/conda/bin/conda install -c "nvidia/label/cuda-11.8.0" cuda==11.8 && \ /opt/conda/bin/conda clean -ya - # Build Flash Attention CUDA kernels FROM kernel-builder as flash-att-builder @@ -109,6 +108,16 @@ COPY server/custom_kernels/ . # Build specific version of transformers RUN python setup.py build +# Build vllm CUDA kernels +FROM kernel-builder as vllm-builder + +WORKDIR /usr/src + +COPY server/Makefile-vllm Makefile + +# Build specific version of vllm +RUN make build-vllm + # Text Generation Inference base image FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base @@ -137,9 +146,12 @@ COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cp COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages -# Copy build artifacts from transformers builder +# Copy build artifacts from custom kernels builder COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels +# Copy builds artifacts from vllm builder +COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages + # Install flash-attention dependencies RUN pip install einops --no-cache-dir diff --git a/README.md b/README.md index 8c8d9773..b74d2617 100644 --- a/README.md +++ b/README.md @@ -43,8 +43,8 @@ to power LLMs api-inference widgets. - Tensor Parallelism for faster inference on multiple GPUs - Token streaming using Server-Sent Events (SSE) - [Continuous batching of incoming requests](https://github.com/huggingface/text-generation-inference/tree/main/router) for increased total throughput -- Optimized transformers code for inference using [flash-attention](https://github.com/HazyResearch/flash-attention) on the most popular architectures -- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) +- Optimized transformers code for inference using [flash-attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures +- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and [GPT-Q](https://arxiv.org/abs/2210.17323) - [Safetensors](https://github.com/huggingface/safetensors) weight loading - Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) - Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor)) diff --git a/integration-tests/models/test_flash_neox.py b/integration-tests/models/test_flash_neox.py index 1076126b..0289c61d 100644 --- a/integration-tests/models/test_flash_neox.py +++ b/integration-tests/models/test_flash_neox.py @@ -13,6 +13,7 @@ async def flash_neox(flash_neox_handle): return flash_neox_handle.client +@pytest.mark.skip @pytest.mark.asyncio async def test_flash_neox(flash_neox, response_snapshot): response = await flash_neox.generate( @@ -25,6 +26,7 @@ async def test_flash_neox(flash_neox, response_snapshot): assert response == response_snapshot +@pytest.mark.skip @pytest.mark.asyncio async def test_flash_neox_load(flash_neox, generate_load, response_snapshot): responses = await generate_load( diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 2e2bc7a5..8497f807 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -115,12 +115,6 @@ struct Args { #[clap(default_value = "1512", long, env)] max_total_tokens: usize, - /// The maximum allowed batch size during dynamic batching. - /// Using `max_batch_total_tokens` should be favored in general - /// as it's a finer way to control RAM usage. - #[clap(long, env)] - max_batch_size: Option, - /// This represents the ratio of waiting queries vs running queries where /// you want to start considering pausing the running queries to include the waiting /// ones into the same batch. @@ -134,6 +128,12 @@ struct Args { #[clap(default_value = "1.2", long, env)] waiting_served_ratio: f32, + /// Limits the number of tokens for the prefill operation. + /// Since this operation take the most memory and is compute bound, it is interesting + /// to limit the number of requests that can be sent. + #[clap(default_value = "4096", long, env)] + max_batch_prefill_tokens: u32, + /// **IMPORTANT** This is one critical control to allow maximum usage /// of the available hardware. /// @@ -146,19 +146,12 @@ struct Args { /// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100` /// or a single query of `1000` tokens. /// - /// So you don't have to control that finely - /// `max_batch_size` or `max_total_tokens`. In fact you could mostly relax them if you - /// want maximum flexibility. However, for your users if they are asking for the full amount of - /// total tokens, they are likely to wait for a very long time to get a spot - /// in the batch (since they are going to be alone) so setting `max_batch_size` - /// and `max_total_tokens` can still be useful to prevent those long waiting times. - /// /// Overall this number should be the largest possible amount that fits the /// remaining memory (after the model is loaded). Since the actual memory overhead /// depends on other parameters like if you're using quantization, flash attention /// or the model implementation, text-generation-inference cannot infer this number /// automatically. - #[clap(default_value = "32000", long, env)] + #[clap(default_value = "16000", long, env)] max_batch_total_tokens: u32, /// This setting defines how many tokens can be passed before forcing the waiting @@ -180,9 +173,9 @@ struct Args { /// for end users. #[clap(default_value = "20", long, env)] max_waiting_tokens: usize, - #[clap(default_value = "3000", long, short, env)] /// The port to listen on. + #[clap(default_value = "3000", long, short, env)] port: u16, /// The name of the socket for gRPC communication between the webserver @@ -329,6 +322,12 @@ fn shard_manager( // Copy current process env let mut env: Vec<(OsString, OsString)> = env::vars_os().collect(); + // Use cuda allocator. It leads to less memory fragmentation + env.push(( + "PYTORCH_CUDA_ALLOC_CONF".into(), + "backend:cudaMallocAsync".into(), + )); + // Torch Distributed Env vars env.push(("RANK".into(), rank.to_string().into())); env.push(("WORLD_SIZE".into(), world_size.to_string().into())); @@ -446,7 +445,7 @@ fn shard_manager( // We received a shutdown signal if *shutdown.lock().unwrap() { - p.terminate().unwrap(); + p.kill().unwrap(); let _ = p.wait_timeout(Duration::from_secs(90)); tracing::info!("Shard {rank} terminated"); return; @@ -822,6 +821,10 @@ fn spawn_webserver( args.max_input_length.to_string(), "--max-total-tokens".to_string(), args.max_total_tokens.to_string(), + "--max-batch-prefill-tokens".to_string(), + args.max_batch_prefill_tokens.to_string(), + "--max-batch-total-tokens".to_string(), + args.max_batch_total_tokens.to_string(), "--waiting-served-ratio".to_string(), args.waiting_served_ratio.to_string(), "--max-waiting-tokens".to_string(), @@ -834,15 +837,6 @@ fn spawn_webserver( args.model_id, ]; - // Deprecate max_batch_size - if let Some(max_batch_size) = args.max_batch_size { - argv.push("--max-batch-size".to_string()); - argv.push(max_batch_size.to_string()) - } else { - argv.push("--max-batch-total-tokens".to_string()); - argv.push(args.max_batch_total_tokens.to_string()) - } - // Model optional revision if let Some(ref revision) = args.revision { argv.push("--revision".to_string()); diff --git a/proto/generate.proto b/proto/generate.proto index a0f5a75e..5e061941 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -11,6 +11,8 @@ service TextGenerationService { rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse); /// Remove requests from a cached batch rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse); + /// Warmup the model and compute max cache size + rpc Warmup (WarmupRequest) returns (WarmupResponse); /// Prefill batch and decode first token rpc Prefill (PrefillRequest) returns (PrefillResponse); /// Decode token for a list of prefilled batches @@ -192,3 +194,13 @@ message DecodeResponse { /// Next batch (cached) optional CachedBatch batch = 2; } + +message WarmupRequest { + /// Batch to warmup on + Batch batch = 1; + /// Maximum number of tokens that the client will send + uint32 max_total_tokens = 2; +} + +/// Empty response +message WarmupResponse {} diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 81f023ef..b5e0ccc0 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -3,6 +3,7 @@ use crate::pb::generate::v1::text_generation_service_client::TextGenerationServi use crate::pb::generate::v1::*; use crate::Result; use grpc_metadata::InjectTelemetryContext; +use std::cmp::min; use tonic::transport::{Channel, Uri}; use tracing::instrument; @@ -94,6 +95,63 @@ impl Client { Ok(filtered_batch.batch) } + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip(self))] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + ) -> Result<()> { + let mut n_tokens = 0; + let mut requests = Vec::new(); + + // Create requests + while n_tokens < max_prefill_tokens { + requests.push(Request { + id: 0, + // We truncate the input on the server side to be sure that it has the correct size + inputs: "_test ".to_string().repeat(max_input_length as usize), + truncate: min(max_input_length, max_prefill_tokens - n_tokens), + // Set sampling parameters to also take these ops into account in the max memory + parameters: Some(NextTokenChooserParameters { + temperature: 0.9, + top_k: 10, + top_p: 0.9, + typical_p: 0.9, + do_sample: false, + seed: 0, + repetition_penalty: 1.2, + watermark: true, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: 2, + stop_sequences: vec![], + ignore_eos_token: false, + }), + prefill_logprobs: true, + }); + n_tokens += max_input_length; + } + + let batch = Batch { + id: 0, + size: requests.len() as u32, + requests, + max_tokens: 0, + }; + + let request = tonic::Request::new(WarmupRequest { + batch: Some(batch), + max_total_tokens, + }) + .inject_context(); + self.stub.warmup(request).await?.into_inner(); + Ok(()) + } + /// Generate one token for each request in the given batch /// /// Returns Generation for each request in batch diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index b81eed46..9dd173a0 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -87,6 +87,27 @@ impl ShardedClient { join_all(futures).await.pop().unwrap() } + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip(self))] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + ) -> Result<()> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| { + Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens)) + }) + .collect(); + // all shards return the same message + join_all(futures).await.pop().unwrap() + } + /// Generate one token for each request in the given batch /// /// Returns Generation for each request in batch diff --git a/router/src/infer.rs b/router/src/infer.rs index f738f986..d0d22d3b 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -45,6 +45,7 @@ impl Infer { client: ShardedClient, validation: Validation, waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, max_batch_total_tokens: u32, max_waiting_tokens: usize, max_concurrent_requests: usize, @@ -61,6 +62,7 @@ impl Infer { tokio::spawn(batching_task( client, waiting_served_ratio, + max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, queue.clone(), @@ -240,9 +242,11 @@ impl Infer { /// Will be launched in a background Tokio task /// /// Batches requests and sends them to the inference server +#[allow(clippy::too_many_arguments)] async fn batching_task( mut client: ShardedClient, waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, max_batch_total_tokens: u32, max_waiting_tokens: usize, queue: Queue, @@ -257,8 +261,9 @@ async fn batching_task( // Get the next batch from the queue // This batch might be smaller than the maximum batch size if there are not enough requests // waiting in the queue - while let Some((mut entries, batch, span)) = - queue.next_batch(None, max_batch_total_tokens).await + while let Some((mut entries, batch, span)) = queue + .next_batch(None, max_batch_prefill_tokens, max_batch_total_tokens) + .await { let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health) .instrument(span) @@ -284,11 +289,12 @@ async fn batching_task( Some((batch_size as f32 * waiting_served_ratio).floor() as usize) }; - let token_budget = max_batch_total_tokens - batch_max_tokens; + let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); // Try to get a new batch - if let Some((mut new_entries, new_batch, span)) = - queue.next_batch(min_size, token_budget).await + if let Some((mut new_entries, new_batch, span)) = queue + .next_batch(min_size, max_batch_prefill_tokens, token_budget) + .await { // Tracking metrics if min_size.is_some() { diff --git a/router/src/main.rs b/router/src/main.rs index 7bbb6477..47d48e3f 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -32,10 +32,10 @@ struct Args { max_input_length: usize, #[clap(default_value = "1512", long, env)] max_total_tokens: usize, - #[clap(long, env)] - max_batch_size: Option, #[clap(default_value = "1.2", long, env)] waiting_served_ratio: f32, + #[clap(default_value = "4096", long, env)] + max_batch_prefill_tokens: u32, #[clap(default_value = "32000", long, env)] max_batch_total_tokens: u32, #[clap(default_value = "20", long, env)] @@ -78,9 +78,9 @@ fn main() -> Result<(), std::io::Error> { max_stop_sequences, max_input_length, max_total_tokens, - max_batch_size, waiting_served_ratio, - mut max_batch_total_tokens, + max_batch_prefill_tokens, + max_batch_total_tokens, max_waiting_tokens, port, master_shard_uds_path, @@ -141,12 +141,6 @@ fn main() -> Result<(), std::io::Error> { .block_on(async { init_logging(otlp_endpoint, json_output); - if let Some(max_batch_size) = max_batch_size { - tracing::warn!("`max-batch-size` is deprecated. Use `max-batch-total-tokens` instead"); - max_batch_total_tokens = (max_batch_size * max_total_tokens) as u32; - tracing::warn!("Overriding `max-batch-total-tokens` value with `max-batch-size` * `max-total-tokens` = {max_batch_total_tokens}"); - } - if tokenizer.is_none() { tracing::warn!( "Could not find a fast tokenizer implementation for {tokenizer_name}" @@ -161,10 +155,16 @@ fn main() -> Result<(), std::io::Error> { sha: None, pipeline_tag: None, }, - false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or_else(|| { - tracing::warn!("Could not retrieve model info from the Hugging Face hub."); - HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None } - }), + false => get_model_info(&tokenizer_name, &revision, authorization_token) + .await + .unwrap_or_else(|| { + tracing::warn!("Could not retrieve model info from the Hugging Face hub."); + HubModelInfo { + model_id: tokenizer_name.to_string(), + sha: None, + pipeline_tag: None, + } + }), }; // if pipeline-tag == text-generation we default to return_full_text = true @@ -190,6 +190,17 @@ fn main() -> Result<(), std::io::Error> { .info() .await .expect("Unable to get shard info"); + + // Warmup model + tracing::info!("Warming up model"); + sharded_client + .warmup( + max_input_length as u32, + max_batch_prefill_tokens, + max_batch_total_tokens, + ) + .await + .expect("Unable to warmup model"); tracing::info!("Connected"); // Binds on localhost @@ -206,6 +217,7 @@ fn main() -> Result<(), std::io::Error> { max_input_length, max_total_tokens, waiting_served_ratio, + max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, sharded_client, @@ -219,7 +231,7 @@ fn main() -> Result<(), std::io::Error> { ngrok_username, ngrok_password, ) - .await; + .await; Ok(()) }) } diff --git a/router/src/queue.rs b/router/src/queue.rs index 6d1d4d12..48e483a1 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -58,6 +58,7 @@ impl Queue { pub(crate) async fn next_batch( &self, min_size: Option, + prefill_token_budget: u32, token_budget: u32, ) -> Option { // Create response channel @@ -67,6 +68,7 @@ impl Queue { self.queue_sender .send(QueueCommand::NextBatch { min_size, + prefill_token_budget, token_budget, response_sender, span: Span::current(), @@ -90,11 +92,12 @@ async fn queue_task(requires_padding: bool, receiver: flume::Receiver span.in_scope(|| { - let next_batch = state.next_batch(min_size, token_budget); + let next_batch = state.next_batch(min_size, prefill_token_budget, token_budget); response_sender.send(next_batch).unwrap(); metrics::gauge!("tgi_queue_size", state.entries.len() as f64); }), @@ -140,7 +143,12 @@ impl State { } // Get the next batch - fn next_batch(&mut self, min_size: Option, token_budget: u32) -> Option { + fn next_batch( + &mut self, + min_size: Option, + prefill_token_budget: u32, + token_budget: u32, + ) -> Option { if self.entries.is_empty() { return None; } @@ -184,7 +192,9 @@ impl State { decode_tokens += entry.request.stopping_parameters.max_new_tokens; - if (prefill_tokens + decode_tokens) > token_budget { + if prefill_tokens > prefill_token_budget + || (prefill_tokens + decode_tokens) > token_budget + { // Entry is over budget // Add it back to the front self.entries.push_front((id, entry)); @@ -259,6 +269,7 @@ enum QueueCommand { Append(Box, Span), NextBatch { min_size: Option, + prefill_token_budget: u32, token_budget: u32, response_sender: oneshot::Sender>, span: Span, @@ -328,8 +339,8 @@ mod tests { fn test_next_batch_empty() { let mut state = State::new(false); - assert!(state.next_batch(None, 1).is_none()); - assert!(state.next_batch(Some(1), 1).is_none()); + assert!(state.next_batch(None, 1, 1).is_none()); + assert!(state.next_batch(Some(1), 1, 1).is_none()); } #[test] @@ -340,7 +351,7 @@ mod tests { state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, 2).unwrap(); + let (entries, batch, _) = state.next_batch(None, 2, 2).unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&0)); assert!(entries.contains_key(&1)); @@ -356,7 +367,7 @@ mod tests { let (entry3, _guard3) = default_entry(); state.append(entry3); - assert!(state.next_batch(Some(2), 2).is_none()); + assert!(state.next_batch(Some(2), 2, 2).is_none()); assert_eq!(state.next_id, 3); assert_eq!(state.entries.len(), 1); @@ -372,7 +383,7 @@ mod tests { state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, 1).unwrap(); + let (entries, batch, _) = state.next_batch(None, 1, 1).unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert_eq!(batch.id, 0); @@ -385,7 +396,7 @@ mod tests { let (entry3, _guard3) = default_entry(); state.append(entry3); - let (entries, batch, _) = state.next_batch(None, 3).unwrap(); + let (entries, batch, _) = state.next_batch(None, 3, 3).unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&1)); assert!(entries.contains_key(&2)); @@ -408,8 +419,8 @@ mod tests { async fn test_queue_next_batch_empty() { let queue = Queue::new(false); - assert!(queue.next_batch(None, 1).await.is_none()); - assert!(queue.next_batch(Some(1), 1).await.is_none()); + assert!(queue.next_batch(None, 1, 1).await.is_none()); + assert!(queue.next_batch(Some(1), 1, 1).await.is_none()); } #[tokio::test] @@ -420,7 +431,7 @@ mod tests { queue.append(entry1); queue.append(entry2); - let (entries, batch, _) = queue.next_batch(None, 2).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, 2, 2).await.unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&0)); assert!(entries.contains_key(&1)); @@ -433,11 +444,11 @@ mod tests { queue.append(entry3); // Not enough requests pending - assert!(queue.next_batch(Some(2), 2).await.is_none()); + assert!(queue.next_batch(Some(2), 2, 2).await.is_none()); // Not enough token budget - assert!(queue.next_batch(Some(1), 0).await.is_none()); + assert!(queue.next_batch(Some(1), 0, 0).await.is_none()); // Ok - let (entries2, batch2, _) = queue.next_batch(Some(1), 2).await.unwrap(); + let (entries2, batch2, _) = queue.next_batch(Some(1), 2, 2).await.unwrap(); assert_eq!(entries2.len(), 1); assert!(entries2.contains_key(&2)); assert!(entries2.get(&2).unwrap().batch_time.is_some()); @@ -453,7 +464,7 @@ mod tests { queue.append(entry1); queue.append(entry2); - let (entries, batch, _) = queue.next_batch(None, 1).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, 1, 1).await.unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert_eq!(batch.id, 0); @@ -462,7 +473,7 @@ mod tests { let (entry3, _guard3) = default_entry(); queue.append(entry3); - let (entries, batch, _) = queue.next_batch(None, 3).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, 3, 3).await.unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&1)); assert!(entries.contains_key(&2)); @@ -476,6 +487,6 @@ mod tests { let (entry, _) = default_entry(); queue.append(entry); - assert!(queue.next_batch(None, 1).await.is_none()); + assert!(queue.next_batch(None, 1, 1).await.is_none()); } } diff --git a/router/src/server.rs b/router/src/server.rs index dd8bc874..ee96ead6 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -514,6 +514,7 @@ pub async fn run( max_input_length: usize, max_total_tokens: usize, waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, max_batch_total_tokens: u32, max_waiting_tokens: usize, client: ShardedClient, @@ -582,6 +583,7 @@ pub async fn run( client, validation, waiting_served_ratio, + max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, max_concurrent_requests, diff --git a/server/Makefile-vllm b/server/Makefile-vllm new file mode 100644 index 00000000..af750733 --- /dev/null +++ b/server/Makefile-vllm @@ -0,0 +1,13 @@ +vllm_commit := d284b831c17f42a8ea63369a06138325f73c4cf9 + +vllm: + # Clone vllm + git clone https://github.com/OlivierDehaene/vllm.git + +build-vllm: vllm + cd vllm && git fetch && git checkout $(vllm_commit) + cd vllm && python setup.py build + +install-vllm: build-vllm + pip uninstall vllm -y || true + cd vllm && python setup.py install \ No newline at end of file diff --git a/server/text_generation_server/cache.py b/server/text_generation_server/cache.py index 5556529c..79fcd3aa 100644 --- a/server/text_generation_server/cache.py +++ b/server/text_generation_server/cache.py @@ -22,7 +22,9 @@ class Cache: del batch def clear(self): - self.cache.clear() + keys = list(self.cache.keys()) + for k in keys: + self.delete(k) def __len__(self): return len(self.cache.keys()) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index ba0853f5..6d47c6eb 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -122,7 +122,7 @@ class CausalLMBatch(Batch): position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) - max_tokens = len(inputs) * max_input_length + max_decode_tokens + max_tokens = len(inputs) * (max_input_length + max_decode_tokens) return cls( batch_id=pb.id, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 993e1e2a..07765e88 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -23,12 +23,16 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN -from typing import Optional +from typing import Optional, List, Tuple # Flash attention imports import flash_attn_cuda import dropout_layer_norm +# vllm imports +import vllm_cache_ops +import vllm_attention_ops + from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -106,7 +110,7 @@ class FlashLlamaAttention(torch.nn.Module): prefix=f"{prefix}.rotary_emb", weights=weights ) - self.softmax_scale = self.head_size ** (-0.5) + self.softmax_scale = self.head_size**-0.5 self.num_heads = self.num_heads // weights.process_group.size() self.query_key_value = TensorParallelColumnLinear.load_multi( @@ -122,20 +126,22 @@ class FlashLlamaAttention(torch.nn.Module): weights=weights, bias=False, ) + self.kv_head_mapping = torch.arange( + 0, self.num_heads, dtype=torch.int32, device=weights.device + ) def forward( self, hidden_states, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) @@ -144,23 +150,25 @@ class FlashLlamaAttention(torch.nn.Module): self.rotary_emb(qkv[:, 0], cos, sin) self.rotary_emb(qkv[:, 1], cos, sin) - # Prefill - if prefill: - # Copy to layer past - layer_past[...] = qkv[:, 1:] + vllm_cache_ops.reshape_and_cache( + qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots + ) - # output - attn_output = torch.empty_like(qkv[:, 0]) + # output tensor + attn_output = torch.empty_like(qkv[:, 0]) + + # Prefill + if start_seq_prefill is not None: # flash attention flash_attn_cuda.fwd( qkv[:, 0], qkv[:, 1], qkv[:, 2], attn_output, - start_seq, - end_seq, - start_seq, - end_seq, + start_seq_prefill, + end_seq_prefill, + start_seq_prefill, + end_seq_prefill, max_s, max_s, 0.0, @@ -173,31 +181,19 @@ class FlashLlamaAttention(torch.nn.Module): ) # Decode else: - query = qkv[:, 0] - # Add present to the layer_past tensor at the correct indices - layer_past[past_present_indices] = qkv[:, 1:] - - # output - attn_output = torch.empty_like(query) - # flash attention - flash_attn_cuda.fwd( - query, - layer_past[:, 0], - layer_past[:, 1], + # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] + block_size = kv_cache[1].shape[3] + vllm_attention_ops.single_query_cached_kv_attention( attn_output, - start_seq_q, - end_seq_q, - start_seq, - end_seq, - 1, - max_s, - 0.0, + qkv[:, 0], + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, self.softmax_scale, - False, - False, - False, - 0, - None, + block_tables, + input_lengths, + block_size, + max_s, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -265,14 +261,13 @@ class FlashLlamaLayer(nn.Module): residual, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -281,14 +276,13 @@ class FlashLlamaLayer(nn.Module): normed_hidden_states, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ) # faster post attention rms norm @@ -333,40 +327,18 @@ class FlashLlamaModel(torch.nn.Module): def forward( self, - input_ids, - position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - past_present_indices, - past_key_values=None, - pre_allocate_past_size: Optional[int] = None, - ): + input_ids: torch.Tensor, + position_ids: torch.Tensor, + start_seq_prefill: Optional[torch.Tensor], + end_seq_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) - # Prefill - if past_key_values is None: - assert pre_allocate_past_size is not None - - prefill = True - - # Create past tensor - # We create a tensor of the same size as input_ids as we don't want to slice at every layer - past_key_values = hidden_states.new_empty( - ( - len(input_ids), - len(self.layers), - 2, - self.num_heads, - self.head_size, - ) - ) - # Decode - else: - prefill = False - # Get rotary cos and sin for this forward # Avoid to index in each layer cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( @@ -380,34 +352,18 @@ class FlashLlamaModel(torch.nn.Module): residual, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, max_s, - past_key_values[:, i], - past_present_indices, - prefill, ) - if prefill: - present = past_key_values - # Create padded past tensor - past_key_values = hidden_states.new_empty( - ( - pre_allocate_past_size, - len(self.layers), - 2, - self.num_heads, - self.head_size, - ) - ) - # We slice only once instead of at every layer - past_key_values[past_present_indices] = present - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states, past_key_values + return hidden_states class FlashLlamaForCausalLM(torch.nn.Module): @@ -423,31 +379,29 @@ class FlashLlamaForCausalLM(torch.nn.Module): def forward( self, - input_ids, - position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - past_present_indices, - past_key_values: Optional[torch.Tensor] = None, - pre_allocate_past_size: Optional[int] = None, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + start_seq_prefill: Optional[torch.Tensor], + end_seq_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, lm_head_indices: Optional[torch.Tensor] = None, - ): - hidden_states, present = self.model( + ) -> torch.Tensor: + hidden_states = self.model( input_ids, position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - past_present_indices, - past_key_values, - pre_allocate_past_size, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) - return logits, present + return logits diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 9c1020a5..9049878a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -25,11 +25,15 @@ from torch import nn from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel from transformers.models.gpt_neox import GPTNeoXConfig -from typing import Optional +from typing import Optional, List, Tuple # Flash attention imports import flash_attn_cuda +# vllm imports +import vllm_cache_ops +import vllm_attention_ops + from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -110,20 +114,22 @@ class FlashNeoxAttention(torch.nn.Module): self.dense = load_row( config, prefix=f"{prefix}.dense", weights=weights, bias=True ) + self.kv_head_mapping = torch.arange( + 0, self.num_heads, dtype=torch.int32, device=weights.device + ) def forward( self, hidden_states, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) @@ -132,23 +138,25 @@ class FlashNeoxAttention(torch.nn.Module): self.rotary_emb(qkv[:, 0], cos, sin) self.rotary_emb(qkv[:, 1], cos, sin) - # Prefill - if prefill: - # Copy to layer past - layer_past[...] = qkv[:, 1:] + vllm_cache_ops.reshape_and_cache( + qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots + ) - # output - attn_output = torch.empty_like(qkv[:, 0]) + # output tensor + attn_output = torch.empty_like(qkv[:, 0]) + + # Prefill + if start_seq_prefill is not None: # flash attention flash_attn_cuda.fwd( qkv[:, 0], qkv[:, 1], qkv[:, 2], attn_output, - start_seq, - end_seq, - start_seq, - end_seq, + start_seq_prefill, + end_seq_prefill, + start_seq_prefill, + end_seq_prefill, max_s, max_s, 0.0, @@ -161,31 +169,19 @@ class FlashNeoxAttention(torch.nn.Module): ) # Decode else: - query = qkv[:, 0] - # Add present to the layer_past tensor at the correct indices - layer_past[past_present_indices] = qkv[:, 1:] - - # output - attn_output = torch.empty_like(query) - # flash attention - flash_attn_cuda.fwd( - query, - layer_past[:, 0], - layer_past[:, 1], + # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] + block_size = kv_cache[1].shape[3] + vllm_attention_ops.single_query_cached_kv_attention( attn_output, - start_seq_q, - end_seq_q, - start_seq, - end_seq, - 1, - max_s, - 0.0, + qkv[:, 0], + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, self.softmax_scale, - False, - False, - False, - 0, - None, + block_tables, + input_lengths, + block_size, + max_s, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -250,14 +246,13 @@ class FlashNeoXLayer(nn.Module): residual, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): if self.use_parallel_residual: ln1_hidden_states, _ = self.input_layernorm(hidden_states) @@ -266,14 +261,13 @@ class FlashNeoXLayer(nn.Module): ln1_hidden_states, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ) ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states) @@ -292,14 +286,13 @@ class FlashNeoXLayer(nn.Module): hidden_states, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ) hidden_states, residual = self.post_attention_layernorm( @@ -346,40 +339,18 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): def forward( self, - input_ids, - position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - past_present_indices, - past_key_values=None, - pre_allocate_past_size: Optional[int] = None, - ): + input_ids: torch.Tensor, + position_ids: torch.Tensor, + start_seq_prefill: Optional[torch.Tensor], + end_seq_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: hidden_states = self.embed_in(input_ids) - # Prefill - if past_key_values is None: - assert pre_allocate_past_size is not None - - prefill = True - - # Create past tensor - # We create a tensor of the same size as input_ids as we don't want to slice at every layer - past_key_values = hidden_states.new_empty( - ( - len(input_ids), - len(self.layers), - 2, - self.num_heads, - self.head_size, - ) - ) - # Decode - else: - prefill = False - # Get rotary cos and sin for this forward # Avoid to index in each layer cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin( @@ -393,34 +364,18 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): residual, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, max_s, - past_key_values[:, i], - past_present_indices, - prefill, ) - if prefill: - present = past_key_values - # Create padded past tensor - past_key_values = hidden_states.new_empty( - ( - pre_allocate_past_size, - len(self.layers), - 2, - self.num_heads, - self.head_size, - ) - ) - # We slice only once instead of at every layer - past_key_values[past_present_indices] = present - hidden_states, _ = self.final_layer_norm(hidden_states, residual) - return hidden_states, past_key_values + return hidden_states class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): @@ -434,31 +389,29 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): def forward( self, - input_ids, - position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - past_present_indices, - past_key_values: Optional[torch.Tensor] = None, - pre_allocate_past_size: Optional[int] = None, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + start_seq_prefill: Optional[torch.Tensor], + end_seq_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, lm_head_indices: Optional[torch.Tensor] = None, - ): - hidden_states, present = self.gpt_neox( + ) -> torch.Tensor: + hidden_states = self.gpt_neox( input_ids, position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - past_present_indices, - past_key_values, - pre_allocate_past_size, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits = self.embed_out(hidden_states) - return logits, present + return logits diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index fa35c359..44aa7cb1 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -4,11 +4,15 @@ import torch.distributed from torch import nn from transformers.modeling_utils import PreTrainedModel from transformers.configuration_utils import PretrainedConfig -from typing import Optional +from typing import Optional, List, Tuple # Flash attention imports import flash_attn_cuda +# vllm imports +import vllm_cache_ops +import vllm_attention_ops + from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -126,19 +130,27 @@ class FlashRWAttention(torch.nn.Module): config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias ) + if self.num_heads_kv == 1: + self.kv_head_mapping = torch.zeros( + self.num_heads, dtype=torch.int32, device=weights.device + ) + else: + self.kv_head_mapping = torch.arange( + 0, self.num_heads, dtype=torch.int32, device=weights.device + ) + def forward( self, hidden_states, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): qkv = self.query_key_value(hidden_states) @@ -156,25 +168,29 @@ class FlashRWAttention(torch.nn.Module): self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) - # Prefill - if prefill: - # Copy to layer past - layer_past[...] = kv - # Expand to query shape - kv = kv.expand(-1, 2, self.num_heads, self.head_size) + vllm_cache_ops.reshape_and_cache( + kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots + ) + + # output + attn_output = torch.empty_like(query) + + # Prefill + if start_seq_prefill is not None: + if self.num_heads_kv == 1: + # Expand to query shape + kv = kv.expand(-1, 2, self.num_heads, self.head_size) - # output - attn_output = torch.empty_like(query) # flash attention flash_attn_cuda.fwd( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), attn_output, - start_seq, - end_seq, - start_seq, - end_seq, + start_seq_prefill, + end_seq_prefill, + start_seq_prefill, + end_seq_prefill, max_s, max_s, 0.0, @@ -187,32 +203,19 @@ class FlashRWAttention(torch.nn.Module): ) # Decode else: - # Add present to the layer_past tensor at the correct indices - layer_past[past_present_indices] = kv - # Expand to query shape - kv = layer_past.expand(-1, 2, self.num_heads, self.head_size) - - # output - attn_output = torch.empty_like(query) - # flash attention - flash_attn_cuda.fwd( - query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), + # kv_cache[1] => [num_blocks, num_heads_kv, head_size, block_size] + block_size = kv_cache[1].shape[3] + vllm_attention_ops.single_query_cached_kv_attention( attn_output, - start_seq_q, - end_seq_q, - start_seq, - end_seq, - 1, - max_s, - 0.0, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, self.softmax_scale, - False, - False, - False, - 0, - None, + block_tables, + input_lengths, + block_size, + max_s, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -264,19 +267,22 @@ class FlashRWLargeAttention(torch.nn.Module): config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias ) + self.kv_head_mapping = torch.arange( + 0, self.num_groups, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_heads) + def forward( self, hidden_states, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size) @@ -293,10 +299,19 @@ class FlashRWLargeAttention(torch.nn.Module): self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin) + vllm_cache_ops.reshape_and_cache( + kv[:, :, 0].contiguous(), + kv[:, :, 1].contiguous(), + kv_cache[0], + kv_cache[1], + slots, + ) + + # output + attn_output = torch.empty_like(query) + # Prefill - if prefill: - # Copy to layer past - layer_past[...] = kv + if start_seq_prefill is not None: # Expand to query shape kv = ( kv.unsqueeze(2) @@ -304,18 +319,16 @@ class FlashRWLargeAttention(torch.nn.Module): .reshape(-1, self.num_groups * self.num_heads, 2, self.head_size) ) - # output - attn_output = torch.empty_like(query) # flash attention flash_attn_cuda.fwd( query, torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=1), attn_output, - start_seq, - end_seq, - start_seq, - end_seq, + start_seq_prefill, + end_seq_prefill, + start_seq_prefill, + end_seq_prefill, max_s, max_s, 0.0, @@ -328,36 +341,19 @@ class FlashRWLargeAttention(torch.nn.Module): ) # Decode else: - # Add present to the layer_past tensor at the correct indices - layer_past[past_present_indices] = kv - # Expand to query shape - kv = ( - layer_past.unsqueeze(2) - .expand(-1, self.num_groups, self.num_heads, 2, self.head_size) - .reshape(-1, self.num_groups * self.num_heads, 2, self.head_size) - ) - - # output - attn_output = torch.empty_like(query) - # flash attention - flash_attn_cuda.fwd( - query, - torch.select(kv, dim=2, index=0), - torch.select(kv, dim=2, index=1), + # kv_cache[1] => [num_blocks, num_groups, head_size, block_size] + block_size = kv_cache[1].shape[3] + vllm_attention_ops.single_query_cached_kv_attention( attn_output, - start_seq_q, - end_seq_q, - start_seq, - end_seq, - 1, - max_s, - 0.0, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, self.softmax_scale, - False, - False, - False, - 0, - None, + block_tables, + input_lengths, + block_size, + max_s, ) return self.dense( @@ -432,14 +428,13 @@ class FlashRWLayer(nn.Module): residual, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): if self.parallel_attn: ln_hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -448,14 +443,13 @@ class FlashRWLayer(nn.Module): ln_hidden_states, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ) mlp_output = self.mlp(ln_hidden_states) @@ -472,14 +466,13 @@ class FlashRWLayer(nn.Module): hidden_states, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ) hidden_states, residual = self.post_attention_layernorm( @@ -523,14 +516,13 @@ class FlashRWLargeLayer(nn.Module): residual, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): ln_attn, residual = self.ln_attn(hidden_states, residual) ln_mlp, _ = self.ln_mlp(residual) @@ -540,14 +532,13 @@ class FlashRWLargeLayer(nn.Module): ln_attn, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ) # MLP. @@ -580,11 +571,7 @@ class FlashRWModel(FlashRWPreTrainedModel): for layer_id in range(config.num_hidden_layers) ] ) - self.cache_size = ( - 2, - self.h[0].self_attention.num_heads_kv, - self.h[0].self_attention.head_size, - ) + self.cache_size = self.h[0].self_attention.num_heads_kv elif config.model_type == "RefinedWeb": self.h = nn.ModuleList( [ @@ -592,11 +579,7 @@ class FlashRWModel(FlashRWPreTrainedModel): for layer_id in range(config.num_hidden_layers) ] ) - self.cache_size = ( - self.h[0].self_attention.num_groups, - 2, - self.h[0].self_attention.head_size, - ) + self.cache_size = self.h[0].self_attention.num_groups else: raise NotImplementedError( f"model_type {config.model_type} is not supported." @@ -612,38 +595,18 @@ class FlashRWModel(FlashRWPreTrainedModel): def forward( self, - input_ids, - position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - past_present_indices, - past_key_values=None, - pre_allocate_past_size: Optional[int] = None, - ): + input_ids: torch.Tensor, + position_ids: torch.Tensor, + start_seq_prefill: Optional[torch.Tensor], + end_seq_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: hidden_states = self.word_embeddings(input_ids) - # Prefill - if past_key_values is None: - assert pre_allocate_past_size is not None - - prefill = True - - # Create past tensor - # We create a tensor of the same size as input_ids as we don't want to slice at every layer - past_key_values = hidden_states.new_empty( - ( - len(input_ids), - len(self.h), - *self.cache_size, - ) - ) - # Decode - else: - prefill = False - # Get rotary cos and sin for this forward # Avoid to index in each layer cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin( @@ -657,32 +620,18 @@ class FlashRWModel(FlashRWPreTrainedModel): residual, cos, sin, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, max_s, - torch.select(past_key_values, dim=1, index=i), - past_present_indices, - prefill, ) - if prefill: - present = past_key_values - # Create padded past tensor - past_key_values = hidden_states.new_empty( - ( - pre_allocate_past_size, - len(self.h), - *self.cache_size, - ) - ) - # We slice only once instead of at every layer - past_key_values[past_present_indices] = present - hidden_states, _ = self.ln_f(hidden_states, residual) - return hidden_states, past_key_values + return hidden_states class FlashRWForCausalLM(FlashRWPreTrainedModel): @@ -697,31 +646,29 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): def forward( self, - input_ids, - position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - past_present_indices, - past_key_values: Optional[torch.Tensor] = None, - pre_allocate_past_size: Optional[int] = None, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + start_seq_prefill: Optional[torch.Tensor], + end_seq_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, lm_head_indices: Optional[torch.Tensor] = None, - ): - hidden_states, present = self.transformer( + ) -> torch.Tensor: + hidden_states = self.transformer( input_ids, position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - past_present_indices, - past_key_values, - pre_allocate_past_size, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) - return logits, present + return logits diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 4eb0034d..04eedef7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -3,11 +3,15 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN -from typing import Optional +from typing import Optional, List, Tuple # Flash attention imports import flash_attn_cuda +# vllm imports +import vllm_cache_ops +import vllm_attention_ops + from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -221,18 +225,20 @@ class FlashMQAttention(torch.nn.Module): self.c_proj = load_row( config, prefix=f"{prefix}.c_proj", weights=weights, bias=True ) + self.kv_head_mapping = torch.zeros( + self.num_heads, dtype=torch.int32, device=weights.device + ) def forward( self, hidden_states, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): qkv = self.c_attn(hidden_states) @@ -245,25 +251,28 @@ class FlashMQAttention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size) + vllm_cache_ops.reshape_and_cache( + key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots + ) + + # output + attn_output = torch.empty_like(query) + # Prefill - if prefill: - # Copy to layer past - layer_past[...] = key_value + if start_seq_prefill is not None: # Expand from 1 to num_heads key_value = key_value.expand(-1, 2, self.num_heads, self.head_size) - # output - attn_output = torch.empty_like(query) # flash attention flash_attn_cuda.fwd( query, torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=1), attn_output, - start_seq, - end_seq, - start_seq, - end_seq, + start_seq_prefill, + end_seq_prefill, + start_seq_prefill, + end_seq_prefill, max_s, max_s, 0.0, @@ -276,32 +285,19 @@ class FlashMQAttention(torch.nn.Module): ) # Decode else: - # Add present to the layer_past tensor at the correct indices - layer_past[past_present_indices] = key_value - # Expand from 1 to num_heads - key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size) - - # output - attn_output = torch.empty_like(query) - # flash attention - flash_attn_cuda.fwd( - query, - torch.select(key_value, dim=1, index=0), - torch.select(key_value, dim=1, index=1), + # kv_cache[1] => [num_blocks, 1, head_size, block_size] + block_size = kv_cache[1].shape[3] + vllm_attention_ops.single_query_cached_kv_attention( attn_output, - start_seq_q, - end_seq_q, - start_seq, - end_seq, - 1, - max_s, - 0.0, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, self.softmax_scale, - False, - False, - False, - 0, - None, + block_tables, + input_lengths, + block_size, + max_s, ) return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -361,27 +357,25 @@ class Block(nn.Module): self, hidden_states, residual, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ): hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states = self.attn( hidden_states, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - layer_past, - past_present_indices, - prefill, ) hidden_states, residual = self.ln_2(hidden_states, residual) @@ -427,64 +421,38 @@ class FlashSantacoderModel(nn.Module): def forward( self, - input_ids, - position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - past_present_indices, - past_key_values=None, - pre_allocate_past_size: Optional[int] = None, - ): + input_ids: torch.Tensor, + position_ids: torch.Tensor, + start_seq_prefill: Optional[torch.Tensor], + end_seq_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: hidden_states = self.wte(input_ids) + self.wpe(position_ids) if self.process_group.size() > 1: torch.distributed.all_reduce(hidden_states, group=self.process_group) - # Prefill - if past_key_values is None: - assert pre_allocate_past_size is not None - - prefill = True - - # Create past tensor - # We create a tensor of the same size as input_ids as we don't want to slice at every layer - past_key_values = hidden_states.new_zeros( - (len(input_ids), len(self.h), 2, 1, self.head_size) - ) - # Decode - else: - prefill = False - residual = None for i, layer in enumerate(self.h): hidden_states, residual = layer( hidden_states, residual, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, max_s, - torch.select(past_key_values, dim=1, index=i), - past_present_indices, - prefill, ) - if prefill: - present = past_key_values - # Create padded past tensor - past_key_values = hidden_states.new_empty( - (pre_allocate_past_size, len(self.h), 2, 1, self.head_size) - ) - # We slice only once instead of at every layer - past_key_values[past_present_indices] = present - hidden_states, _ = self.ln_f(hidden_states, residual) - return hidden_states, past_key_values + return hidden_states class FlashSantacoderForCausalLM(nn.Module): @@ -497,31 +465,29 @@ class FlashSantacoderForCausalLM(nn.Module): def forward( self, - input_ids, - position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, - max_s, - past_present_indices, - past_key_values: Optional[torch.Tensor] = None, - pre_allocate_past_size: Optional[int] = None, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + start_seq_prefill: Optional[torch.Tensor], + end_seq_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, lm_head_indices: Optional[torch.Tensor] = None, - ): - hidden_states, present = self.transformer( + ) -> torch.Tensor: + hidden_states = self.transformer( input_ids, position_ids, - start_seq, - end_seq, - start_seq_q, - end_seq_q, + start_seq_prefill, + end_seq_prefill, + kv_cache, + block_tables, + slots, + input_lengths, max_s, - past_present_indices, - past_key_values, - pre_allocate_past_size, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) - return logits, present + return logits diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index 12679e9d..19deca86 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -1004,7 +1004,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel): try: self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) except RuntimeError: - self.shared = TensorParallelEmbedding(prefix="encoder.embed_tokens", weights=weights) + self.shared = TensorParallelEmbedding( + prefix="encoder.embed_tokens", weights=weights + ) encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index ecea998e..94b14f85 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1,11 +1,14 @@ +import math +import itertools import torch import torch.distributed import numpy as np from dataclasses import dataclass +from loguru import logger from opentelemetry import trace -from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel +from transformers import PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type, Union, Dict from text_generation_server.models import Model @@ -20,6 +23,92 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke tracer = trace.get_tracer(__name__) +BLOCK_SIZE = 16 +# Will be set in warmup +CACHE_MANAGER: Optional["CacheManager"] = None + + +class CacheManager: + def __init__( + self, + num_blocks: int, + num_layers: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + ): + self.block_size = BLOCK_SIZE + + element_size = torch.tensor([], dtype=dtype).element_size() + x = self.block_size // element_size + + self.kv_cache = [ + ( + torch.empty( + (num_blocks, num_heads, head_size // x, self.block_size, x), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, num_heads, head_size, self.block_size), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] + self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu") + self.slots = torch.arange( + 0, num_blocks * self.block_size, dtype=torch.int32 + ).view(num_blocks, self.block_size) + + def allocate(self, batch: "FlashCausalLMBatch"): + # Get free blocks indices by finding values in mask that are not set to 0 + free_block_indices = self.free_block_mask.nonzero() + assert ( + len(free_block_indices) >= batch.blocks + ), f"Out of available cache blocks: asked {batch.blocks}, only {len(free_block_indices)} free blocks" + + # Slice by the number of required blocks + block_indices = free_block_indices[: batch.blocks] + block_indices = block_indices.flatten() + + # Padded block tables + block_tables_tensor = torch.zeros( + (len(batch), batch.max_blocks), dtype=torch.int32 + ) + + # Allocate paged attention blocks + cumulative_blocks = 0 + slots = [] + block_tables = [] + for i, (needed_blocks, needed_slots) in enumerate(batch.needed_blocks_slots): + # Get allocated blocks for this sequence + allocated_blocks = block_indices[ + cumulative_blocks : cumulative_blocks + needed_blocks + ] + # Get slots for the allocated blocks + allocated_slots = self.slots[allocated_blocks].flatten()[:needed_slots] + + slots.append(allocated_slots) + block_tables.append(allocated_blocks.tolist()) + block_tables_tensor[i, :needed_blocks] = allocated_blocks + cumulative_blocks += needed_blocks + + batch.needed_blocks_slots = None + batch.block_tables = block_tables + batch.block_tables_tensor = block_tables_tensor.to(batch.input_ids.device) + batch.slots = torch.concat(slots).to(batch.input_ids.device) + + # Allocate the required number of blocks by setting the mask to 0 + self.free_block_mask[block_indices] = 0 + + def free(self, block_indices: Optional[List[int]]): + if block_indices is not None and block_indices: + # Reset mask + self.free_block_mask[block_indices] = 1 + @dataclass class FlashCausalLMBatch(Batch): @@ -32,23 +121,29 @@ class FlashCausalLMBatch(Batch): input_ids: torch.Tensor position_ids: torch.Tensor - # Indices to copy present to the correct indices is the pre-allocated past key values - past_present_indices: torch.Tensor - - # tensor of length b holding starting offset of each sequence - start_seq: torch.Tensor - # tensor of length b holding ending offset of each sequence - end_seq: torch.Tensor # tensor of length b holding starting offset of each sequence, only used in prefill start_seq_prefill: Optional[torch.Tensor] # tensor of length b holding ending offset of each sequence, only used in prefill end_seq_prefill: Optional[torch.Tensor] - # tensor of length b holding starting offset of each query sequence, only used in decode - start_seq_q: Optional[torch.Tensor] - # tensor of length b holding ending offset of each query sequence, only used in decode - end_seq_q: Optional[torch.Tensor] - # past key values, only used in decode - past_key_values: Optional[torch.Tensor] + + # Paged Attention values + + # Set when creating the batch + # CPU tensor of length b indicating the start of each sequence in slots + start_slots: torch.Tensor + # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode + slot_indices: torch.Tensor + # List of tuple of ints representing the number of blocks and slots needed by each sequence + needed_blocks_slots: Optional[List[Tuple[int, int]]] + + # Set in prefill by the CacheManager + # list of length b of list of length s_i // block_size + block_tables: Optional[List[List[int]]] + # tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences + block_tables_tensor: Optional[torch.Tensor] + # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences + slots: Optional[torch.Tensor] + max_seqlen: int # Prefill metadata tensors to efficiently compute logprobs @@ -62,6 +157,7 @@ class FlashCausalLMBatch(Batch): # Lengths of all generations present in the batch input_lengths: List[int] + input_lengths_tensor: torch.Tensor prefix_offsets: List[Optional[int]] read_offsets: List[Optional[int]] @@ -69,15 +165,17 @@ class FlashCausalLMBatch(Batch): next_token_chooser: HeterogeneousNextTokenChooser stopping_criterias: List[StoppingCriteria] - # Maximum number of tokens this batch will grow to - max_tokens: int + # Number of blocks in this batch + blocks: int + # Maximum number of blocks + max_blocks: int def to_pb(self) -> generate_pb2.CachedBatch: return generate_pb2.CachedBatch( id=self.batch_id, request_ids=[r.id for r in self.requests], size=len(self), - max_tokens=self.max_tokens, + max_tokens=self.blocks * BLOCK_SIZE, ) @classmethod @@ -99,12 +197,11 @@ class FlashCausalLMBatch(Batch): )["input_ids"] position_ids = [] - past_present_indices = [] - start_seq = [] - end_seq = [] start_seq_prefill = [] end_seq_prefill = [] - max_seqlen = 0 + needed_blocks_slots = [] + start_slots = [] + slot_indices = [] input_lengths = [] prefix_offsets = [] @@ -126,7 +223,10 @@ class FlashCausalLMBatch(Batch): cumulative_max_length = 0 prefill_out_cumulative_length = 0 + blocks = 0 + max_seqlen = 0 max_length = 0 + max_blocks = 0 # Parse batch for i, (r, tokenized_input) in enumerate( @@ -138,7 +238,6 @@ class FlashCausalLMBatch(Batch): tokenized_input = tokenized_input[-r.truncate :] input_length = len(tokenized_input) - max_seqlen = max(max_seqlen, input_length) input_lengths.append(input_length) prefix_offsets.append(input_length - 5) @@ -153,8 +252,6 @@ class FlashCausalLMBatch(Batch): # Add cumulative lengths of all previous inputs start_seq_prefill.append(cumulative_length) end_seq_prefill.append(cumulative_length + input_length) - start_seq.append(cumulative_max_length) - end_seq.append(cumulative_max_length + input_length) next_token_chooser_parameters.append(r.parameters) @@ -164,6 +261,21 @@ class FlashCausalLMBatch(Batch): max_new_tokens = stopping_criteria.max_new_tokens stopping_criterias.append(stopping_criteria) + # Paged attention + # Remove one as the first token des not have a past + total_tokens = input_length + max_new_tokens - 1 + needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) + blocks += needed_blocks + needed_blocks_slots.append((needed_blocks, total_tokens)) + start_slots.append(cumulative_max_length) + + request_slot_indices = torch.arange( + cumulative_max_length, + cumulative_max_length + input_length, + dtype=torch.int64, + ) + slot_indices.append(request_slot_indices) + all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs @@ -184,22 +296,17 @@ class FlashCausalLMBatch(Batch): prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_out_cumulative_length += 1 - request_past_present_indices = torch.arange( - cumulative_max_length, - cumulative_max_length + input_length, - dtype=torch.int64, - ) - past_present_indices.append(request_past_present_indices) - # Update - # Remove one as the first token des not have a past cumulative_length += input_length - cumulative_max_length += input_length + max_new_tokens - 1 + cumulative_max_length += total_tokens + max_seqlen = max(max_seqlen, input_length) + max_blocks = max(max_blocks, needed_blocks) max_length = max(max_length, input_length + max_new_tokens) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype, device ) + start_slots = torch.tensor(start_slots, dtype=torch.int64) # Padded all_input_ids_tensor all_input_ids_tensor = np.zeros( @@ -212,34 +319,28 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor = torch.tensor( all_input_ids_tensor, dtype=torch.int64, device=device ) - start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32) - end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32) if len(pb.requests) > 1: input_ids = np.concatenate(all_input_ids, dtype=np.int64) position_ids = torch.cat(position_ids) - - past_present_indices = np.concatenate(past_present_indices, dtype=np.int64) - - start_seq_prefill = torch.tensor( - start_seq_prefill, device=device, dtype=torch.int32 - ) - end_seq_prefill = torch.tensor( - end_seq_prefill, device=device, dtype=torch.int32 - ) + slot_indices = torch.cat(slot_indices) else: input_ids = all_input_ids[0] position_ids = position_ids[0] + slot_indices = slot_indices[0] - past_present_indices = past_present_indices[0] - - start_seq_prefill = start_seq - end_seq_prefill = end_seq + start_seq_prefill = torch.tensor( + start_seq_prefill, device=device, dtype=torch.int32 + ) + end_seq_prefill = torch.tensor( + end_seq_prefill, device=device, dtype=torch.int32 + ) + position_ids = position_ids.to(device) + slot_indices = slot_indices.to(device) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) - position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device) - past_present_indices = torch.tensor( - past_present_indices, device=device, dtype=torch.int64 + input_lengths_tensor = torch.tensor( + input_lengths, dtype=torch.int32, device=device ) if all_prefill_logprobs: @@ -262,26 +363,28 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, - past_present_indices=past_present_indices, - start_seq=start_seq, - end_seq=end_seq, start_seq_prefill=start_seq_prefill, end_seq_prefill=end_seq_prefill, - start_seq_q=None, - end_seq_q=None, + start_slots=start_slots, + slot_indices=slot_indices, + needed_blocks_slots=needed_blocks_slots, + block_tables=None, + block_tables_tensor=None, + slots=None, max_seqlen=max_seqlen, prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, prefill_cu_outlens=prefill_cu_outlens, - past_key_values=None, input_lengths=input_lengths, + input_lengths_tensor=input_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, - max_tokens=cumulative_max_length, + blocks=blocks, + max_blocks=max_blocks, ) @tracer.start_as_current_span("filter") @@ -294,28 +397,24 @@ class FlashCausalLMBatch(Batch): device = self.input_ids.device - # Cumulative length - cumulative_max_length = 0 - # New values after filtering requests_idx_mapping = {} # Used to index into tensors indices = [] - # past indices to keep - past_indices = torch.zeros( - self.past_key_values.shape[0], dtype=torch.bool, device=device + # slots to keep after filtering + slot_filtering_indices = torch.zeros( + self.slots.shape[0], dtype=torch.bool, device=device ) # Create on CPU to only move to GPU once instead of at every copy - start_seq = torch.empty(len(request_ids), dtype=torch.int32) - end_seq = torch.empty(len(request_ids), dtype=torch.int32) - start_seq_q = self.start_seq_q[: len(request_ids)] - end_seq_q = self.end_seq_q[: len(request_ids)] + slot_indices = torch.empty(len(request_ids), dtype=torch.int64) max_seqlen = 0 requests = [] + start_slots = [] + block_tables = [] all_input_ids = [] input_lengths = [] @@ -324,6 +423,11 @@ class FlashCausalLMBatch(Batch): stopping_criterias = [] + blocks = 0 + max_blocks = 0 + # Cumulative length + cumulative_max_length = 0 + for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] indices.append(idx) @@ -348,28 +452,51 @@ class FlashCausalLMBatch(Batch): stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) + request_block_table = self.block_tables[idx] + blocks += len(request_block_table) + block_tables.append(request_block_table) + start_slots.append(cumulative_max_length) + # Copy to tensor (CPU) - start_seq[i] = cumulative_max_length - end_seq[i] = cumulative_max_length + request_input_length + slot_indices[i] = cumulative_max_length + request_input_length - 1 # Set slice - past_indices[ - self.start_seq[idx] : self.end_seq[idx] + remaining_tokens - 1 + slot_filtering_indices[ + self.start_slots[idx] : self.start_slots[idx] + + request_input_length + + remaining_tokens + - 1 ] = True cumulative_max_length += request_input_length + remaining_tokens - 1 + max_blocks = max(max_blocks, len(request_block_table)) + + global CACHE_MANAGER + block_indices_to_free = [] + # Iterate on all requests + for i, r in enumerate(self.requests): + # Filter requests that are not part of the new batch + if r.id not in requests_idx_mapping.keys(): + block_indices_to_free.extend(self.block_tables[i]) + # Free blocks + CACHE_MANAGER.free(block_indices_to_free) + # Needed to avoid dropping blocks when the batches will go out of scope + self.block_tables = None + # Index into tensors input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices] + block_tables_tensor = self.block_tables_tensor[indices] + input_lengths_tensor = self.input_lengths_tensor[indices] + slots = self.slots[slot_filtering_indices] next_token_chooser = self.next_token_chooser.filter(indices) - past_key_values = self.past_key_values[past_indices] + + start_slots = torch.tensor(start_slots, dtype=torch.int64) # Move to GPU now that we have the whole tensor - start_seq = start_seq.to(device) - end_seq = end_seq.to(device) - past_present_indices = end_seq - 1 + slot_indices = slot_indices.to(device) return FlashCausalLMBatch( batch_id=self.batch_id, @@ -377,26 +504,28 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, - past_present_indices=past_present_indices, - start_seq=start_seq, - end_seq=end_seq, start_seq_prefill=None, end_seq_prefill=None, - start_seq_q=start_seq_q, - end_seq_q=end_seq_q, + start_slots=start_slots, + slot_indices=slot_indices, + needed_blocks_slots=None, + block_tables=block_tables, + block_tables_tensor=block_tables_tensor, + slots=slots, max_seqlen=max_seqlen, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, - past_key_values=past_key_values, input_lengths=input_lengths, + input_lengths_tensor=input_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, - max_tokens=cumulative_max_length, + blocks=blocks, + max_blocks=max_blocks, ) @classmethod @@ -406,22 +535,46 @@ class FlashCausalLMBatch(Batch): requests = [] requests_idx_mapping = {} - total_batch_size = sum([len(b) for b in batches]) - - dtype = batches[0].past_key_values.dtype - device = batches[0].input_ids.device + blocks = 0 + total_batch_size = 0 + total_slots = 0 + max_blocks = 0 + max_length = 0 + max_seqlen = 0 + for b in batches: + total_batch_size += len(b) + total_slots += len(b.slots) + blocks += b.blocks + max_blocks = max(max_blocks, b.max_blocks) + max_seqlen = max(max_seqlen, b.max_seqlen) + max_length = max( + max_length, + max( + input_length + + stopping_criteria.max_new_tokens + - stopping_criteria.current_tokens + for input_length, stopping_criteria in zip( + b.input_lengths, b.stopping_criterias + ) + ), + ) input_ids = batches[0].input_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size) - start_seq = batches[0].start_seq.new_empty(total_batch_size) - end_seq = batches[0].end_seq.new_empty(total_batch_size) - start_seq_q = torch.arange( - 0, total_batch_size, device=device, dtype=torch.int32 + slots = batches[0].slots.new_empty(total_slots) + slot_indices = batches[0].slot_indices.new_empty(total_batch_size) + input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( + total_batch_size + ) + block_tables_tensor = batches[0].block_tables_tensor.new_zeros( + (total_batch_size, max_blocks) + ) + all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( + (total_batch_size, max_length) ) - end_seq_q = start_seq_q + 1 - max_seqlen = 0 - past_key_values = [] + start_slots = [] + block_tables = [] all_input_ids = [] input_lengths = [] @@ -433,8 +586,7 @@ class FlashCausalLMBatch(Batch): # Cumulative length cumulative_batch_size = 0 - max_tokens = 0 - max_length = 0 + cumulative_slots = 0 for i, batch in enumerate(batches): requests.extend(batch.requests) @@ -448,16 +600,27 @@ class FlashCausalLMBatch(Batch): start_index = cumulative_batch_size end_index = cumulative_batch_size + len(batch) + slots_start_index = cumulative_slots + slots_end_index = cumulative_slots + len(batch.slots) # Copy tensors (GPU) input_ids[start_index:end_index] = batch.input_ids position_ids[start_index:end_index] = batch.position_ids + slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots + input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor + slots[slots_start_index:slots_end_index] = batch.slots - start_seq[start_index:end_index] = batch.start_seq + max_tokens - end_seq[start_index:end_index] = batch.end_seq + max_tokens + all_input_ids_tensor[ + start_index:end_index, : batch.all_input_ids_tensor.shape[1] + ] = batch.all_input_ids_tensor[:, :max_length] - max_seqlen = max(max_seqlen, batch.max_seqlen) + block_tables_tensor[ + start_index:end_index, : batch.block_tables_tensor.shape[1] + ] = batch.block_tables_tensor[:, :max_blocks] + start_slots.append(batch.start_slots + cumulative_slots) + + block_tables.extend(batch.block_tables) all_input_ids.extend(batch.all_input_ids) input_lengths.extend(batch.input_lengths) @@ -466,73 +629,59 @@ class FlashCausalLMBatch(Batch): next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) stopping_criterias.extend(batch.stopping_criterias) - past_key_values.append(batch.past_key_values) # Update cumulative_batch_size += len(batch) - max_tokens += batch.max_tokens - max_length = max( - max_length, - max( - input_length - + stopping_criteria.max_new_tokens - - stopping_criteria.current_tokens - for input_length, stopping_criteria in zip( - batch.input_lengths, batch.stopping_criterias - ) - ), - ) + cumulative_slots += len(batch.slots) - past_key_values = torch.cat(past_key_values, dim=0) - past_present_indices = end_seq - 1 - - all_input_ids_tensor = torch.zeros( - (total_batch_size, max_length), dtype=torch.int64, device=device - ) - - cumulative_batch_size = 0 - for i, batch in enumerate(batches): - start_index = cumulative_batch_size - end_index = cumulative_batch_size + len(batch) - - all_input_ids_tensor[ - start_index:end_index, : batch.all_input_ids_tensor.shape[1] - ] = batch.all_input_ids_tensor[:, :max_length] - - cumulative_batch_size += len(batch) + start_slots = torch.concat(start_slots) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, dtype=dtype, device=device + next_token_chooser_parameters, + dtype=batches[0].next_token_chooser.dtype, + device=batches[0].next_token_chooser.device, ) + # Needed to avoid dropping blocks when the batches will go out of scope + for b in batches: + b.block_tables = None + return FlashCausalLMBatch( batch_id=batches[0].batch_id, requests=requests, requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, - past_present_indices=past_present_indices, - start_seq=start_seq, - end_seq=end_seq, start_seq_prefill=None, end_seq_prefill=None, - start_seq_q=start_seq_q, - end_seq_q=end_seq_q, + start_slots=start_slots, + slot_indices=slot_indices, + needed_blocks_slots=None, + block_tables=block_tables, + block_tables_tensor=block_tables_tensor, + slots=slots, max_seqlen=max_seqlen, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, - past_key_values=past_key_values, input_lengths=input_lengths, + input_lengths_tensor=input_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, - max_tokens=max_tokens, + blocks=blocks, + max_blocks=max_blocks, ) + def __del__(self): + if self.block_tables is not None and self.block_tables: + global CACHE_MANAGER + # Free blocks + CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables))) + def __len__(self): return len(self.requests) @@ -540,32 +689,19 @@ class FlashCausalLMBatch(Batch): class FlashCausalLM(Model): def __init__( self, - model_cls: Type[PreTrainedModel], - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - trust_remote_code: bool = False, + model: torch.nn.Module, + tokenizer: PreTrainedTokenizerBase, + num_layers: int, + num_kv_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + rank: int = 0, + world_size: int = 1, ): - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.float16 - else: - raise NotImplementedError("FlashCausalLM is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - model = model_cls.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - load_in_8bit=quantize == "bitsandbytes", - trust_remote_code=trust_remote_code, - ).to(device) + self.num_layers = num_layers + self.num_kv_heads = num_kv_heads + self.head_size = head_size super(FlashCausalLM, self).__init__( model=model, @@ -573,12 +709,38 @@ class FlashCausalLM(Model): requires_padding=False, dtype=dtype, device=device, + rank=rank, + world_size=world_size, ) @property def batch_type(self) -> Type[FlashCausalLMBatch]: return FlashCausalLMBatch + def warmup(self, batch: FlashCausalLMBatch, max_total_tokens: int): + global CACHE_MANAGER + + torch.cuda.empty_cache() + try: + CACHE_MANAGER = CacheManager( + # Adds some wiggle room + math.ceil(max_total_tokens / BLOCK_SIZE) + 10, + self.num_layers, + self.num_kv_heads, + self.head_size, + self.dtype, + self.device, + ) + _, batch = self.generate_token(batch) + except Exception as e: + logger.exception( + f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} " + f"prefill tokens. " + f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`" + ) + raise e + del batch + def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: return self.tokenizer.decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False @@ -588,28 +750,27 @@ class FlashCausalLM(Model): self, input_ids: torch.Tensor, position_ids: torch.Tensor, - start_seq: torch.Tensor, - end_seq: torch.Tensor, - start_seq_q: Optional[torch.Tensor], - end_seq_q: Optional[torch.Tensor], + start_seq_prefill: Optional[torch.Tensor], + end_seq_prefill: Optional[torch.Tensor], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, max_s: int, - past_present_indices: torch.Tensor, - past_key_values: Optional = None, - pre_allocate_past_size: Optional[int] = None, lm_head_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: + global CACHE_MANAGER + # Model Forward return self.model.forward( input_ids=input_ids, position_ids=position_ids, - start_seq=start_seq, - end_seq=end_seq, - start_seq_q=start_seq_q, - end_seq_q=end_seq_q, + start_seq_prefill=start_seq_prefill, + end_seq_prefill=end_seq_prefill, + kv_cache=CACHE_MANAGER.kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, max_s=max_s, - past_present_indices=past_present_indices, - past_key_values=past_key_values, - pre_allocate_past_size=pre_allocate_past_size, lm_head_indices=lm_head_indices, ) @@ -617,31 +778,22 @@ class FlashCausalLM(Model): def generate_token( self, batch: FlashCausalLMBatch ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: - prefill = batch.past_key_values is None + prefill = batch.start_seq_prefill is not None prefill_logprobs = batch.prefill_next_token_indices is not None - if prefill: - # Ask to pre-allocate kv to its max size - # == Sum over batch size (number of tokens + max_new_tokens) - batch size - pre_allocate_past_size = batch.max_tokens - start_seq = batch.start_seq_prefill - end_seq = batch.end_seq_prefill - else: - pre_allocate_past_size = None - start_seq = batch.start_seq - end_seq = batch.end_seq + if batch.needed_blocks_slots: + # Allocate blocks to this batch + CACHE_MANAGER.allocate(batch) - out, present = self.forward( + out = self.forward( batch.input_ids, batch.position_ids, - start_seq, - end_seq, - batch.start_seq_q, - batch.end_seq_q, + batch.start_seq_prefill, + batch.end_seq_prefill, + batch.block_tables_tensor, + batch.slots[batch.slot_indices], + batch.input_lengths_tensor, batch.max_seqlen, - batch.past_present_indices, - batch.past_key_values, - pre_allocate_past_size, batch.prefill_head_indices, ) @@ -662,12 +814,8 @@ class FlashCausalLM(Model): # When batch == 1, we will just use the batch.input_ids values directly prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) - # Create batch.start_seq_q and batch.end_seq_q for decode - batch.start_seq_q = torch.arange( - 0, len(batch), device=self.device, dtype=torch.int32 - ) - batch.end_seq_q = batch.start_seq_q + 1 next_position_ids = batch.position_ids.new_empty(len(batch)) + batch.slot_indices = batch.slot_indices[batch.end_seq_prefill - 1] # We do not need start_seq_prefill and end_seq_prefill anymore batch.start_seq_prefill = None batch.end_seq_prefill = None @@ -731,8 +879,8 @@ class FlashCausalLM(Model): # Set values in batch batch.input_ids = next_input_ids batch.position_ids = next_position_ids + 1 - batch.past_present_indices = batch.end_seq - batch.end_seq = batch.end_seq + 1 + batch.input_lengths_tensor += 1 + batch.slot_indices += 1 if prefill and prefill_logprobs: # Get prefill logprobs @@ -755,7 +903,6 @@ class FlashCausalLM(Model): batch.read_offsets, batch.stopping_criterias, batch.all_input_ids, - batch.all_input_ids_tensor, batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, next_token_ids, @@ -770,7 +917,6 @@ class FlashCausalLM(Model): read_offset, stopping_criteria, all_input_ids, - all_input_ids_tensor, do_sample, seed, next_token_id, @@ -845,19 +991,20 @@ class FlashCausalLM(Model): generations.append(generation) - new_input_length = input_length + 1 - # Update values - batch.input_lengths[i] = new_input_length + batch.input_lengths[i] = input_length + 1 batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids + if stopped: + del batch + # No need to return a batch if we know that all requests stopped + return generations, None + batch.prefill_cu_outlens = None batch.prefill_head_indices = None batch.prefill_next_token_indices = None batch.max_seqlen = batch.max_seqlen + 1 - batch.past_key_values = present - # No need to return a batch if we know that all requests stopped - return generations, batch if not stopped else None + return generations, batch diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index a80d58cb..2c59f01e 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -64,10 +64,12 @@ class FlashLlama(FlashCausalLM): model = FlashLlamaForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) - super(FlashCausalLM, self).__init__( + super(FlashLlama, self).__init__( model=model, tokenizer=tokenizer, - requires_padding=False, + num_layers=len(model.model.layers), + num_kv_heads=model.model.num_heads, + head_size=model.model.head_size, dtype=dtype, device=device, rank=rank, diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 4847571d..e64af0c6 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -55,10 +55,12 @@ class FlashNeoXSharded(FlashCausalLM): model = FlashGPTNeoXForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) - super(FlashCausalLM, self).__init__( + super(FlashNeoXSharded, self).__init__( model=model.to(device), tokenizer=tokenizer, - requires_padding=False, + num_layers=len(model.gpt_neox.layers), + num_kv_heads=model.gpt_neox.num_heads, + head_size=model.gpt_neox.head_size, dtype=dtype, device=device, rank=rank, diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 5f963bfb..a55f9118 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -55,10 +55,12 @@ class FlashRWSharded(FlashCausalLM): model = FlashRWForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) - super(FlashCausalLM, self).__init__( + super(FlashRWSharded, self).__init__( model=model.to(device), tokenizer=tokenizer, - requires_padding=False, + num_layers=len(model.transformer.h), + num_kv_heads=model.transformer.cache_size, + head_size=model.transformer.head_size, dtype=dtype, device=device, rank=rank, diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index a71c0061..ef202785 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -52,17 +52,22 @@ class FlashSantacoderSharded(FlashCausalLM): torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group, - aliases = {"transformer.wte.weight": ["lm_head.weight"]} + filenames, + device=device, + dtype=dtype, + process_group=self.process_group, + aliases={"transformer.wte.weight": ["lm_head.weight"]}, ) model = FlashSantacoderForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) - super(FlashCausalLM, self).__init__( + super(FlashSantacoderSharded, self).__init__( model=model.to(device), tokenizer=tokenizer, - requires_padding=False, + num_layers=len(model.transformer.h), + num_kv_heads=1, + head_size=model.transformer.head_size, dtype=dtype, device=device, rank=rank, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 6b8472a5..f8460fc2 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -22,6 +22,9 @@ class Model(ABC): rank: int = 0, world_size: int = 1, ): + if torch.cuda.is_available(): + torch.cuda.set_per_process_memory_fraction(1.0) + self.model = model.eval() self.tokenizer = tokenizer self.all_special_ids = set(tokenizer.all_special_ids) @@ -55,6 +58,9 @@ class Model(ABC): def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]: raise NotImplementedError + def warmup(self, batch: B, max_total_tokens: int): + self.generate_token(batch) + def decode_token( self, all_input_ids: List[int], diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 3ad5698c..999b6637 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -127,7 +127,7 @@ class Seq2SeqLMBatch(Batch): read_offsets.append(1) all_decoder_input_ids = decoder_input_ids.view(-1).split(1) - max_tokens = len(inputs) * max_input_length + max_decode_tokens + max_tokens = len(inputs) * (max_input_length + max_decode_tokens) return cls( batch_id=pb.id, diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index e1bd8412..6cc5beeb 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -53,6 +53,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) + async def Warmup(self, request, context): + batch = self.model.batch_type.from_pb( + request.batch, self.model.tokenizer, self.model.dtype, self.model.device + ) + self.model.warmup(batch, request.max_total_tokens) + return generate_pb2.WarmupResponse() + async def Prefill(self, request, context): batch = self.model.batch_type.from_pb( request.batch, self.model.tokenizer, self.model.dtype, self.model.device diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index e6e512bc..b83af591 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -216,6 +216,8 @@ class HeterogeneousNextTokenChooser: self.seeds = seeds self.do_sample = do_sample + self.dtype = dtype + self.device = device def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor): if self.watermark_processor is not None: diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 9d371834..83d9df68 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -5,7 +5,14 @@ import torch class Weights: - def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None): + def __init__( + self, + filenames: List[Path], + device, + dtype, + process_group, + aliases: Optional[Dict[str, List[str]]] = None, + ): routing = {} for filename in filenames: with safe_open(filename, framework="pytorch") as f: @@ -43,7 +50,7 @@ class Weights: return str(filename), tensor_name def _get_slice(self, tensor_name: str): - filename, tensor_name= self.get_filename(tensor_name) + filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) return slice_ @@ -94,12 +101,20 @@ class Weights: def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): if quantize == "gptq": try: - qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1) + qweight = torch.cat( + [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 + ) except RuntimeError: - raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) - qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1) - scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1) + qzeros = torch.cat( + [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 + ) + scales = torch.cat( + [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 + ) w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] for w2 in w[1:]: torch.testing.assert_close(w2, w[0]) @@ -118,7 +133,9 @@ class Weights: try: qweight = self.get_sharded(f"{prefix}.qweight", dim=0) except RuntimeError: - raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) qzeros = self.get_tensor(f"{prefix}.qzeros") scales = self.get_tensor(f"{prefix}.scales") g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)