Add support for FP8 on compute capability >=8.0, <8.9 (#2213)

Use FP8 GPTQ-Marlin kernels to enable FP8 support on CUDA GPUs
with compute capability >=8.0 and <8.9.

Co-authored-by: Florian Zimmermeister <flozi00.fz@gmail.com>
This commit is contained in:
Daniël de Kok 2024-07-11 16:03:26 +02:00 committed by GitHub
parent 8511669cb2
commit cb150eb295
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 1465 additions and 4 deletions

View File

@ -59,3 +59,18 @@ def marlin_gemm(
Matrix multiplication using Marlin kernels. Matrix multiplication using Marlin kernels.
""" """
... ...
# fp8 marlin
def fp8_marlin_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
workspace: torch.Tensor,
num_bits: int,
size_m: int,
size_n: int,
size_k: int,
) -> torch.Tensor:
return torch.ops._C.fp8_marlin_gemm(
a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
)

View File

@ -9,4 +9,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gptq_marlin_repack", &gptq_marlin_repack, m.def("gptq_marlin_repack", &gptq_marlin_repack,
"Repack GPTQ parameters for Marlin"); "Repack GPTQ parameters for Marlin");
m.def("marlin_gemm", &marlin_gemm, "Marlin gemm"); m.def("marlin_gemm", &marlin_gemm, "Marlin gemm");
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
m.def("fp8_marlin_gemm", &fp8_marlin_gemm);
} }

View File

@ -27,4 +27,9 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &b_scales, torch::Tensor &workspace, torch::Tensor &b_scales, torch::Tensor &workspace,
int64_t size_m, int64_t size_n, int64_t size_k); int64_t size_m, int64_t size_n, int64_t size_k);
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k);
#endif #endif

File diff suppressed because it is too large Load Diff

View File

@ -9,6 +9,7 @@ setup(
CUDAExtension( CUDAExtension(
name="marlin_kernels", name="marlin_kernels",
sources=[ sources=[
"marlin_kernels/fp8_marlin.cu",
"marlin_kernels/gptq_marlin.cu", "marlin_kernels/gptq_marlin.cu",
"marlin_kernels/gptq_marlin_repack.cu", "marlin_kernels/gptq_marlin_repack.cu",
"marlin_kernels/marlin_cuda_kernel.cu", "marlin_kernels/marlin_cuda_kernel.cu",

View File

@ -1,4 +1,23 @@
from enum import Enum, auto
import torch import torch
from text_generation_server.utils.import_utils import SYSTEM
def get_fp8_linear() -> torch.nn.Module:
"""
Return an FP8 linear `Module` that is compatible with the current system.
"""
if SYSTEM == "cuda":
major, minor = torch.cuda.get_device_capability()
if major == 8 and minor < 9:
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
return GPTQMarlinFP8Linear
# On other systems let Torch decide if the hardware supports FP8.
return Fp8Linear
def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):

View File

@ -106,9 +106,9 @@ def get_linear(weight, bias, quantize):
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ" "Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
) )
elif quantize == "fp8": elif quantize == "fp8":
from text_generation_server.layers.fp8 import Fp8Linear from text_generation_server.layers.fp8 import get_fp8_linear
linear = Fp8Linear(weight, bias) linear = get_fp8_linear()(weight, bias)
elif quantize == "bitsandbytes": elif quantize == "bitsandbytes":
try: try:
from text_generation_server.layers.bnb import ( from text_generation_server.layers.bnb import (

View File

@ -1,11 +1,13 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from text_generation_server.utils.weights import Weights, WeightsLoader
import torch import torch
import torch.nn as nn import torch.nn as nn
from loguru import logger
from text_generation_server.layers.fp8 import fp8_quantize
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weights, WeightsLoader
try: try:
import marlin_kernels import marlin_kernels
@ -455,6 +457,115 @@ class GPTQMarlin24Linear(nn.Module):
return C return C
class GPTQMarlinFP8Linear(nn.Module):
"""
FP8 GPTQ-Marlin linear layer.
"""
def __init__(
self,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
) -> None:
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
qweight, scale = fp8_quantize(weight)
scale = scale.to(torch.float16)
qweight, scales = repack_fp8_for_marlin(qweight, scale)
in_features = qweight.shape[0] * MARLIN_TILE_SIZE
out_features = scales.shape[1]
_check_valid_shape(in_features=in_features, out_features=out_features)
self.qweight = qweight
self.scales = scales
self.bias = bias if bias is not None else None
self.workspace = torch.zeros(
out_features // 64 * 16, dtype=torch.int, device=qweight.device
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
A_flat = A.view(-1, A.shape[-1])
C = marlin_kernels.fp8_marlin_gemm(
A_flat,
self.qweight,
self.scales,
self.workspace,
8,
A_flat.shape[0],
self.scales.shape[1],
A_flat.shape[1],
)
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
if self.bias is not None:
C += self.bias
return C
def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
"""
Repack FP8 weights to gptq format (packed int32 elements).
"""
assert fp8_tensor.dtype == torch.float8_e4m3fn
if fp8_tensor.shape[0] % 4 != 0:
raise ValueError(
f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}"
)
# Reshape to prepare for packing
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
# Convert fp8 to uint8 (byte) representation
byte_tensor = reshaped.view(torch.uint8)
# Pack 4 uint8 values into one int32
packed = torch.zeros(
fp8_tensor.shape[0] // 4,
fp8_tensor.shape[1],
dtype=torch.int32,
device=fp8_tensor.device,
)
for i in range(4):
packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8)
return packed
def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor):
"""
Repack FP8 tensor for GPTQ-Marlin.
"""
out_features, in_features = weight.shape
# Torch linear layers weights with shape [out_features, in_features],
# GPTQ-quantized weights use [in_feateres/pack_factor, in_features],
# so transpose before packing.
qweight = pack_fp8_as_int32(weight.t())
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
repacked = marlin_kernels.gptq_marlin_repack(
qweight, perm, in_features, out_features, 8
)
scales = scale.reshape(1, 1).repeat(1, out_features)
scales = permute_scales(scales)
return repacked, scales
@dataclass @dataclass
class MarlinWeight: class MarlinWeight:
""" """