449 lines
15 KiB
Python
449 lines
15 KiB
Python
from dataclasses import dataclass
|
|
import os
|
|
from typing import Optional, Tuple, Type, Union, List
|
|
|
|
import torch
|
|
from loguru import logger
|
|
|
|
from text_generation_server.utils.import_utils import SYSTEM
|
|
from text_generation_server.utils.weights import (
|
|
Weight,
|
|
WeightsLoader,
|
|
UnquantizedWeight,
|
|
Weights,
|
|
)
|
|
from text_generation_server.utils.log import log_once
|
|
|
|
try:
|
|
import marlin_kernels
|
|
except ImportError:
|
|
marlin_kernels = None
|
|
|
|
|
|
if SYSTEM == "cuda" and marlin_kernels is not None:
|
|
major, minor = torch.cuda.get_device_capability()
|
|
CUTLASS_FP8_AVAILABLE = marlin_kernels.cutlass_scaled_mm_supports_fp8(
|
|
major * 10 + minor
|
|
)
|
|
else:
|
|
CUTLASS_FP8_AVAILABLE = False
|
|
|
|
|
|
def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
|
|
"""
|
|
Return an FP8 linear `Module` that is compatible with the current system.
|
|
"""
|
|
|
|
if SYSTEM == "cuda":
|
|
|
|
major, _ = torch.cuda.get_device_capability()
|
|
# Marlin is W8A16, use it when:
|
|
#
|
|
# - On capability 8.x where x < 8: W8A8 FP8 GEMM is not supported.
|
|
# - On capability 8.9: W8A8 FP8 GEMM is supported, but Marlin-FP8 is faster.
|
|
# - On capability 9.x when force_w8a16: cutlass kernels do not support W8A16.
|
|
if (major == 8 or (major == 9 and force_w8a16)) and os.getenv(
|
|
"USE_CUTLASS_W8A8", "0"
|
|
) != "1":
|
|
# NOTE: Capability 8.9 is supported by cutlass kernels, but FP8-Marlin
|
|
# gives better decoding throughput on L4 and L40.
|
|
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
|
|
|
|
return GPTQMarlinFP8Linear
|
|
|
|
# On other systems let Torch decide if the hardware supports FP8.
|
|
return Fp8Linear
|
|
|
|
|
|
def normalize_e4m3fn_to_e4m3fnuz(
|
|
weight: torch.Tensor,
|
|
weight_scale: torch.Tensor,
|
|
input_scale: Optional[torch.Tensor] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
|
assert weight.dtype == torch.float8_e4m3fn
|
|
# The bits pattern 10000000(-128) represents zero in e4m3fn
|
|
# but NaN in e4m3fnuz. So here we set it to 0.
|
|
# https://onnx.ai/onnx/technical/float8.html
|
|
weight_as_int8 = weight.view(torch.int8)
|
|
ROCM_FP8_NAN_AS_INT = -128
|
|
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
|
|
weight = weight_as_int8.view(torch.float8_e4m3fnuz)
|
|
|
|
# For the same bits representation, e4m3fnuz value is half of
|
|
# the e4m3fn value, so we should double the scaling factor to
|
|
# get the same dequantized value.
|
|
# https://onnx.ai/onnx/technical/float8.html
|
|
weight_scale = weight_scale * 2.0
|
|
if input_scale is not None:
|
|
input_scale = input_scale * 2.0
|
|
return weight, weight_scale, input_scale
|
|
|
|
|
|
def fp8_quantize(
|
|
weight: torch.Tensor,
|
|
scale: Optional[torch.Tensor] = None,
|
|
scale_upper_bound: Optional[torch.Tensor] = None,
|
|
qdtype: torch.dtype = torch.float8_e4m3fn,
|
|
scalar: bool = False,
|
|
):
|
|
"""
|
|
This function returns a reciprocal of the scale, so that a tensor can be unscaled
|
|
by multiplying it with the returned scale. If a scale is given through the `scale`
|
|
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
|
|
be used without modification).
|
|
"""
|
|
if marlin_kernels is not None:
|
|
shape = weight.shape
|
|
qweight, scale = marlin_kernels.scaled_fp8_quant(
|
|
weight.reshape(-1, shape[-1]),
|
|
dtype=qdtype,
|
|
scale=scale,
|
|
scale_ub=scale_upper_bound,
|
|
# TODO: don't do this when we have to use the Torch kernel.
|
|
use_per_token_if_dynamic=not scalar,
|
|
)
|
|
|
|
return qweight.reshape(shape), scale
|
|
|
|
finfo = torch.finfo(qdtype)
|
|
|
|
if scale is None:
|
|
# Calculate the scale as dtype max divided by absmax
|
|
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
|
|
# scale and clamp the tensor to bring it to
|
|
# the representative range of float8 data type
|
|
# (as default cast is unsaturated)
|
|
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
|
|
scale = scale.float().reciprocal()
|
|
else:
|
|
# Use reciprocal to avoid more expensive division.
|
|
qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
|
|
|
|
# Return both float8 data and the inverse scale (as float),
|
|
# as both required as inputs to torch._scaled_mm
|
|
qweight = qweight.to(qdtype)
|
|
|
|
if SYSTEM == "rocm":
|
|
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale)
|
|
|
|
return qweight, scale
|
|
|
|
|
|
class HybridFP8UnquantLoader(WeightsLoader):
|
|
"""Weight loader that loads FP8 and unquantized Torch tensors."""
|
|
|
|
def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool):
|
|
self.activation_scale_ub = activation_scale_ub
|
|
self.to_fp8 = to_fp8
|
|
|
|
def get_weights(self, weights: "Weights", prefix: str):
|
|
w = weights.get_tensor(f"{prefix}.weight")
|
|
|
|
if w.dtype == torch.float8_e4m3fn:
|
|
# FP8 branch
|
|
scale = (
|
|
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
|
.reshape(-1)
|
|
.expand(w.shape[0])
|
|
)
|
|
|
|
input_scale = None
|
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
|
input_scale = weights.get_tensor(
|
|
f"{prefix}.input_scale", to_dtype=False
|
|
).reshape(-1)
|
|
|
|
return Fp8Weight(
|
|
weight=w,
|
|
weight_scale=scale,
|
|
input_scale=input_scale,
|
|
activation_scale_ub=self.activation_scale_ub,
|
|
dtype=weights.dtype,
|
|
)
|
|
if self.to_fp8:
|
|
return Fp8Weight(weight=w, dtype=weights.dtype)
|
|
|
|
return UnquantizedWeight(w)
|
|
|
|
def get_weights_col_packed(
|
|
self,
|
|
weights: Weights,
|
|
prefix: str,
|
|
block_sizes: Union[int, List[int]],
|
|
):
|
|
w = weights.get_packed_sharded(
|
|
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
|
)
|
|
|
|
if w.dtype == torch.float8_e4m3fn:
|
|
# FP8 branch
|
|
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
|
if scale.numel() > 1:
|
|
scale = weights.get_packed_sharded(
|
|
f"{prefix}.weight_scale",
|
|
dim=0,
|
|
block_sizes=block_sizes,
|
|
to_dtype=False,
|
|
)
|
|
scale = scale.reshape(-1).expand(w.shape[0])
|
|
|
|
input_scale = None
|
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
|
input_scale = weights.get_tensor(
|
|
f"{prefix}.input_scale", to_dtype=False
|
|
)
|
|
if input_scale.numel() > 1:
|
|
input_scale = weights.get_packed_sharded(
|
|
f"{prefix}.input_scale",
|
|
dim=0,
|
|
block_sizes=block_sizes,
|
|
to_dtype=False,
|
|
)
|
|
input_scale = input_scale.reshape(-1).max()
|
|
|
|
return Fp8Weight(
|
|
weight=w,
|
|
weight_scale=scale,
|
|
input_scale=input_scale,
|
|
activation_scale_ub=self.activation_scale_ub,
|
|
dtype=weights.dtype,
|
|
)
|
|
if self.to_fp8:
|
|
return Fp8Weight(weight=w, dtype=weights.dtype)
|
|
|
|
return UnquantizedWeight(w)
|
|
|
|
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
|
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
|
|
w = [
|
|
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
|
|
]
|
|
shapes = [x.shape for x in w]
|
|
|
|
# Concat then send to the device
|
|
w = torch.cat(w, dim=dim).to(weights.device)
|
|
|
|
# FP8 branch
|
|
if w.dtype == torch.float8_e4m3fn:
|
|
scale = [
|
|
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
|
|
for p, shape in zip(prefixes, shapes)
|
|
]
|
|
scale = torch.cat(scale, dim=0).reshape(-1)
|
|
|
|
input_scale = [
|
|
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
|
|
for p, shape in zip(prefixes, shapes)
|
|
if weights.has_tensor(f"{p}.input_scale")
|
|
]
|
|
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
|
|
input_scale = (
|
|
torch.cat(input_scale, dim=0).reshape(-1).max()
|
|
if len(input_scale) != 0
|
|
else None
|
|
)
|
|
|
|
return Fp8Weight(
|
|
weight=w,
|
|
weight_scale=scale,
|
|
input_scale=input_scale,
|
|
activation_scale_ub=self.activation_scale_ub,
|
|
dtype=weights.dtype,
|
|
)
|
|
if self.to_fp8:
|
|
return Fp8Weight(weight=w, dtype=weights.dtype)
|
|
|
|
return UnquantizedWeight(w)
|
|
|
|
def get_weights_row(self, weights: "Weights", prefix: str):
|
|
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
|
# FP8 branch
|
|
if w.dtype == torch.float8_e4m3fn:
|
|
scale = (
|
|
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
|
.reshape(-1)
|
|
.expand(w.shape[0])
|
|
)
|
|
input_scale = None
|
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
|
input_scale = weights.get_tensor(
|
|
f"{prefix}.input_scale", to_dtype=False
|
|
).reshape(-1)
|
|
|
|
return Fp8Weight(
|
|
weight=w,
|
|
weight_scale=scale,
|
|
input_scale=input_scale,
|
|
activation_scale_ub=self.activation_scale_ub,
|
|
dtype=weights.dtype,
|
|
)
|
|
if self.to_fp8:
|
|
return Fp8Weight(weight=w, dtype=weights.dtype)
|
|
|
|
return UnquantizedWeight(w)
|
|
|
|
|
|
@dataclass
|
|
class Fp8Weight(Weight):
|
|
weight: torch.Tensor
|
|
dtype: torch.dtype
|
|
weight_scale: Optional[torch.Tensor] = None
|
|
input_scale: Optional[torch.Tensor] = None
|
|
activation_scale_ub: Optional[float] = None
|
|
force_w8a16: bool = False
|
|
|
|
def get_linear(self, bias: torch.Tensor):
|
|
if self.weight_scale is None:
|
|
return get_fp8_linear(force_w8a16=self.force_w8a16).from_unquant(
|
|
self.weight, bias, self.dtype
|
|
)
|
|
# This is not checked by the fbgemm kernels, but they require contiguous
|
|
# memory. Can be non-contiguous when we e.g. expand from scalars.
|
|
self.weight_scale = self.weight_scale.contiguous()
|
|
return get_fp8_linear(force_w8a16=self.force_w8a16).from_fp8(
|
|
weight=self.weight,
|
|
scale=self.weight_scale,
|
|
dtype=self.dtype,
|
|
bias=bias,
|
|
input_scale=self.input_scale,
|
|
scale_upper_bound=self.activation_scale_ub,
|
|
)
|
|
|
|
|
|
class Fp8Linear(torch.nn.Module):
|
|
_device_identity_cache = {}
|
|
|
|
def __init__(
|
|
self,
|
|
qweight: torch.Tensor,
|
|
scale: torch.Tensor,
|
|
dtype: torch.dtype,
|
|
bias: Optional[torch.Tensor] = None,
|
|
input_scale: Optional[torch.Tensor] = None,
|
|
scale_upper_bound: Optional[float] = None,
|
|
) -> None:
|
|
super().__init__()
|
|
if CUTLASS_FP8_AVAILABLE:
|
|
log_once(logger.info, "Using cutlass w8a8 kernels")
|
|
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
|
|
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
|
weight=qweight, weight_scale=scale
|
|
)
|
|
|
|
self.dtype = dtype
|
|
self.qweight = qweight
|
|
self.scale = scale.float()
|
|
self.input_scale = input_scale.float() if input_scale is not None else None
|
|
|
|
if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None:
|
|
self.scale_upper_bound = torch.tensor(
|
|
scale_upper_bound, dtype=torch.float32, device=qweight.device
|
|
)
|
|
else:
|
|
self.scale_upper_bound = scale_upper_bound
|
|
|
|
self.bias = bias if bias is not None else None
|
|
|
|
@classmethod
|
|
def from_unquant(cls, weight, bias, dtype):
|
|
qweight, scale = fp8_quantize(weight, scalar=not CUTLASS_FP8_AVAILABLE)
|
|
return cls(
|
|
qweight=qweight,
|
|
scale=scale,
|
|
dtype=dtype,
|
|
bias=bias,
|
|
input_scale=None,
|
|
scale_upper_bound=None,
|
|
)
|
|
|
|
@classmethod
|
|
def from_fp8(
|
|
cls,
|
|
weight: torch.Tensor,
|
|
scale: torch.Tensor,
|
|
dtype: torch.dtype,
|
|
bias: Optional[torch.Tensor] = None,
|
|
**kwargs,
|
|
) -> "Fp8Linear":
|
|
input_scale = kwargs.get("input_scale", None)
|
|
scale_upper_bound = kwargs.get("scale_upper_bound", None)
|
|
|
|
return cls(
|
|
qweight=weight,
|
|
scale=scale,
|
|
input_scale=input_scale,
|
|
scale_upper_bound=scale_upper_bound,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
)
|
|
|
|
@classmethod
|
|
def get_shared_device_identity(cls, device):
|
|
# Input scaling factors are no longer optional in _scaled_mm starting
|
|
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
|
if device not in cls._device_identity_cache:
|
|
cls._device_identity_cache[device] = torch.ones(1, device=device)
|
|
return cls._device_identity_cache[device]
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
if CUTLASS_FP8_AVAILABLE:
|
|
# cutlass FP8 supports per-token scales, so get non-scalar scales.
|
|
qinput, scale = fp8_quantize(
|
|
input, scale_upper_bound=self.scale_upper_bound, scalar=False
|
|
)
|
|
return marlin_kernels.cutlass_scaled_mm(
|
|
qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias
|
|
)
|
|
|
|
qinput, scale = fp8_quantize(
|
|
input,
|
|
self.input_scale,
|
|
scale_upper_bound=self.scale_upper_bound,
|
|
scalar=True,
|
|
)
|
|
|
|
per_tensor_weights = self.scale.numel() == 1
|
|
per_tensor_activations = scale.numel() == 1
|
|
|
|
if SYSTEM != "rocm" or (per_tensor_weights and per_tensor_activations):
|
|
output = torch._scaled_mm(
|
|
qinput,
|
|
self.qweight.t(),
|
|
out_dtype=self.dtype,
|
|
scale_a=scale,
|
|
scale_b=self.scale,
|
|
bias=self.bias,
|
|
)
|
|
|
|
if isinstance(output, tuple) and len(output) == 2:
|
|
output = output[0]
|
|
else:
|
|
device_identity = None
|
|
if SYSTEM == "rocm":
|
|
device_identity = self.get_shared_device_identity(self.qweight.device)
|
|
|
|
output = torch._scaled_mm(
|
|
qinput,
|
|
self.qweight.t(),
|
|
scale_a=device_identity,
|
|
scale_b=device_identity,
|
|
out_dtype=torch.float32,
|
|
)
|
|
if isinstance(output, tuple) and len(output) == 2:
|
|
output = output[0]
|
|
|
|
output = output * scale * self.scale.t()
|
|
if self.bias is not None:
|
|
output = output + self.bias
|
|
|
|
output = output.to(dtype=self.dtype)
|
|
|
|
return output
|
|
|
|
|
|
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
|
|
scale = weights.get_tensor(prefix, to_dtype=False)
|
|
if scale.numel() > 1:
|
|
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
|
|
return scale.reshape(-1).expand(shape[0])
|