Fp8 e4m3_fnuz support for rocm (#2588)

* (feat) fp8 fnuz support for rocm

* (review comments) Fix compression_config load, type hints

* (bug) update all has_tensor

* (review_comments) fix typo and added comments

* (nit) improved comment
This commit is contained in:
Mohit Sharma 2024-10-16 13:24:50 +05:30 committed by GitHub
parent ffe05ccd05
commit 704a58c807
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 211 additions and 37 deletions

View File

@ -1,7 +1,7 @@
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Union, List from typing import Optional, Tuple, Union, List
from loguru import logger from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
@ -51,8 +51,32 @@ def get_fp8_linear() -> torch.nn.Module:
return Fp8Linear return Fp8Linear
def normalize_e4m3fn_to_e4m3fnuz(
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
assert weight.dtype == torch.float8_e4m3fn
# The bits pattern 10000000(-128) represents zero in e4m3fn
# but NaN in e4m3fnuz. So here we set it to 0.
# https://onnx.ai/onnx/technical/float8.html
weight_as_int8 = weight.view(torch.int8)
ROCM_FP8_NAN_AS_INT = -128
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
weight = weight_as_int8.view(torch.float8_e4m3fnuz)
# For the same bits representation, e4m3fnuz value is half of
# the e4m3fn value, so we should double the scaling factor to
# get the same dequantized value.
# https://onnx.ai/onnx/technical/float8.html
weight_scale = weight_scale * 2.0
if input_scale is not None:
input_scale = input_scale * 2.0
return weight, weight_scale, input_scale
def fp8_quantize( def fp8_quantize(
weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False weight, scale=None, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False
): ):
if FBGEMM_DYN_AVAILABLE and not scalar: if FBGEMM_DYN_AVAILABLE and not scalar:
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row( qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
@ -62,8 +86,11 @@ def fp8_quantize(
# weight, scale = quant_weights(weight, torch.int8, False) # weight, scale = quant_weights(weight, torch.int8, False)
finfo = torch.finfo(qdtype) finfo = torch.finfo(qdtype)
if scale is None:
# Calculate the scale as dtype max divided by absmax # Calculate the scale as dtype max divided by absmax
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound) scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
# scale and clamp the tensor to bring it to # scale and clamp the tensor to bring it to
# the representative range of float8 data type # the representative range of float8 data type
# (as default cast is unsaturated) # (as default cast is unsaturated)
@ -72,6 +99,10 @@ def fp8_quantize(
# as both required as inputs to torch._scaled_mm # as both required as inputs to torch._scaled_mm
qweight = qweight.to(qdtype) qweight = qweight.to(qdtype)
scale = scale.float().reciprocal() scale = scale.float().reciprocal()
if SYSTEM == "rocm":
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale)
return qweight, scale return qweight, scale
@ -92,9 +123,17 @@ class HybridFP8UnquantLoader(WeightsLoader):
.reshape(-1) .reshape(-1)
.expand(w.shape[0]) .expand(w.shape[0])
) )
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub, activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype, dtype=weights.dtype,
) )
@ -125,9 +164,24 @@ class HybridFP8UnquantLoader(WeightsLoader):
) )
scale = scale.reshape(-1).expand(w.shape[0]) scale = scale.reshape(-1).expand(w.shape[0])
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
)
if input_scale.numel() > 1:
input_scale = weights.get_packed_sharded(
f"{prefix}.input_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
input_scale = input_scale.reshape(-1).max()
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub, activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype, dtype=weights.dtype,
) )
@ -154,9 +208,22 @@ class HybridFP8UnquantLoader(WeightsLoader):
] ]
scale = torch.cat(scale, dim=0).reshape(-1) scale = torch.cat(scale, dim=0).reshape(-1)
input_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
for p, shape in zip(prefixes, shapes)
if weights.has_tensor(f"{p}.input_scale")
]
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0
else None
)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub, activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype, dtype=weights.dtype,
) )
@ -174,9 +241,16 @@ class HybridFP8UnquantLoader(WeightsLoader):
.reshape(-1) .reshape(-1)
.expand(w.shape[0]) .expand(w.shape[0])
) )
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub, activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype, dtype=weights.dtype,
) )
@ -191,6 +265,7 @@ class Fp8Weight(Weight):
weight: torch.Tensor weight: torch.Tensor
dtype: torch.dtype dtype: torch.dtype
weight_scale: Optional[torch.Tensor] = None weight_scale: Optional[torch.Tensor] = None
input_scale: Optional[torch.Tensor] = None
activation_scale_ub: Optional[float] = None activation_scale_ub: Optional[float] = None
def get_linear(self, bias: torch.Tensor): def get_linear(self, bias: torch.Tensor):
@ -200,26 +275,43 @@ class Fp8Weight(Weight):
# memory. Can be non-contiguous when we e.g. expand from scalars. # memory. Can be non-contiguous when we e.g. expand from scalars.
self.weight_scale = self.weight_scale.contiguous() self.weight_scale = self.weight_scale.contiguous()
return get_fp8_linear().from_fp8( return get_fp8_linear().from_fp8(
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype weight=self.weight,
scale=self.weight_scale,
dtype=self.dtype,
bias=bias,
input_scale=self.input_scale,
scale_upper_bound=self.activation_scale_ub,
) )
class Fp8Linear(torch.nn.Module): class Fp8Linear(torch.nn.Module):
_device_identity_cache = {}
def __init__( def __init__(
self, self,
qweight, qweight: torch.Tensor,
scale, scale: torch.Tensor,
scale_upper_bound, dtype: torch.dtype,
bias, bias: Optional[torch.Tensor] = None,
dtype, input_scale: Optional[torch.Tensor] = None,
scale_upper_bound: Optional[float] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if FBGEMM_MM_AVAILABLE: if FBGEMM_MM_AVAILABLE:
log_once(logger.info, "Using FBGEMM fp8 optimized kernels") log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=qweight, weight_scale=scale
)
self.dtype = dtype self.dtype = dtype
self.qweight = qweight self.qweight = qweight
self.scale = scale self.scale = scale.float()
self.input_scale = (
input_scale.float().reciprocal() if input_scale is not None else None
)
if FBGEMM_MM_AVAILABLE:
self.scale_upper_bound = ( self.scale_upper_bound = (
torch.tensor( torch.tensor(
[scale_upper_bound], dtype=torch.float32, device=qweight.device [scale_upper_bound], dtype=torch.float32, device=qweight.device
@ -227,6 +319,8 @@ class Fp8Linear(torch.nn.Module):
if scale_upper_bound is not None if scale_upper_bound is not None
else None else None
) )
else:
self.scale_upper_bound = scale_upper_bound
self.bias = bias if bias is not None else None self.bias = bias if bias is not None else None
@ -234,22 +328,46 @@ class Fp8Linear(torch.nn.Module):
def from_unquant(cls, weight, bias, dtype): def from_unquant(cls, weight, bias, dtype):
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE) qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
return cls( return cls(
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype qweight=qweight,
scale=scale,
dtype=dtype,
bias=bias,
input_scale=None,
scale_upper_bound=None,
) )
@classmethod @classmethod
def from_fp8(cls, weight, scale, input_scale, bias, dtype): def from_fp8(
cls,
weight: torch.Tensor,
scale: torch.Tensor,
dtype: torch.dtype,
bias: Optional[torch.Tensor] = None,
**kwargs,
) -> "Fp8Linear":
input_scale = kwargs.get("input_scale", None)
scale_upper_bound = kwargs.get("scale_upper_bound", None)
if FBGEMM_DYN_AVAILABLE: if FBGEMM_DYN_AVAILABLE:
# fbgemm needs float32 scales. # fbgemm needs float32 scales.
scale = scale.float() scale = scale.float()
return cls( return cls(
qweight=weight, qweight=weight,
scale=scale, scale=scale,
scale_upper_bound=input_scale, input_scale=input_scale,
scale_upper_bound=scale_upper_bound,
bias=bias, bias=bias,
dtype=dtype, dtype=dtype,
) )
@classmethod
def get_shared_device_identity(cls, device):
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
if device not in cls._device_identity_cache:
cls._device_identity_cache[device] = torch.ones(1, device=device)
return cls._device_identity_cache[device]
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
if FBGEMM_MM_AVAILABLE: if FBGEMM_MM_AVAILABLE:
qinput, scale = fp8_quantize( qinput, scale = fp8_quantize(
@ -266,8 +384,18 @@ class Fp8Linear(torch.nn.Module):
) )
return y.to(self.dtype) return y.to(self.dtype)
qinput, scale = fp8_quantize(input, scalar=True) qinput, scale = fp8_quantize(
output, _ = torch._scaled_mm( input,
self.input_scale,
scale_upper_bound=self.scale_upper_bound,
scalar=True,
)
per_tensor_weights = self.scale.numel() == 1
per_tensor_activations = scale.numel() == 1
if SYSTEM != "rocm" or (per_tensor_weights and per_tensor_activations):
output = torch._scaled_mm(
qinput, qinput,
self.qweight.t(), self.qweight.t(),
out_dtype=self.dtype, out_dtype=self.dtype,
@ -275,6 +403,30 @@ class Fp8Linear(torch.nn.Module):
scale_b=self.scale, scale_b=self.scale,
bias=self.bias, bias=self.bias,
) )
if isinstance(output, tuple) and len(output) == 2:
output = output[0]
else:
device_identity = None
if SYSTEM == "rocm":
device_identity = self.get_shared_device_identity(self.qweight.device)
output = torch._scaled_mm(
qinput,
self.qweight.t(),
scale_a=device_identity,
scale_b=device_identity,
out_dtype=torch.float32,
)
if isinstance(output, tuple) and len(output) == 2:
output = output[0]
output = output * scale * self.scale.t()
if self.bias is not None:
output = output + self.bias
output = output.to(dtype=self.dtype)
return output return output

View File

@ -392,7 +392,7 @@ class GPTQWeightsLoader(WeightsLoader):
) )
def _get_gptq_params(self, weights: Weights): def _get_gptq_params(self, weights: Weights):
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"): if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"):
self.bits = weights.get_tensor("gptq_bits").item() self.bits = weights.get_tensor("gptq_bits").item()
self.groupsize = weights.get_tensor("gptq_groupsize").item() self.groupsize = weights.get_tensor("gptq_groupsize").item()
self.desc_act = False self.desc_act = False
@ -400,7 +400,7 @@ class GPTQWeightsLoader(WeightsLoader):
# before the `gptq_sym` setting tensor was added. # before the `gptq_sym` setting tensor was added.
self.sym = ( self.sym = (
weights.get_tensor("gptq_sym").item() weights.get_tensor("gptq_sym").item()
if weights._has_tensor("gptq_sym") if weights.has_tensor("gptq_sym")
else False else False
) )
self.quant_method = "gptq" self.quant_method = "gptq"

View File

@ -62,7 +62,14 @@ class GPTQMarlinFP8Linear(nn.Module):
return cls(qweight=qweight, scales=scales.to(dtype), bias=bias) return cls(qweight=qweight, scales=scales.to(dtype), bias=bias)
@classmethod @classmethod
def from_fp8(cls, weight, scale, _input_scale, bias, dtype): def from_fp8(
cls,
weight: torch.Tensor,
scale: torch.Tensor,
bias: torch.Tensor,
dtype: torch.dtype,
**kwargs,
):
return cls(qweight=weight, scales=scale.to(dtype), bias=bias) return cls(qweight=weight, scales=scale.to(dtype), bias=bias)
def forward(self, A: torch.Tensor) -> torch.Tensor: def forward(self, A: torch.Tensor) -> torch.Tensor:

View File

@ -231,7 +231,7 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
) )
def _get_gptq_params(self, weights: Weights): def _get_gptq_params(self, weights: Weights):
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"): if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"):
self.bits = weights.get_tensor("gptq_bits").item() self.bits = weights.get_tensor("gptq_bits").item()
self.groupsize = weights.get_tensor("gptq_groupsize").item() self.groupsize = weights.get_tensor("gptq_groupsize").item()
self.desc_act = False self.desc_act = False
@ -239,7 +239,7 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
# before the `gptq_sym` setting tensor was added. # before the `gptq_sym` setting tensor was added.
self.sym = ( self.sym = (
weights.get_tensor("gptq_sym").item() weights.get_tensor("gptq_sym").item()
if weights._has_tensor("gptq_sym") if weights.has_tensor("gptq_sym")
else False else False
) )
self.quant_method = "gptq" self.quant_method = "gptq"

View File

@ -357,17 +357,32 @@ def get_model(
compression_config = config_dict.get("compression_config", None) compression_config = config_dict.get("compression_config", None)
if quantization_config is not None and quantize is None: if quantization_config is not None and quantize is None:
method = quantization_config.get("quant_method", None) method = quantization_config.get("quant_method", None)
config_groups = quantization_config.get("config_groups", None)
if method in {"gptq", "awq", "exl2"}: if method in {"gptq", "awq", "exl2"}:
log_master(logger.info, f"Auto selecting quantization method {method}") log_master(logger.info, f"Auto selecting quantization method {method}")
quantize = method quantize = method
elif method == "fbgemm_fp8": elif method == "fbgemm_fp8" or method == "fp8":
log_master(logger.info, "Auto selecting quantization method fp8") log_master(logger.info, "Auto selecting quantization method fp8")
quantize = "fp8" quantize = "fp8"
elif config_groups is not None:
# TODO: at some point we should probably fully parse the compression
# configuration to know which parameters are compressed.
for _, group in config_groups.items():
weights_config = group.get("weights")
if weights_config is not None:
if (
weights_config["type"] == "float"
and weights_config["num_bits"] == 8
):
log_master(
logger.info, "Auto selecting quantization method fp8"
)
quantize = "fp8"
break
else: else:
log_master(logger.warning, f"Unknown quantization method {method}") log_master(logger.warning, f"Unknown quantization method {method}")
elif compression_config is not None: elif compression_config is not None:
# TODO: at some point we should probably fully parse the compression # `compression_config` renamed to `quantization_config`; support retained for backward compatibility.
# configuration to know which parameters are compressed.
config_groups = compression_config.get("config_groups") config_groups = compression_config.get("config_groups")
if config_groups is not None: if config_groups is not None:
for _, group in config_groups.items(): for _, group in config_groups.items():

View File

@ -197,7 +197,7 @@ class Weights:
slice_ = f.get_slice(tensor_name) slice_ = f.get_slice(tensor_name)
return slice_ return slice_
def _has_tensor(self, tensor_name: str): def has_tensor(self, tensor_name: str):
try: try:
self.get_filename(tensor_name) self.get_filename(tensor_name)
except Exception: except Exception: