feat: experimental support for cuda graphs (#1428)
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
parent
532146338b
commit
0d794af6a5
|
@ -1,5 +1,5 @@
|
||||||
# Rust builder
|
# Rust builder
|
||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.71 AS chef
|
FROM lukemathwalker/cargo-chef:latest-rust-1.75 AS chef
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
|
@ -166,7 +166,7 @@ FROM kernel-builder as megablocks-builder
|
||||||
RUN pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e
|
RUN pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e
|
||||||
|
|
||||||
# Text Generation Inference base image
|
# Text Generation Inference base image
|
||||||
FROM nvidia/cuda:12.1.0-base-ubuntu20.04 as base
|
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 as base
|
||||||
|
|
||||||
# Conda env
|
# Conda env
|
||||||
ENV PATH=/opt/conda/bin:$PATH \
|
ENV PATH=/opt/conda/bin:$PATH \
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# Rust builder
|
# Rust builder
|
||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.71 AS chef
|
FROM lukemathwalker/cargo-chef:latest-rust-1.75 AS chef
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
|
|
|
@ -205,6 +205,14 @@ Options:
|
||||||
|
|
||||||
[env: MAX_BATCH_SIZE=]
|
[env: MAX_BATCH_SIZE=]
|
||||||
|
|
||||||
|
```
|
||||||
|
## ENABLE_CUDA_GRAPHS
|
||||||
|
```shell
|
||||||
|
--enable-cuda-graphs
|
||||||
|
Enable experimental support for cuda graphs
|
||||||
|
|
||||||
|
[env: ENABLE_CUDA_GRAPHS=]
|
||||||
|
|
||||||
```
|
```
|
||||||
## HOSTNAME
|
## HOSTNAME
|
||||||
```shell
|
```shell
|
||||||
|
|
|
@ -317,7 +317,10 @@ def launcher(event_loop):
|
||||||
|
|
||||||
gpu_count = num_shard if num_shard is not None else 1
|
gpu_count = num_shard if num_shard is not None else 1
|
||||||
|
|
||||||
env = {"LOG_LEVEL": "info,text_generation_router=debug"}
|
env = {
|
||||||
|
"LOG_LEVEL": "info,text_generation_router=debug",
|
||||||
|
"ENABLE_CUDA_GRAPHS": "true",
|
||||||
|
}
|
||||||
if not use_flash_attention:
|
if not use_flash_attention:
|
||||||
env["USE_FLASH_ATTENTION"] = "false"
|
env["USE_FLASH_ATTENTION"] = "false"
|
||||||
|
|
||||||
|
|
|
@ -284,6 +284,10 @@ struct Args {
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
|
|
||||||
|
/// Enable experimental support for cuda graphs
|
||||||
|
#[clap(long, env)]
|
||||||
|
enable_cuda_graphs: bool,
|
||||||
|
|
||||||
/// The IP address to listen on
|
/// The IP address to listen on
|
||||||
#[clap(default_value = "0.0.0.0", long, env)]
|
#[clap(default_value = "0.0.0.0", long, env)]
|
||||||
hostname: String,
|
hostname: String,
|
||||||
|
@ -407,6 +411,7 @@ fn shard_manager(
|
||||||
disable_custom_kernels: bool,
|
disable_custom_kernels: bool,
|
||||||
watermark_gamma: Option<f32>,
|
watermark_gamma: Option<f32>,
|
||||||
watermark_delta: Option<f32>,
|
watermark_delta: Option<f32>,
|
||||||
|
enable_cuda_graphs: bool,
|
||||||
cuda_memory_fraction: f32,
|
cuda_memory_fraction: f32,
|
||||||
rope_scaling: Option<RopeScaling>,
|
rope_scaling: Option<RopeScaling>,
|
||||||
rope_factor: Option<f32>,
|
rope_factor: Option<f32>,
|
||||||
|
@ -488,7 +493,7 @@ fn shard_manager(
|
||||||
envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
|
envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
|
||||||
envs.push(("MASTER_ADDR".into(), master_addr.into()));
|
envs.push(("MASTER_ADDR".into(), master_addr.into()));
|
||||||
envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
|
envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
|
||||||
envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
|
envs.push(("TORCH_NCCL_AVOID_RECORD_STREAMS".into(), "1".into()));
|
||||||
|
|
||||||
// CUDA memory fraction
|
// CUDA memory fraction
|
||||||
envs.push((
|
envs.push((
|
||||||
|
@ -538,6 +543,11 @@ fn shard_manager(
|
||||||
));
|
));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Enable experimental support for cuda graphs
|
||||||
|
if enable_cuda_graphs {
|
||||||
|
envs.push(("ENABLE_CUDA_GRAPHS".into(), "True".into()))
|
||||||
|
}
|
||||||
|
|
||||||
// If disable_custom_kernels is true, pass it to the shard as an env var
|
// If disable_custom_kernels is true, pass it to the shard as an env var
|
||||||
if disable_custom_kernels {
|
if disable_custom_kernels {
|
||||||
envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
|
envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
|
||||||
|
@ -926,6 +936,7 @@ fn spawn_shards(
|
||||||
let disable_custom_kernels = args.disable_custom_kernels;
|
let disable_custom_kernels = args.disable_custom_kernels;
|
||||||
let watermark_gamma = args.watermark_gamma;
|
let watermark_gamma = args.watermark_gamma;
|
||||||
let watermark_delta = args.watermark_delta;
|
let watermark_delta = args.watermark_delta;
|
||||||
|
let enable_cuda_graphs = args.enable_cuda_graphs;
|
||||||
let cuda_memory_fraction = args.cuda_memory_fraction;
|
let cuda_memory_fraction = args.cuda_memory_fraction;
|
||||||
let rope_scaling = args.rope_scaling;
|
let rope_scaling = args.rope_scaling;
|
||||||
let rope_factor = args.rope_factor;
|
let rope_factor = args.rope_factor;
|
||||||
|
@ -947,6 +958,7 @@ fn spawn_shards(
|
||||||
disable_custom_kernels,
|
disable_custom_kernels,
|
||||||
watermark_gamma,
|
watermark_gamma,
|
||||||
watermark_delta,
|
watermark_delta,
|
||||||
|
enable_cuda_graphs,
|
||||||
cuda_memory_fraction,
|
cuda_memory_fraction,
|
||||||
rope_scaling,
|
rope_scaling,
|
||||||
rope_factor,
|
rope_factor,
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
awq_commit := f084f40bd996f3cf3a0633c1ad7d9d476c318aaa
|
# Fork that adds only the correct stream to this kernel in order
|
||||||
|
# to make cuda graphs work.
|
||||||
|
awq_commit := bd1dc2d5254345cc76ab71894651fb821275bdd4
|
||||||
|
|
||||||
awq:
|
awq:
|
||||||
rm -rf llm-awq
|
rm -rf llm-awq
|
||||||
git clone https://github.com/mit-han-lab/llm-awq
|
git clone https://github.com/huggingface/llm-awq
|
||||||
|
|
||||||
build-awq: awq
|
build-awq: awq
|
||||||
cd llm-awq/ && git fetch && git checkout $(awq_commit)
|
cd llm-awq/ && git fetch && git checkout $(awq_commit)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
#include "q4_matmul.cuh"
|
#include "q4_matmul.cuh"
|
||||||
#include "column_remap.cuh"
|
#include "column_remap.cuh"
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include "../util.cuh"
|
#include "../util.cuh"
|
||||||
#include "../matrix.cuh"
|
#include "../matrix.cuh"
|
||||||
#include "../cu_compat.cuh"
|
#include "../cu_compat.cuh"
|
||||||
|
@ -224,8 +225,8 @@ void q4_matmul_recons_cuda
|
||||||
const int x_height,
|
const int x_height,
|
||||||
Q4Matrix* w,
|
Q4Matrix* w,
|
||||||
half* out,
|
half* out,
|
||||||
const cublasHandle_t handle,
|
bool no_zero,
|
||||||
bool no_zero
|
const cublasHandle_t handle
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
int height = x_height;
|
int height = x_height;
|
||||||
|
|
|
@ -19,8 +19,8 @@ void q4_matmul_cuda
|
||||||
const int x_height,
|
const int x_height,
|
||||||
const Q4Matrix* w,
|
const Q4Matrix* w,
|
||||||
half* out,
|
half* out,
|
||||||
bool no_zero = false,
|
bool no_zero,
|
||||||
cudaStream_t alt_stream = NULL
|
cudaStream_t alt_stream
|
||||||
);
|
);
|
||||||
|
|
||||||
void q4_matmul_recons_cuda
|
void q4_matmul_recons_cuda
|
||||||
|
@ -30,8 +30,8 @@ void q4_matmul_recons_cuda
|
||||||
const int x_height,
|
const int x_height,
|
||||||
Q4Matrix* w,
|
Q4Matrix* w,
|
||||||
half* out,
|
half* out,
|
||||||
const cublasHandle_t handle,
|
bool no_zero,
|
||||||
bool no_zero = false
|
const cublasHandle_t handle
|
||||||
);
|
);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||||
|
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include "q4_matrix.cuh"
|
#include "q4_matrix.cuh"
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "../util.cuh"
|
#include "../util.cuh"
|
||||||
|
@ -146,7 +147,8 @@ void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx)
|
||||||
dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1);
|
dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1);
|
||||||
dim3 blocks(width / UNSHUF_BLOCKSIZE_X / 2, height / 8, 1);
|
dim3 blocks(width / UNSHUF_BLOCKSIZE_X / 2, height / 8, 1);
|
||||||
|
|
||||||
make_sequential_kernel<<<blocks, threads>>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width);
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
make_sequential_kernel<<<blocks, threads, 0, stream>>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width);
|
||||||
|
|
||||||
// Replace qweights
|
// Replace qweights
|
||||||
|
|
||||||
|
@ -213,5 +215,6 @@ void Q4Matrix::reconstruct(half* out)
|
||||||
1
|
1
|
||||||
);
|
);
|
||||||
|
|
||||||
reconstruct_kernel<<<blocks, threads>>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
reconstruct_kernel<<<blocks, threads, 0, stream>>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);
|
||||||
}
|
}
|
|
@ -183,6 +183,7 @@ void q4_matmul
|
||||||
|
|
||||||
int x_height = x.size(0);
|
int x_height = x.size(0);
|
||||||
|
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
|
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
|
||||||
{
|
{
|
||||||
q4_matmul_cuda
|
q4_matmul_cuda
|
||||||
|
@ -191,7 +192,9 @@ void q4_matmul
|
||||||
(half*) x.data_ptr(),
|
(half*) x.data_ptr(),
|
||||||
x_height,
|
x_height,
|
||||||
wm,
|
wm,
|
||||||
(half*) out.data_ptr()
|
(half*) out.data_ptr(),
|
||||||
|
false,
|
||||||
|
stream
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
|
@ -203,6 +206,7 @@ void q4_matmul
|
||||||
x_height,
|
x_height,
|
||||||
wm,
|
wm,
|
||||||
(half*) out.data_ptr(),
|
(half*) out.data_ptr(),
|
||||||
|
false,
|
||||||
at::cuda::getCurrentCUDABlasHandle()
|
at::cuda::getCurrentCUDABlasHandle()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,6 +38,7 @@ void gemm_half_q_half_cuda_part
|
||||||
bool mul_r_weights
|
bool mul_r_weights
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
if (!b->is_gptq)
|
if (!b->is_gptq)
|
||||||
{
|
{
|
||||||
dim3 blockDim, gridDim;
|
dim3 blockDim, gridDim;
|
||||||
|
@ -50,7 +51,7 @@ void gemm_half_q_half_cuda_part
|
||||||
|
|
||||||
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights);
|
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights);
|
||||||
|
|
||||||
kernel<<<gridDim, blockDim>>>
|
kernel<<<gridDim, blockDim, 0, stream>>>
|
||||||
(
|
(
|
||||||
a,
|
a,
|
||||||
b->cuda_q_weight,
|
b->cuda_q_weight,
|
||||||
|
@ -91,7 +92,7 @@ void gemm_half_q_half_cuda_part
|
||||||
// print_global_mem(r_weights, 1, 1, 1);
|
// print_global_mem(r_weights, 1, 1, 1);
|
||||||
// DBGI(r_weights_stride);
|
// DBGI(r_weights_stride);
|
||||||
|
|
||||||
kernel<<<gridDim, blockDim>>>
|
kernel<<<gridDim, blockDim, 0, stream>>>
|
||||||
(
|
(
|
||||||
a,
|
a,
|
||||||
b->cuda_q_weight,
|
b->cuda_q_weight,
|
||||||
|
|
|
@ -168,8 +168,9 @@ QMatrix::QMatrix
|
||||||
blockDim.y = 1;
|
blockDim.y = 1;
|
||||||
gridDim.x = DIVIDE(width, THREADS_X);
|
gridDim.x = DIVIDE(width, THREADS_X);
|
||||||
gridDim.y = 1;
|
gridDim.y = 1;
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
shuffle_kernel<<<gridDim, blockDim>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
|
shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
|
||||||
}
|
}
|
||||||
|
|
||||||
QMatrix::~QMatrix()
|
QMatrix::~QMatrix()
|
||||||
|
@ -475,11 +476,12 @@ void QMatrix::reconstruct(half* out)
|
||||||
blockDim.x = BLOCK_KN_SIZE;
|
blockDim.x = BLOCK_KN_SIZE;
|
||||||
blockDim.y = 1;
|
blockDim.y = 1;
|
||||||
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
|
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
if (!is_gptq)
|
if (!is_gptq)
|
||||||
{
|
{
|
||||||
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
|
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
|
||||||
reconstruct_kernel<<<gridDim, blockDim>>>
|
reconstruct_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||||
(
|
(
|
||||||
cuda_q_weight,
|
cuda_q_weight,
|
||||||
cuda_q_perm,
|
cuda_q_perm,
|
||||||
|
@ -502,7 +504,7 @@ void QMatrix::reconstruct(half* out)
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4);
|
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4);
|
||||||
reconstruct_gptq_kernel<<<gridDim, blockDim>>>
|
reconstruct_gptq_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||||
(
|
(
|
||||||
cuda_q_weight,
|
cuda_q_weight,
|
||||||
cuda_q_perm,
|
cuda_q_perm,
|
||||||
|
@ -563,6 +565,7 @@ __global__ void make_sequential_kernel
|
||||||
|
|
||||||
bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
|
bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
|
||||||
{
|
{
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
uint32_t* cuda_new_qweight = NULL;
|
uint32_t* cuda_new_qweight = NULL;
|
||||||
cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
|
cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
|
||||||
if (err != cudaSuccess) {
|
if (err != cudaSuccess) {
|
||||||
|
@ -621,7 +624,7 @@ bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
|
||||||
gridDim.x = DIVIDE(width, THREADS_X);
|
gridDim.x = DIVIDE(width, THREADS_X);
|
||||||
gridDim.y = height / 8;
|
gridDim.y = height / 8;
|
||||||
|
|
||||||
make_sequential_kernel<<<gridDim, blockDim>>>
|
make_sequential_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||||
(
|
(
|
||||||
cuda_q_weight,
|
cuda_q_weight,
|
||||||
cuda_new_qweight,
|
cuda_new_qweight,
|
||||||
|
|
|
@ -425,6 +425,11 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.max_past = config.sliding_window
|
self.max_past = config.sliding_window
|
||||||
|
self.max_past_tensor = (
|
||||||
|
torch.tensor(config.sliding_window, device=weights.device)
|
||||||
|
if self.max_past is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -446,8 +451,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
||||||
elif self.max_past is not None:
|
elif self.max_past is not None:
|
||||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||||
# kernel requires the true values
|
# kernel requires the true values
|
||||||
max_s = min(self.max_past, max_s)
|
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
|
||||||
input_lengths = torch.clamp(input_lengths, max=self.max_past)
|
|
||||||
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
|
|
@ -816,6 +816,11 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.max_past = config.sliding_window
|
self.max_past = config.sliding_window
|
||||||
|
self.max_past_tensor = (
|
||||||
|
torch.tensor(config.sliding_window, device=weights.device)
|
||||||
|
if self.max_past is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -837,8 +842,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
||||||
elif self.max_past is not None:
|
elif self.max_past is not None:
|
||||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||||
# kernel requires the true values
|
# kernel requires the true values
|
||||||
max_s = min(self.max_past, max_s)
|
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
|
||||||
input_lengths = torch.clamp(input_lengths, max=self.max_past)
|
|
||||||
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
import itertools
|
import itertools
|
||||||
import torch
|
import torch
|
||||||
|
@ -6,6 +7,7 @@ import torch.distributed
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
@ -31,6 +33,8 @@ from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
MEM_POOL = torch.cuda.graph_pool_handle()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlashCausalLMBatch(Batch):
|
class FlashCausalLMBatch(Batch):
|
||||||
|
@ -62,7 +66,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
# Set in prefill by the CacheManager
|
# Set in prefill by the CacheManager
|
||||||
# list of length b of list of length s_i // block_size
|
# list of length b of list of length s_i // block_size
|
||||||
block_tables: Optional[List[List[int]]]
|
block_tables: Optional[List[List[int]]]
|
||||||
# tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
|
# tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
|
||||||
block_tables_tensor: Optional[torch.Tensor]
|
block_tables_tensor: Optional[torch.Tensor]
|
||||||
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
||||||
slots: Optional[torch.Tensor]
|
slots: Optional[torch.Tensor]
|
||||||
|
@ -663,6 +667,8 @@ class FlashCausalLM(Model):
|
||||||
self.num_kv_heads = num_kv_heads
|
self.num_kv_heads = num_kv_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
|
|
||||||
|
self.cuda_graphs = {}
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
@ -678,7 +684,60 @@ class FlashCausalLM(Model):
|
||||||
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
||||||
return FlashCausalLMBatch
|
return FlashCausalLMBatch
|
||||||
|
|
||||||
|
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||||
|
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||||
|
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||||
|
slots = torch.arange(bs, dtype=torch.int32, device=self.device)
|
||||||
|
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||||
|
block_tables = (
|
||||||
|
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
||||||
|
.repeat(bs)
|
||||||
|
.reshape((bs, max_bt))
|
||||||
|
)
|
||||||
|
kv_cache = get_cache_manager().kv_cache
|
||||||
|
|
||||||
|
self.cuda_graphs[bs] = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"position_ids": position_ids,
|
||||||
|
"kv_cache": kv_cache,
|
||||||
|
"block_tables": block_tables,
|
||||||
|
"slots": slots,
|
||||||
|
"input_lengths": input_lengths,
|
||||||
|
}
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
self.cuda_graphs[bs]["graph"] = graph
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
# Run once outside to warmup
|
||||||
|
self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
max_s=max_s,
|
||||||
|
lm_head_indices=None,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||||
|
self.cuda_graphs[bs]["logits"] = self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
max_s=max_s,
|
||||||
|
lm_head_indices=None,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def warmup(self, batch: FlashCausalLMBatch):
|
def warmup(self, batch: FlashCausalLMBatch):
|
||||||
|
# The warmup batch is the biggest batch we could ever receive
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
try:
|
try:
|
||||||
cache_manager = set_cache_manager(
|
cache_manager = set_cache_manager(
|
||||||
|
@ -690,6 +749,8 @@ class FlashCausalLM(Model):
|
||||||
self.dtype,
|
self.dtype,
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
|
max_bt = batch.max_blocks
|
||||||
|
max_s = max_bt * get_cache_manager().block_size
|
||||||
_, batch, _ = self.generate_token(batch)
|
_, batch, _ = self.generate_token(batch)
|
||||||
except torch.cuda.OutOfMemoryError as e:
|
except torch.cuda.OutOfMemoryError as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
@ -713,7 +774,8 @@ class FlashCausalLM(Model):
|
||||||
)
|
)
|
||||||
|
|
||||||
num_blocks = (
|
num_blocks = (
|
||||||
int(free_memory // total_cache_size)
|
# Leave 5% for some wiggle room
|
||||||
|
int((free_memory * 0.95) // total_cache_size)
|
||||||
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
|
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
|
||||||
+ cache_manager.num_blocks
|
+ cache_manager.num_blocks
|
||||||
)
|
)
|
||||||
|
@ -731,9 +793,19 @@ class FlashCausalLM(Model):
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True":
|
||||||
|
try:
|
||||||
|
logger.info("Experimental support for Cuda Graphs is enabled")
|
||||||
|
# Warmup cuda graphs
|
||||||
|
for bs in [1, 2, 4] + [8 * i for i in range(8)]:
|
||||||
|
if self.speculate is None or self.speculate + 1 <= bs:
|
||||||
|
self.cuda_graph_warmup(bs, max_s, max_bt)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(f"Decode cuda graph warmup failed")
|
||||||
|
|
||||||
return int(num_blocks * BLOCK_SIZE)
|
return int(num_blocks * BLOCK_SIZE)
|
||||||
|
|
||||||
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
def forward(self, batch: FlashCausalLMBatch) -> torch.Tensor:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
if batch.speculative_ids is not None:
|
if batch.speculative_ids is not None:
|
||||||
input_ids = batch.input_ids
|
input_ids = batch.input_ids
|
||||||
|
@ -785,6 +857,19 @@ class FlashCausalLM(Model):
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
|
bs = input_ids.shape[0]
|
||||||
|
padded_bs = bs
|
||||||
|
if bs == 3:
|
||||||
|
padded_bs = 4
|
||||||
|
elif 3 < bs <= 8:
|
||||||
|
padded_bs = 8
|
||||||
|
elif bs > 8:
|
||||||
|
padded_bs = (bs + 7) // 8 * 8
|
||||||
|
|
||||||
|
# Try to find an associated cuda graph
|
||||||
|
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||||
|
|
||||||
|
if cu_seqlen_prefill is not None or cuda_graph is None or batch.speculative_ids is not None:
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
@ -797,6 +882,24 @@ class FlashCausalLM(Model):
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copy inputs to the static inputs of the cuda graph
|
||||||
|
# Static inputs are potentially padded
|
||||||
|
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||||
|
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||||
|
cuda_graph["block_tables"][
|
||||||
|
: block_tables.shape[0], : block_tables.shape[1]
|
||||||
|
] = block_tables
|
||||||
|
cuda_graph["slots"].fill_(-1)
|
||||||
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
|
cuda_graph["input_lengths"].zero_()
|
||||||
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||||
|
|
||||||
|
# Replay the graph
|
||||||
|
cuda_graph["graph"].replay()
|
||||||
|
|
||||||
|
# Slice output to the correct shape
|
||||||
|
return cuda_graph["logits"][:bs]
|
||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: FlashCausalLMBatch
|
self, batch: FlashCausalLMBatch
|
||||||
|
|
|
@ -35,6 +35,8 @@ tracer = trace.get_tracer(__name__)
|
||||||
SLIDING_WINDOW: Optional[int] = None
|
SLIDING_WINDOW: Optional[int] = None
|
||||||
SLIDING_WINDOW_BLOCKS: Optional[int] = None
|
SLIDING_WINDOW_BLOCKS: Optional[int] = None
|
||||||
|
|
||||||
|
MEM_POOL = torch.cuda.graph_pool_handle()
|
||||||
|
|
||||||
|
|
||||||
# Adds windowing logic to FlashCausalLMBatch
|
# Adds windowing logic to FlashCausalLMBatch
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -332,6 +334,8 @@ class BaseFlashMistral(FlashCausalLM):
|
||||||
|
|
||||||
model = model_cls(config, weights)
|
model = model_cls(config, weights)
|
||||||
|
|
||||||
|
self.cuda_graphs = {}
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(BaseFlashMistral, self).__init__(
|
super(BaseFlashMistral, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -350,6 +354,60 @@ class BaseFlashMistral(FlashCausalLM):
|
||||||
def batch_type(self) -> Type[FlashMistralBatch]:
|
def batch_type(self) -> Type[FlashMistralBatch]:
|
||||||
return FlashMistralBatch
|
return FlashMistralBatch
|
||||||
|
|
||||||
|
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||||
|
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||||
|
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||||
|
slots = torch.arange(bs, dtype=torch.int32, device=self.device)
|
||||||
|
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||||
|
block_tables = (
|
||||||
|
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
||||||
|
.repeat(bs)
|
||||||
|
.reshape((bs, max_bt))
|
||||||
|
)
|
||||||
|
kv_cache = get_cache_manager().kv_cache
|
||||||
|
|
||||||
|
self.cuda_graphs[bs] = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"position_ids": position_ids,
|
||||||
|
"kv_cache": kv_cache,
|
||||||
|
"block_tables": block_tables,
|
||||||
|
"slots": slots,
|
||||||
|
"input_lengths": input_lengths,
|
||||||
|
}
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
self.cuda_graphs[bs]["graph"] = graph
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
# Run once outside to warmup
|
||||||
|
self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
max_s=max_s,
|
||||||
|
prefill_cache_indices=None,
|
||||||
|
lm_head_indices=None,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||||
|
self.cuda_graphs[bs]["logits"] = self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
max_s=max_s,
|
||||||
|
prefill_cache_indices=None,
|
||||||
|
lm_head_indices=None,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
if batch.speculative_ids is not None:
|
if batch.speculative_ids is not None:
|
||||||
|
@ -401,6 +459,23 @@ class BaseFlashMistral(FlashCausalLM):
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
|
if self.model.max_past is not None:
|
||||||
|
max_s = min(self.model.max_past, max_s)
|
||||||
|
|
||||||
|
bs = input_ids.shape[0]
|
||||||
|
padded_bs = bs
|
||||||
|
if bs == 3:
|
||||||
|
padded_bs = 4
|
||||||
|
elif 3 < bs <= 8:
|
||||||
|
padded_bs = 8
|
||||||
|
elif bs > 8:
|
||||||
|
padded_bs = (bs + 7) // 8 * 8
|
||||||
|
|
||||||
|
# Try to find an associated cuda graph
|
||||||
|
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||||
|
|
||||||
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
logits = self.model.forward(
|
logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
@ -417,6 +492,24 @@ class BaseFlashMistral(FlashCausalLM):
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
# Copy inputs to the static inputs of the cuda graph
|
||||||
|
# Static inputs are potentially padded
|
||||||
|
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||||
|
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||||
|
cuda_graph["block_tables"][
|
||||||
|
: block_tables.shape[0], : block_tables.shape[1]
|
||||||
|
] = block_tables
|
||||||
|
cuda_graph["slots"].fill_(-1)
|
||||||
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
|
cuda_graph["input_lengths"].zero_()
|
||||||
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||||
|
|
||||||
|
# Replay the graph
|
||||||
|
cuda_graph["graph"].replay()
|
||||||
|
|
||||||
|
# Slice output to the correct shape
|
||||||
|
return cuda_graph["logits"][:bs]
|
||||||
|
|
||||||
|
|
||||||
class FlashMistral(BaseFlashMistral):
|
class FlashMistral(BaseFlashMistral):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -407,8 +407,9 @@ class Weights:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
self.gptq_bits = data["quantization_config"]["bits"]
|
self.gptq_bits = data["quantization_config"]["bits"]
|
||||||
self.gptq_groupsize = data["quantization_config"]["group_size"]
|
self.gptq_groupsize = data["quantization_config"]["group_size"]
|
||||||
self.gptq_desc_act = data["quantization_config"]["desc_act"]
|
# Order is important here, desc_act is missing on some real models
|
||||||
self.quant_method = data["quantization_config"]["quant_method"]
|
self.quant_method = data["quantization_config"]["quant_method"]
|
||||||
|
self.gptq_desc_act = data["quantization_config"]["desc_act"]
|
||||||
except Exception:
|
except Exception:
|
||||||
filename = "quantize_config.json"
|
filename = "quantize_config.json"
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Reference in New Issue