(feat) fp8 fnuz support for rocm

This commit is contained in:
Mohit Sharma 2024-09-30 11:43:45 +00:00
parent 2401fdc889
commit 8ee9823d3b
3 changed files with 159 additions and 24 deletions

View File

@ -1,7 +1,7 @@
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Union, List from typing import Optional, Tuple, Union, List
from loguru import logger from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
@ -51,8 +51,32 @@ def get_fp8_linear() -> torch.nn.Module:
return Fp8Linear return Fp8Linear
def normalize_e4m3fn_to_e4m3fnuz(
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
assert weight.dtype == torch.float8_e4m3fn
# The bits pattern 10000000(-128) represents zero in e4m3fn
# but NaN in e4m3fnuz. So here we set it to 0.
# https://onnx.ai/onnx/technical/float8.html
weight_as_int8 = weight.view(torch.int8)
ROCM_FP8_NAN_AS_INT = -128
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
weight = weight_as_int8.view(torch.float8_e4m3fnuz)
# For the same bits representation, e4m3fnuz value is half of
# the e4m3fn value, so we should double the scaling factor to
# get the same dequantized value.
# https://onnx.ai/onnx/technical/float8.html
weight_scale = weight_scale * 2.0
if input_scale is not None:
input_scale = input_scale * 2.0
return weight, weight_scale, input_scale
def fp8_quantize( def fp8_quantize(
weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False weight, scale=None, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False
): ):
if FBGEMM_DYN_AVAILABLE and not scalar: if FBGEMM_DYN_AVAILABLE and not scalar:
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row( qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
@ -62,8 +86,11 @@ def fp8_quantize(
# weight, scale = quant_weights(weight, torch.int8, False) # weight, scale = quant_weights(weight, torch.int8, False)
finfo = torch.finfo(qdtype) finfo = torch.finfo(qdtype)
# Calculate the scale as dtype max divided by absmax
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound) if scale is None:
# Calculate the scale as dtype max divided by absmax
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
# scale and clamp the tensor to bring it to # scale and clamp the tensor to bring it to
# the representative range of float8 data type # the representative range of float8 data type
# (as default cast is unsaturated) # (as default cast is unsaturated)
@ -72,6 +99,10 @@ def fp8_quantize(
# as both required as inputs to torch._scaled_mm # as both required as inputs to torch._scaled_mm
qweight = qweight.to(qdtype) qweight = qweight.to(qdtype)
scale = scale.float().reciprocal() scale = scale.float().reciprocal()
if SYSTEM == "rocm":
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale)
return qweight, scale return qweight, scale
@ -92,9 +123,17 @@ class HybridFP8UnquantLoader(WeightsLoader):
.reshape(-1) .reshape(-1)
.expand(w.shape[0]) .expand(w.shape[0])
) )
try:
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
except Exception:
input_scale = None
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub, activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype, dtype=weights.dtype,
) )
@ -124,10 +163,25 @@ class HybridFP8UnquantLoader(WeightsLoader):
to_dtype=False, to_dtype=False,
) )
scale = scale.reshape(-1).expand(w.shape[0]) scale = scale.reshape(-1).expand(w.shape[0])
try:
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
)
if input_scale.numel() > 1:
input_scale = weights.get_packed_sharded(
f"{prefix}.input_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
input_scale = input_scale.reshape(-1).max()
except Exception:
input_scale = None
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub, activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype, dtype=weights.dtype,
) )
@ -153,10 +207,19 @@ class HybridFP8UnquantLoader(WeightsLoader):
for p, shape in zip(prefixes, shapes) for p, shape in zip(prefixes, shapes)
] ]
scale = torch.cat(scale, dim=0).reshape(-1) scale = torch.cat(scale, dim=0).reshape(-1)
try:
input_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
for p, shape in zip(prefixes, shapes)
]
input_scale = torch.cat(input_scale, dim=0).reshape(-1).max()
except Exception:
input_scale = None
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub, activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype, dtype=weights.dtype,
) )
@ -174,9 +237,17 @@ class HybridFP8UnquantLoader(WeightsLoader):
.reshape(-1) .reshape(-1)
.expand(w.shape[0]) .expand(w.shape[0])
) )
try:
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
except Exception:
input_scale = None
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub, activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype, dtype=weights.dtype,
) )
@ -191,6 +262,7 @@ class Fp8Weight(Weight):
weight: torch.Tensor weight: torch.Tensor
dtype: torch.dtype dtype: torch.dtype
weight_scale: Optional[torch.Tensor] = None weight_scale: Optional[torch.Tensor] = None
input_scale: Optional[torch.Tensor] = None
activation_scale_ub: Optional[float] = None activation_scale_ub: Optional[float] = None
def get_linear(self, bias: torch.Tensor): def get_linear(self, bias: torch.Tensor):
@ -200,15 +272,23 @@ class Fp8Weight(Weight):
# memory. Can be non-contiguous when we e.g. expand from scalars. # memory. Can be non-contiguous when we e.g. expand from scalars.
self.weight_scale = self.weight_scale.contiguous() self.weight_scale = self.weight_scale.contiguous()
return get_fp8_linear().from_fp8( return get_fp8_linear().from_fp8(
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype self.weight,
self.weight_scale,
self.input_scale,
self.activation_scale_ub,
bias,
self.dtype,
) )
class Fp8Linear(torch.nn.Module): class Fp8Linear(torch.nn.Module):
_device_identity_cache = {}
def __init__( def __init__(
self, self,
qweight, qweight,
scale, scale,
input_scale,
scale_upper_bound, scale_upper_bound,
bias, bias,
dtype, dtype,
@ -217,17 +297,29 @@ class Fp8Linear(torch.nn.Module):
if FBGEMM_MM_AVAILABLE: if FBGEMM_MM_AVAILABLE:
log_once(logger.info, "Using FBGEMM fp8 optimized kernels") log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
if SYSTEM == "rocm":
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=qweight, weight_scale=scale
)
self.dtype = dtype self.dtype = dtype
self.qweight = qweight self.qweight = qweight
self.scale = scale self.scale = scale.float()
self.scale_upper_bound = ( self.input_scale = (
torch.tensor( input_scale.float().reciprocal() if input_scale is not None else None
[scale_upper_bound], dtype=torch.float32, device=qweight.device
)
if scale_upper_bound is not None
else None
) )
if FBGEMM_MM_AVAILABLE:
self.scale_upper_bound = (
torch.tensor(
[scale_upper_bound], dtype=torch.float32, device=qweight.device
)
if scale_upper_bound is not None
else None
)
else:
self.scale_upper_bound = scale_upper_bound
self.bias = bias if bias is not None else None self.bias = bias if bias is not None else None
@classmethod @classmethod
@ -238,18 +330,27 @@ class Fp8Linear(torch.nn.Module):
) )
@classmethod @classmethod
def from_fp8(cls, weight, scale, input_scale, bias, dtype): def from_fp8(cls, weight, scale, input_scale, scale_upper_bound, bias, dtype):
if FBGEMM_DYN_AVAILABLE: if FBGEMM_DYN_AVAILABLE:
# fbgemm needs float32 scales. # fbgemm needs float32 scales.
scale = scale.float() scale = scale.float()
return cls( return cls(
qweight=weight, qweight=weight,
scale=scale, scale=scale,
scale_upper_bound=input_scale, input_scale=input_scale,
scale_upper_bound=scale_upper_bound,
bias=bias, bias=bias,
dtype=dtype, dtype=dtype,
) )
@classmethod
def get_shared_device_identity(cls, device):
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
if device not in cls._device_identity_cache:
cls._device_identity_cache[device] = torch.ones(1, device=device)
return cls._device_identity_cache[device]
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
if FBGEMM_MM_AVAILABLE: if FBGEMM_MM_AVAILABLE:
qinput, scale = fp8_quantize( qinput, scale = fp8_quantize(
@ -266,15 +367,49 @@ class Fp8Linear(torch.nn.Module):
) )
return y.to(self.dtype) return y.to(self.dtype)
qinput, scale = fp8_quantize(input, scalar=True) qinput, scale = fp8_quantize(
output, _ = torch._scaled_mm( input,
qinput, self.input_scale,
self.qweight.t(), scale_upper_bound=self.scale_upper_bound,
out_dtype=self.dtype, scalar=True,
scale_a=scale,
scale_b=self.scale,
bias=self.bias,
) )
per_tensor_weights = self.scale.numel() == 1
per_tensor_activations = scale.numel() == 1
if SYSTEM != "rocm" or (per_tensor_weights and per_tensor_activations):
output = torch._scaled_mm(
qinput,
self.qweight.t(),
out_dtype=self.dtype,
scale_a=scale,
scale_b=self.scale,
bias=self.bias,
)
if type(output) is tuple and len(output) == 2:
output = output[0]
else:
device_identity = None
if SYSTEM == "rocm":
device_identity = self.get_shared_device_identity(self.qweight.device)
output = torch._scaled_mm(
qinput,
self.qweight.t(),
scale_a=device_identity,
scale_b=device_identity,
out_dtype=torch.float32,
)
if type(output) is tuple and len(output) == 2:
output = output[0]
output = output * scale * self.scale.t()
if self.bias is not None:
output = output + self.bias
output = output.to(dtype=self.dtype)
return output return output

View File

@ -62,7 +62,7 @@ class GPTQMarlinFP8Linear(nn.Module):
return cls(qweight=qweight, scales=scales.to(dtype), bias=bias) return cls(qweight=qweight, scales=scales.to(dtype), bias=bias)
@classmethod @classmethod
def from_fp8(cls, weight, scale, _input_scale, bias, dtype): def from_fp8(cls, weight, scale, _input_scale, _scale_upper_bound, bias, dtype):
return cls(qweight=weight, scales=scale.to(dtype), bias=bias) return cls(qweight=weight, scales=scale.to(dtype), bias=bias)
def forward(self, A: torch.Tensor) -> torch.Tensor: def forward(self, A: torch.Tensor) -> torch.Tensor:

View File

@ -340,7 +340,7 @@ def get_model(
if method in {"gptq", "awq", "exl2"}: if method in {"gptq", "awq", "exl2"}:
log_master(logger.info, f"Auto selecting quantization method {method}") log_master(logger.info, f"Auto selecting quantization method {method}")
quantize = method quantize = method
elif method == "fbgemm_fp8": elif method == "fbgemm_fp8" or method == "fp8":
log_master(logger.info, "Auto selecting quantization method fp8") log_master(logger.info, "Auto selecting quantization method fp8")
quantize = "fp8" quantize = "fp8"
else: else: