feat: experimental support for cuda graphs (#1428)

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
OlivierDehaene 2024-02-12 10:09:29 +01:00 committed by GitHub
parent 532146338b
commit 0d794af6a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 300 additions and 58 deletions

View File

@ -1,5 +1,5 @@
# Rust builder # Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.71 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.75 AS chef
WORKDIR /usr/src WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -166,7 +166,7 @@ FROM kernel-builder as megablocks-builder
RUN pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e RUN pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e
# Text Generation Inference base image # Text Generation Inference base image
FROM nvidia/cuda:12.1.0-base-ubuntu20.04 as base FROM nvidia/cuda:12.1.0-base-ubuntu22.04 as base
# Conda env # Conda env
ENV PATH=/opt/conda/bin:$PATH \ ENV PATH=/opt/conda/bin:$PATH \

View File

@ -1,5 +1,5 @@
# Rust builder # Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.71 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.75 AS chef
WORKDIR /usr/src WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse

View File

@ -205,6 +205,14 @@ Options:
[env: MAX_BATCH_SIZE=] [env: MAX_BATCH_SIZE=]
```
## ENABLE_CUDA_GRAPHS
```shell
--enable-cuda-graphs
Enable experimental support for cuda graphs
[env: ENABLE_CUDA_GRAPHS=]
``` ```
## HOSTNAME ## HOSTNAME
```shell ```shell

View File

@ -317,7 +317,10 @@ def launcher(event_loop):
gpu_count = num_shard if num_shard is not None else 1 gpu_count = num_shard if num_shard is not None else 1
env = {"LOG_LEVEL": "info,text_generation_router=debug"} env = {
"LOG_LEVEL": "info,text_generation_router=debug",
"ENABLE_CUDA_GRAPHS": "true",
}
if not use_flash_attention: if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false" env["USE_FLASH_ATTENTION"] = "false"

View File

@ -284,6 +284,10 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
/// Enable experimental support for cuda graphs
#[clap(long, env)]
enable_cuda_graphs: bool,
/// The IP address to listen on /// The IP address to listen on
#[clap(default_value = "0.0.0.0", long, env)] #[clap(default_value = "0.0.0.0", long, env)]
hostname: String, hostname: String,
@ -407,6 +411,7 @@ fn shard_manager(
disable_custom_kernels: bool, disable_custom_kernels: bool,
watermark_gamma: Option<f32>, watermark_gamma: Option<f32>,
watermark_delta: Option<f32>, watermark_delta: Option<f32>,
enable_cuda_graphs: bool,
cuda_memory_fraction: f32, cuda_memory_fraction: f32,
rope_scaling: Option<RopeScaling>, rope_scaling: Option<RopeScaling>,
rope_factor: Option<f32>, rope_factor: Option<f32>,
@ -488,7 +493,7 @@ fn shard_manager(
envs.push(("WORLD_SIZE".into(), world_size.to_string().into())); envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
envs.push(("MASTER_ADDR".into(), master_addr.into())); envs.push(("MASTER_ADDR".into(), master_addr.into()));
envs.push(("MASTER_PORT".into(), master_port.to_string().into())); envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into())); envs.push(("TORCH_NCCL_AVOID_RECORD_STREAMS".into(), "1".into()));
// CUDA memory fraction // CUDA memory fraction
envs.push(( envs.push((
@ -538,6 +543,11 @@ fn shard_manager(
)); ));
}; };
// Enable experimental support for cuda graphs
if enable_cuda_graphs {
envs.push(("ENABLE_CUDA_GRAPHS".into(), "True".into()))
}
// If disable_custom_kernels is true, pass it to the shard as an env var // If disable_custom_kernels is true, pass it to the shard as an env var
if disable_custom_kernels { if disable_custom_kernels {
envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into())) envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
@ -926,6 +936,7 @@ fn spawn_shards(
let disable_custom_kernels = args.disable_custom_kernels; let disable_custom_kernels = args.disable_custom_kernels;
let watermark_gamma = args.watermark_gamma; let watermark_gamma = args.watermark_gamma;
let watermark_delta = args.watermark_delta; let watermark_delta = args.watermark_delta;
let enable_cuda_graphs = args.enable_cuda_graphs;
let cuda_memory_fraction = args.cuda_memory_fraction; let cuda_memory_fraction = args.cuda_memory_fraction;
let rope_scaling = args.rope_scaling; let rope_scaling = args.rope_scaling;
let rope_factor = args.rope_factor; let rope_factor = args.rope_factor;
@ -947,6 +958,7 @@ fn spawn_shards(
disable_custom_kernels, disable_custom_kernels,
watermark_gamma, watermark_gamma,
watermark_delta, watermark_delta,
enable_cuda_graphs,
cuda_memory_fraction, cuda_memory_fraction,
rope_scaling, rope_scaling,
rope_factor, rope_factor,

View File

@ -1,8 +1,10 @@
awq_commit := f084f40bd996f3cf3a0633c1ad7d9d476c318aaa # Fork that adds only the correct stream to this kernel in order
# to make cuda graphs work.
awq_commit := bd1dc2d5254345cc76ab71894651fb821275bdd4
awq: awq:
rm -rf llm-awq rm -rf llm-awq
git clone https://github.com/mit-han-lab/llm-awq git clone https://github.com/huggingface/llm-awq
build-awq: awq build-awq: awq
cd llm-awq/ && git fetch && git checkout $(awq_commit) cd llm-awq/ && git fetch && git checkout $(awq_commit)

View File

@ -1,5 +1,6 @@
#include "q4_matmul.cuh" #include "q4_matmul.cuh"
#include "column_remap.cuh" #include "column_remap.cuh"
#include <ATen/cuda/CUDAContext.h>
#include "../util.cuh" #include "../util.cuh"
#include "../matrix.cuh" #include "../matrix.cuh"
#include "../cu_compat.cuh" #include "../cu_compat.cuh"
@ -224,8 +225,8 @@ void q4_matmul_recons_cuda
const int x_height, const int x_height,
Q4Matrix* w, Q4Matrix* w,
half* out, half* out,
const cublasHandle_t handle, bool no_zero,
bool no_zero const cublasHandle_t handle
) )
{ {
int height = x_height; int height = x_height;

View File

@ -19,8 +19,8 @@ void q4_matmul_cuda
const int x_height, const int x_height,
const Q4Matrix* w, const Q4Matrix* w,
half* out, half* out,
bool no_zero = false, bool no_zero,
cudaStream_t alt_stream = NULL cudaStream_t alt_stream
); );
void q4_matmul_recons_cuda void q4_matmul_recons_cuda
@ -30,8 +30,8 @@ void q4_matmul_recons_cuda
const int x_height, const int x_height,
Q4Matrix* w, Q4Matrix* w,
half* out, half* out,
const cublasHandle_t handle, bool no_zero,
bool no_zero = false const cublasHandle_t handle
); );
#endif #endif

View File

@ -1,5 +1,6 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama // Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include <ATen/cuda/CUDAContext.h>
#include "q4_matrix.cuh" #include "q4_matrix.cuh"
#include <vector> #include <vector>
#include "../util.cuh" #include "../util.cuh"
@ -146,7 +147,8 @@ void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx)
dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1); dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1);
dim3 blocks(width / UNSHUF_BLOCKSIZE_X / 2, height / 8, 1); dim3 blocks(width / UNSHUF_BLOCKSIZE_X / 2, height / 8, 1);
make_sequential_kernel<<<blocks, threads>>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
make_sequential_kernel<<<blocks, threads, 0, stream>>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width);
// Replace qweights // Replace qweights
@ -213,5 +215,6 @@ void Q4Matrix::reconstruct(half* out)
1 1
); );
reconstruct_kernel<<<blocks, threads>>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
reconstruct_kernel<<<blocks, threads, 0, stream>>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);
} }

View File

@ -183,6 +183,7 @@ void q4_matmul
int x_height = x.size(0); int x_height = x.size(0);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd) if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
{ {
q4_matmul_cuda q4_matmul_cuda
@ -191,7 +192,9 @@ void q4_matmul
(half*) x.data_ptr(), (half*) x.data_ptr(),
x_height, x_height,
wm, wm,
(half*) out.data_ptr() (half*) out.data_ptr(),
false,
stream
); );
} }
else else
@ -203,6 +206,7 @@ void q4_matmul
x_height, x_height,
wm, wm,
(half*) out.data_ptr(), (half*) out.data_ptr(),
false,
at::cuda::getCurrentCUDABlasHandle() at::cuda::getCurrentCUDABlasHandle()
); );
} }

View File

@ -38,6 +38,7 @@ void gemm_half_q_half_cuda_part
bool mul_r_weights bool mul_r_weights
) )
{ {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (!b->is_gptq) if (!b->is_gptq)
{ {
dim3 blockDim, gridDim; dim3 blockDim, gridDim;
@ -50,7 +51,7 @@ void gemm_half_q_half_cuda_part
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights); fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights);
kernel<<<gridDim, blockDim>>> kernel<<<gridDim, blockDim, 0, stream>>>
( (
a, a,
b->cuda_q_weight, b->cuda_q_weight,
@ -91,7 +92,7 @@ void gemm_half_q_half_cuda_part
// print_global_mem(r_weights, 1, 1, 1); // print_global_mem(r_weights, 1, 1, 1);
// DBGI(r_weights_stride); // DBGI(r_weights_stride);
kernel<<<gridDim, blockDim>>> kernel<<<gridDim, blockDim, 0, stream>>>
( (
a, a,
b->cuda_q_weight, b->cuda_q_weight,

View File

@ -168,8 +168,9 @@ QMatrix::QMatrix
blockDim.y = 1; blockDim.y = 1;
gridDim.x = DIVIDE(width, THREADS_X); gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = 1; gridDim.y = 1;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
shuffle_kernel<<<gridDim, blockDim>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2); shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
} }
QMatrix::~QMatrix() QMatrix::~QMatrix()
@ -475,11 +476,12 @@ void QMatrix::reconstruct(half* out)
blockDim.x = BLOCK_KN_SIZE; blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1; blockDim.y = 1;
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (!is_gptq) if (!is_gptq)
{ {
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
reconstruct_kernel<<<gridDim, blockDim>>> reconstruct_kernel<<<gridDim, blockDim, 0, stream>>>
( (
cuda_q_weight, cuda_q_weight,
cuda_q_perm, cuda_q_perm,
@ -502,7 +504,7 @@ void QMatrix::reconstruct(half* out)
else else
{ {
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4); gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4);
reconstruct_gptq_kernel<<<gridDim, blockDim>>> reconstruct_gptq_kernel<<<gridDim, blockDim, 0, stream>>>
( (
cuda_q_weight, cuda_q_weight,
cuda_q_perm, cuda_q_perm,
@ -563,6 +565,7 @@ __global__ void make_sequential_kernel
bool QMatrix::make_sequential(const uint32_t* cpu_g_idx) bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
{ {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
uint32_t* cuda_new_qweight = NULL; uint32_t* cuda_new_qweight = NULL;
cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
if (err != cudaSuccess) { if (err != cudaSuccess) {
@ -621,7 +624,7 @@ bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
gridDim.x = DIVIDE(width, THREADS_X); gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = height / 8; gridDim.y = height / 8;
make_sequential_kernel<<<gridDim, blockDim>>> make_sequential_kernel<<<gridDim, blockDim, 0, stream>>>
( (
cuda_q_weight, cuda_q_weight,
cuda_new_qweight, cuda_new_qweight,

View File

@ -425,6 +425,11 @@ class FlashMistralForCausalLM(torch.nn.Module):
weights=weights, weights=weights,
) )
self.max_past = config.sliding_window self.max_past = config.sliding_window
self.max_past_tensor = (
torch.tensor(config.sliding_window, device=weights.device)
if self.max_past is not None
else None
)
def forward( def forward(
self, self,
@ -446,8 +451,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
max_s = min(self.max_past, max_s) input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
input_lengths = torch.clamp(input_lengths, max=self.max_past)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,

View File

@ -816,6 +816,11 @@ class FlashMixtralForCausalLM(torch.nn.Module):
weights=weights, weights=weights,
) )
self.max_past = config.sliding_window self.max_past = config.sliding_window
self.max_past_tensor = (
torch.tensor(config.sliding_window, device=weights.device)
if self.max_past is not None
else None
)
def forward( def forward(
self, self,
@ -837,8 +842,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
max_s = min(self.max_past, max_s) input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
input_lengths = torch.clamp(input_lengths, max=self.max_past)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,

View File

@ -1,4 +1,5 @@
import math import math
import os
import time import time
import itertools import itertools
import torch import torch
@ -6,6 +7,7 @@ import torch.distributed
import numpy as np import numpy as np
from loguru import logger
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
@ -31,6 +33,8 @@ from text_generation_server.utils.dist import MEMORY_FRACTION
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
MEM_POOL = torch.cuda.graph_pool_handle()
@dataclass @dataclass
class FlashCausalLMBatch(Batch): class FlashCausalLMBatch(Batch):
@ -62,7 +66,7 @@ class FlashCausalLMBatch(Batch):
# Set in prefill by the CacheManager # Set in prefill by the CacheManager
# list of length b of list of length s_i // block_size # list of length b of list of length s_i // block_size
block_tables: Optional[List[List[int]]] block_tables: Optional[List[List[int]]]
# tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
block_tables_tensor: Optional[torch.Tensor] block_tables_tensor: Optional[torch.Tensor]
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
slots: Optional[torch.Tensor] slots: Optional[torch.Tensor]
@ -663,6 +667,8 @@ class FlashCausalLM(Model):
self.num_kv_heads = num_kv_heads self.num_kv_heads = num_kv_heads
self.head_size = head_size self.head_size = head_size
self.cuda_graphs = {}
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -678,7 +684,60 @@ class FlashCausalLM(Model):
def batch_type(self) -> Type[FlashCausalLMBatch]: def batch_type(self) -> Type[FlashCausalLMBatch]:
return FlashCausalLMBatch return FlashCausalLMBatch
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
slots = torch.arange(bs, dtype=torch.int32, device=self.device)
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
block_tables = (
torch.arange(max_bt, dtype=torch.int32, device=self.device)
.repeat(bs)
.reshape((bs, max_bt))
)
kv_cache = get_cache_manager().kv_cache
self.cuda_graphs[bs] = {
"input_ids": input_ids,
"position_ids": position_ids,
"kv_cache": kv_cache,
"block_tables": block_tables,
"slots": slots,
"input_lengths": input_lengths,
}
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph
torch.cuda.synchronize()
# Run once outside to warmup
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
lm_head_indices=None,
)
torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL):
self.cuda_graphs[bs]["logits"] = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
lm_head_indices=None,
)
torch.cuda.synchronize()
def warmup(self, batch: FlashCausalLMBatch): def warmup(self, batch: FlashCausalLMBatch):
# The warmup batch is the biggest batch we could ever receive
torch.cuda.empty_cache() torch.cuda.empty_cache()
try: try:
cache_manager = set_cache_manager( cache_manager = set_cache_manager(
@ -690,6 +749,8 @@ class FlashCausalLM(Model):
self.dtype, self.dtype,
self.device, self.device,
) )
max_bt = batch.max_blocks
max_s = max_bt * get_cache_manager().block_size
_, batch, _ = self.generate_token(batch) _, batch, _ = self.generate_token(batch)
except torch.cuda.OutOfMemoryError as e: except torch.cuda.OutOfMemoryError as e:
raise RuntimeError( raise RuntimeError(
@ -713,7 +774,8 @@ class FlashCausalLM(Model):
) )
num_blocks = ( num_blocks = (
int(free_memory // total_cache_size) # Leave 5% for some wiggle room
int((free_memory * 0.95) // total_cache_size)
# Add batch.blocks as we allocated it above, so it is included in the peak memory. # Add batch.blocks as we allocated it above, so it is included in the peak memory.
+ cache_manager.num_blocks + cache_manager.num_blocks
) )
@ -731,9 +793,19 @@ class FlashCausalLM(Model):
self.device, self.device,
) )
if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True":
try:
logger.info("Experimental support for Cuda Graphs is enabled")
# Warmup cuda graphs
for bs in [1, 2, 4] + [8 * i for i in range(8)]:
if self.speculate is None or self.speculate + 1 <= bs:
self.cuda_graph_warmup(bs, max_s, max_bt)
except Exception:
logger.exception(f"Decode cuda graph warmup failed")
return int(num_blocks * BLOCK_SIZE) return int(num_blocks * BLOCK_SIZE)
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, batch: FlashCausalLMBatch) -> torch.Tensor:
# Model Forward # Model Forward
if batch.speculative_ids is not None: if batch.speculative_ids is not None:
input_ids = batch.input_ids input_ids = batch.input_ids
@ -785,6 +857,19 @@ class FlashCausalLM(Model):
max_s = batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
bs = input_ids.shape[0]
padded_bs = bs
if bs == 3:
padded_bs = 4
elif 3 < bs <= 8:
padded_bs = 8
elif bs > 8:
padded_bs = (bs + 7) // 8 * 8
# Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(padded_bs, None)
if cu_seqlen_prefill is not None or cuda_graph is None or batch.speculative_ids is not None:
return self.model.forward( return self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
@ -797,6 +882,24 @@ class FlashCausalLM(Model):
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
) )
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
return cuda_graph["logits"][:bs]
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(
self, batch: FlashCausalLMBatch self, batch: FlashCausalLMBatch

View File

@ -35,6 +35,8 @@ tracer = trace.get_tracer(__name__)
SLIDING_WINDOW: Optional[int] = None SLIDING_WINDOW: Optional[int] = None
SLIDING_WINDOW_BLOCKS: Optional[int] = None SLIDING_WINDOW_BLOCKS: Optional[int] = None
MEM_POOL = torch.cuda.graph_pool_handle()
# Adds windowing logic to FlashCausalLMBatch # Adds windowing logic to FlashCausalLMBatch
@dataclass @dataclass
@ -332,6 +334,8 @@ class BaseFlashMistral(FlashCausalLM):
model = model_cls(config, weights) model = model_cls(config, weights)
self.cuda_graphs = {}
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__( super(BaseFlashMistral, self).__init__(
model=model, model=model,
@ -350,6 +354,60 @@ class BaseFlashMistral(FlashCausalLM):
def batch_type(self) -> Type[FlashMistralBatch]: def batch_type(self) -> Type[FlashMistralBatch]:
return FlashMistralBatch return FlashMistralBatch
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
slots = torch.arange(bs, dtype=torch.int32, device=self.device)
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
block_tables = (
torch.arange(max_bt, dtype=torch.int32, device=self.device)
.repeat(bs)
.reshape((bs, max_bt))
)
kv_cache = get_cache_manager().kv_cache
self.cuda_graphs[bs] = {
"input_ids": input_ids,
"position_ids": position_ids,
"kv_cache": kv_cache,
"block_tables": block_tables,
"slots": slots,
"input_lengths": input_lengths,
}
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph
torch.cuda.synchronize()
# Run once outside to warmup
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
)
torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL):
self.cuda_graphs[bs]["logits"] = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
)
torch.cuda.synchronize()
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward # Model Forward
if batch.speculative_ids is not None: if batch.speculative_ids is not None:
@ -401,6 +459,23 @@ class BaseFlashMistral(FlashCausalLM):
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
if self.model.max_past is not None:
max_s = min(self.model.max_past, max_s)
bs = input_ids.shape[0]
padded_bs = bs
if bs == 3:
padded_bs = 4
elif 3 < bs <= 8:
padded_bs = 8
elif bs > 8:
padded_bs = (bs + 7) // 8 * 8
# Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(padded_bs, None)
if cu_seqlen_prefill is not None or cuda_graph is None:
logits = self.model.forward( logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
@ -417,6 +492,24 @@ class BaseFlashMistral(FlashCausalLM):
batch.prefill_cache_indices = None batch.prefill_cache_indices = None
return logits return logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
return cuda_graph["logits"][:bs]
class FlashMistral(BaseFlashMistral): class FlashMistral(BaseFlashMistral):
def __init__( def __init__(

View File

@ -407,8 +407,9 @@ class Weights:
data = json.load(f) data = json.load(f)
self.gptq_bits = data["quantization_config"]["bits"] self.gptq_bits = data["quantization_config"]["bits"]
self.gptq_groupsize = data["quantization_config"]["group_size"] self.gptq_groupsize = data["quantization_config"]["group_size"]
self.gptq_desc_act = data["quantization_config"]["desc_act"] # Order is important here, desc_act is missing on some real models
self.quant_method = data["quantization_config"]["quant_method"] self.quant_method = data["quantization_config"]["quant_method"]
self.gptq_desc_act = data["quantization_config"]["desc_act"]
except Exception: except Exception:
filename = "quantize_config.json" filename = "quantize_config.json"
try: try: