Add support for FP8 on compute capability >=8.0, <8.9 (#2213)
Use FP8 GPTQ-Marlin kernels to enable FP8 support on CUDA GPUs with compute capability >=8.0 and <8.9. Co-authored-by: Florian Zimmermeister <flozi00.fz@gmail.com>
This commit is contained in:
parent
8511669cb2
commit
cb150eb295
|
@ -59,3 +59,18 @@ def marlin_gemm(
|
||||||
Matrix multiplication using Marlin kernels.
|
Matrix multiplication using Marlin kernels.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
# fp8 marlin
|
||||||
|
def fp8_marlin_gemm(
|
||||||
|
a: torch.Tensor,
|
||||||
|
b_q_weight: torch.Tensor,
|
||||||
|
b_scales: torch.Tensor,
|
||||||
|
workspace: torch.Tensor,
|
||||||
|
num_bits: int,
|
||||||
|
size_m: int,
|
||||||
|
size_n: int,
|
||||||
|
size_k: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops._C.fp8_marlin_gemm(
|
||||||
|
a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
|
||||||
|
)
|
||||||
|
|
|
@ -9,4 +9,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
m.def("gptq_marlin_repack", &gptq_marlin_repack,
|
m.def("gptq_marlin_repack", &gptq_marlin_repack,
|
||||||
"Repack GPTQ parameters for Marlin");
|
"Repack GPTQ parameters for Marlin");
|
||||||
m.def("marlin_gemm", &marlin_gemm, "Marlin gemm");
|
m.def("marlin_gemm", &marlin_gemm, "Marlin gemm");
|
||||||
|
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
|
||||||
|
m.def("fp8_marlin_gemm", &fp8_marlin_gemm);
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,4 +27,9 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
||||||
torch::Tensor &b_scales, torch::Tensor &workspace,
|
torch::Tensor &b_scales, torch::Tensor &workspace,
|
||||||
int64_t size_m, int64_t size_n, int64_t size_k);
|
int64_t size_m, int64_t size_n, int64_t size_k);
|
||||||
|
|
||||||
|
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
|
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||||
|
int64_t num_bits, int64_t size_m, int64_t size_n,
|
||||||
|
int64_t size_k);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -9,6 +9,7 @@ setup(
|
||||||
CUDAExtension(
|
CUDAExtension(
|
||||||
name="marlin_kernels",
|
name="marlin_kernels",
|
||||||
sources=[
|
sources=[
|
||||||
|
"marlin_kernels/fp8_marlin.cu",
|
||||||
"marlin_kernels/gptq_marlin.cu",
|
"marlin_kernels/gptq_marlin.cu",
|
||||||
"marlin_kernels/gptq_marlin_repack.cu",
|
"marlin_kernels/gptq_marlin_repack.cu",
|
||||||
"marlin_kernels/marlin_cuda_kernel.cu",
|
"marlin_kernels/marlin_cuda_kernel.cu",
|
||||||
|
|
|
@ -1,4 +1,23 @@
|
||||||
|
from enum import Enum, auto
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
|
||||||
|
def get_fp8_linear() -> torch.nn.Module:
|
||||||
|
"""
|
||||||
|
Return an FP8 linear `Module` that is compatible with the current system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if SYSTEM == "cuda":
|
||||||
|
major, minor = torch.cuda.get_device_capability()
|
||||||
|
if major == 8 and minor < 9:
|
||||||
|
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
|
||||||
|
|
||||||
|
return GPTQMarlinFP8Linear
|
||||||
|
|
||||||
|
# On other systems let Torch decide if the hardware supports FP8.
|
||||||
|
return Fp8Linear
|
||||||
|
|
||||||
|
|
||||||
def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
|
def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
|
||||||
|
|
|
@ -106,9 +106,9 @@ def get_linear(weight, bias, quantize):
|
||||||
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
|
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
|
||||||
)
|
)
|
||||||
elif quantize == "fp8":
|
elif quantize == "fp8":
|
||||||
from text_generation_server.layers.fp8 import Fp8Linear
|
from text_generation_server.layers.fp8 import get_fp8_linear
|
||||||
|
|
||||||
linear = Fp8Linear(weight, bias)
|
linear = get_fp8_linear()(weight, bias)
|
||||||
elif quantize == "bitsandbytes":
|
elif quantize == "bitsandbytes":
|
||||||
try:
|
try:
|
||||||
from text_generation_server.layers.bnb import (
|
from text_generation_server.layers.bnb import (
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
from text_generation_server.utils.weights import Weights, WeightsLoader
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from loguru import logger
|
||||||
|
from text_generation_server.layers.fp8 import fp8_quantize
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.log import log_once
|
||||||
|
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import marlin_kernels
|
import marlin_kernels
|
||||||
|
@ -455,6 +457,115 @@ class GPTQMarlin24Linear(nn.Module):
|
||||||
return C
|
return C
|
||||||
|
|
||||||
|
|
||||||
|
class GPTQMarlinFP8Linear(nn.Module):
|
||||||
|
"""
|
||||||
|
FP8 GPTQ-Marlin linear layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor],
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
_check_marlin_kernels()
|
||||||
|
assert marlin_kernels is not None
|
||||||
|
|
||||||
|
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
|
||||||
|
|
||||||
|
qweight, scale = fp8_quantize(weight)
|
||||||
|
scale = scale.to(torch.float16)
|
||||||
|
qweight, scales = repack_fp8_for_marlin(qweight, scale)
|
||||||
|
|
||||||
|
in_features = qweight.shape[0] * MARLIN_TILE_SIZE
|
||||||
|
out_features = scales.shape[1]
|
||||||
|
_check_valid_shape(in_features=in_features, out_features=out_features)
|
||||||
|
|
||||||
|
self.qweight = qweight
|
||||||
|
self.scales = scales
|
||||||
|
self.bias = bias if bias is not None else None
|
||||||
|
|
||||||
|
self.workspace = torch.zeros(
|
||||||
|
out_features // 64 * 16, dtype=torch.int, device=qweight.device
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||||
|
assert marlin_kernels is not None
|
||||||
|
|
||||||
|
A_flat = A.view(-1, A.shape[-1])
|
||||||
|
C = marlin_kernels.fp8_marlin_gemm(
|
||||||
|
A_flat,
|
||||||
|
self.qweight,
|
||||||
|
self.scales,
|
||||||
|
self.workspace,
|
||||||
|
8,
|
||||||
|
A_flat.shape[0],
|
||||||
|
self.scales.shape[1],
|
||||||
|
A_flat.shape[1],
|
||||||
|
)
|
||||||
|
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
|
||||||
|
|
||||||
|
if self.bias is not None:
|
||||||
|
C += self.bias
|
||||||
|
|
||||||
|
return C
|
||||||
|
|
||||||
|
|
||||||
|
def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Repack FP8 weights to gptq format (packed int32 elements).
|
||||||
|
"""
|
||||||
|
assert fp8_tensor.dtype == torch.float8_e4m3fn
|
||||||
|
|
||||||
|
if fp8_tensor.shape[0] % 4 != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape to prepare for packing
|
||||||
|
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
|
||||||
|
|
||||||
|
# Convert fp8 to uint8 (byte) representation
|
||||||
|
byte_tensor = reshaped.view(torch.uint8)
|
||||||
|
|
||||||
|
# Pack 4 uint8 values into one int32
|
||||||
|
packed = torch.zeros(
|
||||||
|
fp8_tensor.shape[0] // 4,
|
||||||
|
fp8_tensor.shape[1],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=fp8_tensor.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(4):
|
||||||
|
packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8)
|
||||||
|
|
||||||
|
return packed
|
||||||
|
|
||||||
|
|
||||||
|
def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Repack FP8 tensor for GPTQ-Marlin.
|
||||||
|
"""
|
||||||
|
|
||||||
|
out_features, in_features = weight.shape
|
||||||
|
|
||||||
|
# Torch linear layers weights with shape [out_features, in_features],
|
||||||
|
# GPTQ-quantized weights use [in_feateres/pack_factor, in_features],
|
||||||
|
# so transpose before packing.
|
||||||
|
qweight = pack_fp8_as_int32(weight.t())
|
||||||
|
|
||||||
|
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
|
||||||
|
repacked = marlin_kernels.gptq_marlin_repack(
|
||||||
|
qweight, perm, in_features, out_features, 8
|
||||||
|
)
|
||||||
|
|
||||||
|
scales = scale.reshape(1, 1).repeat(1, out_features)
|
||||||
|
scales = permute_scales(scales)
|
||||||
|
|
||||||
|
return repacked, scales
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MarlinWeight:
|
class MarlinWeight:
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue