From 0d794af6a5730efcfb23aa898232c6df3afec347 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Mon, 12 Feb 2024 10:09:29 +0100 Subject: [PATCH] feat: experimental support for cuda graphs (#1428) Co-authored-by: Nicolas Patry --- Dockerfile | 4 +- Dockerfile_amd | 2 +- docs/source/basic_tutorials/launcher.md | 8 ++ integration-tests/conftest.py | 5 +- launcher/src/main.rs | 14 +- server/Makefile-awq | 6 +- .../exllama_kernels/cuda_func/q4_matmul.cu | 5 +- .../exllama_kernels/cuda_func/q4_matmul.cuh | 8 +- .../exllama_kernels/cuda_func/q4_matrix.cu | 11 +- .../exllama_kernels/exllama_ext.cpp | 6 +- .../exllamav2_kernels/cuda/q_gemm.cu | 5 +- .../exllamav2_kernels/cuda/q_matrix.cu | 11 +- .../custom_modeling/flash_mistral_modeling.py | 8 +- .../custom_modeling/flash_mixtral_modeling.py | 8 +- .../models/flash_causal_lm.py | 131 ++++++++++++++++-- .../models/flash_mistral.py | 123 ++++++++++++++-- .../text_generation_server/utils/weights.py | 3 +- 17 files changed, 300 insertions(+), 58 deletions(-) diff --git a/Dockerfile b/Dockerfile index 6818005f..252c5885 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Rust builder -FROM lukemathwalker/cargo-chef:latest-rust-1.71 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.75 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse @@ -166,7 +166,7 @@ FROM kernel-builder as megablocks-builder RUN pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e # Text Generation Inference base image -FROM nvidia/cuda:12.1.0-base-ubuntu20.04 as base +FROM nvidia/cuda:12.1.0-base-ubuntu22.04 as base # Conda env ENV PATH=/opt/conda/bin:$PATH \ diff --git a/Dockerfile_amd b/Dockerfile_amd index d2b6f897..c2ec4a6d 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -1,5 +1,5 @@ # Rust builder -FROM lukemathwalker/cargo-chef:latest-rust-1.71 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.75 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index ba54f058..be31a7a4 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -205,6 +205,14 @@ Options: [env: MAX_BATCH_SIZE=] +``` +## ENABLE_CUDA_GRAPHS +```shell + --enable-cuda-graphs + Enable experimental support for cuda graphs + + [env: ENABLE_CUDA_GRAPHS=] + ``` ## HOSTNAME ```shell diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 4cb4ca59..efeda08d 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -317,7 +317,10 @@ def launcher(event_loop): gpu_count = num_shard if num_shard is not None else 1 - env = {"LOG_LEVEL": "info,text_generation_router=debug"} + env = { + "LOG_LEVEL": "info,text_generation_router=debug", + "ENABLE_CUDA_GRAPHS": "true", + } if not use_flash_attention: env["USE_FLASH_ATTENTION"] = "false" diff --git a/launcher/src/main.rs b/launcher/src/main.rs index a51742e6..8367ef81 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -284,6 +284,10 @@ struct Args { #[clap(long, env)] max_batch_size: Option, + /// Enable experimental support for cuda graphs + #[clap(long, env)] + enable_cuda_graphs: bool, + /// The IP address to listen on #[clap(default_value = "0.0.0.0", long, env)] hostname: String, @@ -407,6 +411,7 @@ fn shard_manager( disable_custom_kernels: bool, watermark_gamma: Option, watermark_delta: Option, + enable_cuda_graphs: bool, cuda_memory_fraction: f32, rope_scaling: Option, rope_factor: Option, @@ -488,7 +493,7 @@ fn shard_manager( envs.push(("WORLD_SIZE".into(), world_size.to_string().into())); envs.push(("MASTER_ADDR".into(), master_addr.into())); envs.push(("MASTER_PORT".into(), master_port.to_string().into())); - envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into())); + envs.push(("TORCH_NCCL_AVOID_RECORD_STREAMS".into(), "1".into())); // CUDA memory fraction envs.push(( @@ -538,6 +543,11 @@ fn shard_manager( )); }; + // Enable experimental support for cuda graphs + if enable_cuda_graphs { + envs.push(("ENABLE_CUDA_GRAPHS".into(), "True".into())) + } + // If disable_custom_kernels is true, pass it to the shard as an env var if disable_custom_kernels { envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into())) @@ -926,6 +936,7 @@ fn spawn_shards( let disable_custom_kernels = args.disable_custom_kernels; let watermark_gamma = args.watermark_gamma; let watermark_delta = args.watermark_delta; + let enable_cuda_graphs = args.enable_cuda_graphs; let cuda_memory_fraction = args.cuda_memory_fraction; let rope_scaling = args.rope_scaling; let rope_factor = args.rope_factor; @@ -947,6 +958,7 @@ fn spawn_shards( disable_custom_kernels, watermark_gamma, watermark_delta, + enable_cuda_graphs, cuda_memory_fraction, rope_scaling, rope_factor, diff --git a/server/Makefile-awq b/server/Makefile-awq index 80e78c08..5dd9dbaa 100644 --- a/server/Makefile-awq +++ b/server/Makefile-awq @@ -1,8 +1,10 @@ -awq_commit := f084f40bd996f3cf3a0633c1ad7d9d476c318aaa +# Fork that adds only the correct stream to this kernel in order +# to make cuda graphs work. +awq_commit := bd1dc2d5254345cc76ab71894651fb821275bdd4 awq: rm -rf llm-awq - git clone https://github.com/mit-han-lab/llm-awq + git clone https://github.com/huggingface/llm-awq build-awq: awq cd llm-awq/ && git fetch && git checkout $(awq_commit) diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu index 09126efe..1b0f7956 100644 --- a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu +++ b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu @@ -1,5 +1,6 @@ #include "q4_matmul.cuh" #include "column_remap.cuh" +#include #include "../util.cuh" #include "../matrix.cuh" #include "../cu_compat.cuh" @@ -224,8 +225,8 @@ void q4_matmul_recons_cuda const int x_height, Q4Matrix* w, half* out, - const cublasHandle_t handle, - bool no_zero + bool no_zero, + const cublasHandle_t handle ) { int height = x_height; diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh index 63611790..4c7a6669 100644 --- a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh +++ b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh @@ -19,8 +19,8 @@ void q4_matmul_cuda const int x_height, const Q4Matrix* w, half* out, - bool no_zero = false, - cudaStream_t alt_stream = NULL + bool no_zero, + cudaStream_t alt_stream ); void q4_matmul_recons_cuda @@ -30,8 +30,8 @@ void q4_matmul_recons_cuda const int x_height, Q4Matrix* w, half* out, - const cublasHandle_t handle, - bool no_zero = false + bool no_zero, + const cublasHandle_t handle ); #endif diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu index 2867a8d0..1f32e6b8 100644 --- a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu +++ b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu @@ -1,5 +1,6 @@ // Adapted from turboderp exllama: https://github.com/turboderp/exllama +#include #include "q4_matrix.cuh" #include #include "../util.cuh" @@ -90,7 +91,7 @@ __global__ void make_sequential_kernel int w2_row_shift = w2_subrow << 2; int wnew2_row_shift = i << 2; - uint64_t src = w2[w2_row * w2_stride + w2_column]; + uint64_t src = w2[w2_row * w2_stride + w2_column]; src >>= w2_row_shift; src &= 0x0000000f0000000f; src <<= wnew2_row_shift; @@ -146,7 +147,8 @@ void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx) dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1); dim3 blocks(width / UNSHUF_BLOCKSIZE_X / 2, height / 8, 1); - make_sequential_kernel<<>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + make_sequential_kernel<<>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width); // Replace qweights @@ -213,5 +215,6 @@ void Q4Matrix::reconstruct(half* out) 1 ); - reconstruct_kernel<<>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize); -} \ No newline at end of file + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + reconstruct_kernel<<>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize); +} diff --git a/server/exllama_kernels/exllama_kernels/exllama_ext.cpp b/server/exllama_kernels/exllama_kernels/exllama_ext.cpp index b786988b..f2df80e8 100644 --- a/server/exllama_kernels/exllama_kernels/exllama_ext.cpp +++ b/server/exllama_kernels/exllama_kernels/exllama_ext.cpp @@ -183,6 +183,7 @@ void q4_matmul int x_height = x.size(0); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd) { q4_matmul_cuda @@ -191,7 +192,9 @@ void q4_matmul (half*) x.data_ptr(), x_height, wm, - (half*) out.data_ptr() + (half*) out.data_ptr(), + false, + stream ); } else @@ -203,6 +206,7 @@ void q4_matmul x_height, wm, (half*) out.data_ptr(), + false, at::cuda::getCurrentCUDABlasHandle() ); } diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu index b4e4cf22..5b99f1ba 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu @@ -38,6 +38,7 @@ void gemm_half_q_half_cuda_part bool mul_r_weights ) { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (!b->is_gptq) { dim3 blockDim, gridDim; @@ -50,7 +51,7 @@ void gemm_half_q_half_cuda_part fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights); - kernel<<>> + kernel<<>> ( a, b->cuda_q_weight, @@ -91,7 +92,7 @@ void gemm_half_q_half_cuda_part // print_global_mem(r_weights, 1, 1, 1); // DBGI(r_weights_stride); - kernel<<>> + kernel<<>> ( a, b->cuda_q_weight, diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu index 7a0038b4..f7a91e29 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu @@ -168,8 +168,9 @@ QMatrix::QMatrix blockDim.y = 1; gridDim.x = DIVIDE(width, THREADS_X); gridDim.y = 1; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - shuffle_kernel<<>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2); + shuffle_kernel<<>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2); } QMatrix::~QMatrix() @@ -475,11 +476,12 @@ void QMatrix::reconstruct(half* out) blockDim.x = BLOCK_KN_SIZE; blockDim.y = 1; gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (!is_gptq) { gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); - reconstruct_kernel<<>> + reconstruct_kernel<<>> ( cuda_q_weight, cuda_q_perm, @@ -502,7 +504,7 @@ void QMatrix::reconstruct(half* out) else { gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4); - reconstruct_gptq_kernel<<>> + reconstruct_gptq_kernel<<>> ( cuda_q_weight, cuda_q_perm, @@ -563,6 +565,7 @@ __global__ void make_sequential_kernel bool QMatrix::make_sequential(const uint32_t* cpu_g_idx) { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); uint32_t* cuda_new_qweight = NULL; cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); if (err != cudaSuccess) { @@ -621,7 +624,7 @@ bool QMatrix::make_sequential(const uint32_t* cpu_g_idx) gridDim.x = DIVIDE(width, THREADS_X); gridDim.y = height / 8; - make_sequential_kernel<<>> + make_sequential_kernel<<>> ( cuda_q_weight, cuda_new_qweight, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 0fc4e1b3..7b45be57 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -425,6 +425,11 @@ class FlashMistralForCausalLM(torch.nn.Module): weights=weights, ) self.max_past = config.sliding_window + self.max_past_tensor = ( + torch.tensor(config.sliding_window, device=weights.device) + if self.max_past is not None + else None + ) def forward( self, @@ -446,8 +451,7 @@ class FlashMistralForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - max_s = min(self.max_past, max_s) - input_lengths = torch.clamp(input_lengths, max=self.max_past) + input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) hidden_states = self.model( input_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 61488ec4..c91b2224 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -816,6 +816,11 @@ class FlashMixtralForCausalLM(torch.nn.Module): weights=weights, ) self.max_past = config.sliding_window + self.max_past_tensor = ( + torch.tensor(config.sliding_window, device=weights.device) + if self.max_past is not None + else None + ) def forward( self, @@ -837,8 +842,7 @@ class FlashMixtralForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - max_s = min(self.max_past, max_s) - input_lengths = torch.clamp(input_lengths, max=self.max_past) + input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) hidden_states = self.model( input_ids, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 90776654..c7fda516 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1,4 +1,5 @@ import math +import os import time import itertools import torch @@ -6,6 +7,7 @@ import torch.distributed import numpy as np +from loguru import logger from dataclasses import dataclass from opentelemetry import trace from transformers import PreTrainedTokenizerBase @@ -31,6 +33,8 @@ from text_generation_server.utils.dist import MEMORY_FRACTION tracer = trace.get_tracer(__name__) +MEM_POOL = torch.cuda.graph_pool_handle() + @dataclass class FlashCausalLMBatch(Batch): @@ -62,7 +66,7 @@ class FlashCausalLMBatch(Batch): # Set in prefill by the CacheManager # list of length b of list of length s_i // block_size block_tables: Optional[List[List[int]]] - # tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences + # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences block_tables_tensor: Optional[torch.Tensor] # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences slots: Optional[torch.Tensor] @@ -663,6 +667,8 @@ class FlashCausalLM(Model): self.num_kv_heads = num_kv_heads self.head_size = head_size + self.cuda_graphs = {} + super(FlashCausalLM, self).__init__( model=model, tokenizer=tokenizer, @@ -678,7 +684,60 @@ class FlashCausalLM(Model): def batch_type(self) -> Type[FlashCausalLMBatch]: return FlashCausalLMBatch + def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): + input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) + position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) + slots = torch.arange(bs, dtype=torch.int32, device=self.device) + input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s + block_tables = ( + torch.arange(max_bt, dtype=torch.int32, device=self.device) + .repeat(bs) + .reshape((bs, max_bt)) + ) + kv_cache = get_cache_manager().kv_cache + + self.cuda_graphs[bs] = { + "input_ids": input_ids, + "position_ids": position_ids, + "kv_cache": kv_cache, + "block_tables": block_tables, + "slots": slots, + "input_lengths": input_lengths, + } + graph = torch.cuda.CUDAGraph() + self.cuda_graphs[bs]["graph"] = graph + + torch.cuda.synchronize() + # Run once outside to warmup + self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + lm_head_indices=None, + ) + torch.cuda.synchronize() + + with torch.cuda.graph(graph, pool=MEM_POOL): + self.cuda_graphs[bs]["logits"] = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + lm_head_indices=None, + ) + torch.cuda.synchronize() + def warmup(self, batch: FlashCausalLMBatch): + # The warmup batch is the biggest batch we could ever receive torch.cuda.empty_cache() try: cache_manager = set_cache_manager( @@ -690,6 +749,8 @@ class FlashCausalLM(Model): self.dtype, self.device, ) + max_bt = batch.max_blocks + max_s = max_bt * get_cache_manager().block_size _, batch, _ = self.generate_token(batch) except torch.cuda.OutOfMemoryError as e: raise RuntimeError( @@ -713,7 +774,8 @@ class FlashCausalLM(Model): ) num_blocks = ( - int(free_memory // total_cache_size) + # Leave 5% for some wiggle room + int((free_memory * 0.95) // total_cache_size) # Add batch.blocks as we allocated it above, so it is included in the peak memory. + cache_manager.num_blocks ) @@ -731,9 +793,19 @@ class FlashCausalLM(Model): self.device, ) + if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True": + try: + logger.info("Experimental support for Cuda Graphs is enabled") + # Warmup cuda graphs + for bs in [1, 2, 4] + [8 * i for i in range(8)]: + if self.speculate is None or self.speculate + 1 <= bs: + self.cuda_graph_warmup(bs, max_s, max_bt) + except Exception: + logger.exception(f"Decode cuda graph warmup failed") + return int(num_blocks * BLOCK_SIZE) - def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, batch: FlashCausalLMBatch) -> torch.Tensor: # Model Forward if batch.speculative_ids is not None: input_ids = batch.input_ids @@ -785,17 +857,48 @@ class FlashCausalLM(Model): max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices - return self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - input_lengths=input_lengths, - max_s=max_s, - lm_head_indices=lm_head_indices, - ) + bs = input_ids.shape[0] + padded_bs = bs + if bs == 3: + padded_bs = 4 + elif 3 < bs <= 8: + padded_bs = 8 + elif bs > 8: + padded_bs = (bs + 7) // 8 * 8 + + # Try to find an associated cuda graph + cuda_graph = self.cuda_graphs.get(padded_bs, None) + + if cu_seqlen_prefill is not None or cuda_graph is None or batch.speculative_ids is not None: + return self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + lm_head_indices=lm_head_indices, + ) + + # Copy inputs to the static inputs of the cuda graph + # Static inputs are potentially padded + cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids + cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids + cuda_graph["block_tables"][ + : block_tables.shape[0], : block_tables.shape[1] + ] = block_tables + cuda_graph["slots"].fill_(-1) + cuda_graph["slots"][: slots.shape[0]] = slots + cuda_graph["input_lengths"].zero_() + cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + + # Replay the graph + cuda_graph["graph"].replay() + + # Slice output to the correct shape + return cuda_graph["logits"][:bs] @tracer.start_as_current_span("generate_token") def generate_token( diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 8c6cb025..34a50194 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -35,6 +35,8 @@ tracer = trace.get_tracer(__name__) SLIDING_WINDOW: Optional[int] = None SLIDING_WINDOW_BLOCKS: Optional[int] = None +MEM_POOL = torch.cuda.graph_pool_handle() + # Adds windowing logic to FlashCausalLMBatch @dataclass @@ -332,6 +334,8 @@ class BaseFlashMistral(FlashCausalLM): model = model_cls(config, weights) + self.cuda_graphs = {} + torch.distributed.barrier(group=self.process_group) super(BaseFlashMistral, self).__init__( model=model, @@ -350,6 +354,60 @@ class BaseFlashMistral(FlashCausalLM): def batch_type(self) -> Type[FlashMistralBatch]: return FlashMistralBatch + def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): + input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) + position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) + slots = torch.arange(bs, dtype=torch.int32, device=self.device) + input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s + block_tables = ( + torch.arange(max_bt, dtype=torch.int32, device=self.device) + .repeat(bs) + .reshape((bs, max_bt)) + ) + kv_cache = get_cache_manager().kv_cache + + self.cuda_graphs[bs] = { + "input_ids": input_ids, + "position_ids": position_ids, + "kv_cache": kv_cache, + "block_tables": block_tables, + "slots": slots, + "input_lengths": input_lengths, + } + graph = torch.cuda.CUDAGraph() + self.cuda_graphs[bs]["graph"] = graph + + torch.cuda.synchronize() + # Run once outside to warmup + self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + prefill_cache_indices=None, + lm_head_indices=None, + ) + torch.cuda.synchronize() + + with torch.cuda.graph(graph, pool=MEM_POOL): + self.cuda_graphs[bs]["logits"] = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + prefill_cache_indices=None, + lm_head_indices=None, + ) + torch.cuda.synchronize() + def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]: # Model Forward if batch.speculative_ids is not None: @@ -401,21 +459,56 @@ class BaseFlashMistral(FlashCausalLM): input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices - logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - input_lengths=input_lengths, - max_s=max_s, - prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=lm_head_indices, - ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - return logits + + if self.model.max_past is not None: + max_s = min(self.model.max_past, max_s) + + bs = input_ids.shape[0] + padded_bs = bs + if bs == 3: + padded_bs = 4 + elif 3 < bs <= 8: + padded_bs = 8 + elif bs > 8: + padded_bs = (bs + 7) // 8 * 8 + + # Try to find an associated cuda graph + cuda_graph = self.cuda_graphs.get(padded_bs, None) + + if cu_seqlen_prefill is not None or cuda_graph is None: + logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=lm_head_indices, + ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + return logits + + # Copy inputs to the static inputs of the cuda graph + # Static inputs are potentially padded + cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids + cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids + cuda_graph["block_tables"][ + : block_tables.shape[0], : block_tables.shape[1] + ] = block_tables + cuda_graph["slots"].fill_(-1) + cuda_graph["slots"][: slots.shape[0]] = slots + cuda_graph["input_lengths"].zero_() + cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + + # Replay the graph + cuda_graph["graph"].replay() + + # Slice output to the correct shape + return cuda_graph["logits"][:bs] class FlashMistral(BaseFlashMistral): diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 8f7e1f10..d0614346 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -407,8 +407,9 @@ class Weights: data = json.load(f) self.gptq_bits = data["quantization_config"]["bits"] self.gptq_groupsize = data["quantization_config"]["group_size"] - self.gptq_desc_act = data["quantization_config"]["desc_act"] + # Order is important here, desc_act is missing on some real models self.quant_method = data["quantization_config"]["quant_method"] + self.gptq_desc_act = data["quantization_config"]["desc_act"] except Exception: filename = "quantize_config.json" try: