From 8ee9823d3b7ebd98168ad91f9e05686aa4f81bcf Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Mon, 30 Sep 2024 11:43:45 +0000 Subject: [PATCH] (feat) fp8 fnuz support for rocm --- server/text_generation_server/layers/fp8.py | 179 +++++++++++++++--- .../layers/marlin/fp8.py | 2 +- .../text_generation_server/models/__init__.py | 2 +- 3 files changed, 159 insertions(+), 24 deletions(-) diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 61dd5115..8aabc990 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -1,7 +1,7 @@ import torch from dataclasses import dataclass -from typing import Optional, Union, List +from typing import Optional, Tuple, Union, List from loguru import logger from text_generation_server.utils.import_utils import SYSTEM @@ -51,8 +51,32 @@ def get_fp8_linear() -> torch.nn.Module: return Fp8Linear +def normalize_e4m3fn_to_e4m3fnuz( + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + assert weight.dtype == torch.float8_e4m3fn + # The bits pattern 10000000(-128) represents zero in e4m3fn + # but NaN in e4m3fnuz. So here we set it to 0. + # https://onnx.ai/onnx/technical/float8.html + weight_as_int8 = weight.view(torch.int8) + ROCM_FP8_NAN_AS_INT = -128 + weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0 + weight = weight_as_int8.view(torch.float8_e4m3fnuz) + + # For the same bits representation, e4m3fnuz value is half of + # the e4m3fn value, so we should double the scaling factor to + # get the same dequantized value. + # https://onnx.ai/onnx/technical/float8.html + weight_scale = weight_scale * 2.0 + if input_scale is not None: + input_scale = input_scale * 2.0 + return weight, weight_scale, input_scale + + def fp8_quantize( - weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False + weight, scale=None, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False ): if FBGEMM_DYN_AVAILABLE and not scalar: qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row( @@ -62,8 +86,11 @@ def fp8_quantize( # weight, scale = quant_weights(weight, torch.int8, False) finfo = torch.finfo(qdtype) - # Calculate the scale as dtype max divided by absmax - scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound) + + if scale is None: + # Calculate the scale as dtype max divided by absmax + scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound) + # scale and clamp the tensor to bring it to # the representative range of float8 data type # (as default cast is unsaturated) @@ -72,6 +99,10 @@ def fp8_quantize( # as both required as inputs to torch._scaled_mm qweight = qweight.to(qdtype) scale = scale.float().reciprocal() + + if SYSTEM == "rocm": + qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale) + return qweight, scale @@ -92,9 +123,17 @@ class HybridFP8UnquantLoader(WeightsLoader): .reshape(-1) .expand(w.shape[0]) ) + try: + input_scale = weights.get_tensor( + f"{prefix}.input_scale", to_dtype=False + ).reshape(-1) + except Exception: + input_scale = None + return Fp8Weight( weight=w, weight_scale=scale, + input_scale=input_scale, activation_scale_ub=self.activation_scale_ub, dtype=weights.dtype, ) @@ -124,10 +163,25 @@ class HybridFP8UnquantLoader(WeightsLoader): to_dtype=False, ) scale = scale.reshape(-1).expand(w.shape[0]) + try: + input_scale = weights.get_tensor( + f"{prefix}.input_scale", to_dtype=False + ) + if input_scale.numel() > 1: + input_scale = weights.get_packed_sharded( + f"{prefix}.input_scale", + dim=0, + block_sizes=block_sizes, + to_dtype=False, + ) + input_scale = input_scale.reshape(-1).max() + except Exception: + input_scale = None return Fp8Weight( weight=w, weight_scale=scale, + input_scale=input_scale, activation_scale_ub=self.activation_scale_ub, dtype=weights.dtype, ) @@ -153,10 +207,19 @@ class HybridFP8UnquantLoader(WeightsLoader): for p, shape in zip(prefixes, shapes) ] scale = torch.cat(scale, dim=0).reshape(-1) + try: + input_scale = [ + _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape) + for p, shape in zip(prefixes, shapes) + ] + input_scale = torch.cat(input_scale, dim=0).reshape(-1).max() + except Exception: + input_scale = None return Fp8Weight( weight=w, weight_scale=scale, + input_scale=input_scale, activation_scale_ub=self.activation_scale_ub, dtype=weights.dtype, ) @@ -174,9 +237,17 @@ class HybridFP8UnquantLoader(WeightsLoader): .reshape(-1) .expand(w.shape[0]) ) + try: + input_scale = weights.get_tensor( + f"{prefix}.input_scale", to_dtype=False + ).reshape(-1) + except Exception: + input_scale = None + return Fp8Weight( weight=w, weight_scale=scale, + input_scale=input_scale, activation_scale_ub=self.activation_scale_ub, dtype=weights.dtype, ) @@ -191,6 +262,7 @@ class Fp8Weight(Weight): weight: torch.Tensor dtype: torch.dtype weight_scale: Optional[torch.Tensor] = None + input_scale: Optional[torch.Tensor] = None activation_scale_ub: Optional[float] = None def get_linear(self, bias: torch.Tensor): @@ -200,15 +272,23 @@ class Fp8Weight(Weight): # memory. Can be non-contiguous when we e.g. expand from scalars. self.weight_scale = self.weight_scale.contiguous() return get_fp8_linear().from_fp8( - self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype + self.weight, + self.weight_scale, + self.input_scale, + self.activation_scale_ub, + bias, + self.dtype, ) class Fp8Linear(torch.nn.Module): + _device_identity_cache = {} + def __init__( self, qweight, scale, + input_scale, scale_upper_bound, bias, dtype, @@ -217,17 +297,29 @@ class Fp8Linear(torch.nn.Module): if FBGEMM_MM_AVAILABLE: log_once(logger.info, "Using FBGEMM fp8 optimized kernels") + if SYSTEM == "rocm": + qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=qweight, weight_scale=scale + ) + self.dtype = dtype self.qweight = qweight - self.scale = scale - self.scale_upper_bound = ( - torch.tensor( - [scale_upper_bound], dtype=torch.float32, device=qweight.device - ) - if scale_upper_bound is not None - else None + self.scale = scale.float() + self.input_scale = ( + input_scale.float().reciprocal() if input_scale is not None else None ) + if FBGEMM_MM_AVAILABLE: + self.scale_upper_bound = ( + torch.tensor( + [scale_upper_bound], dtype=torch.float32, device=qweight.device + ) + if scale_upper_bound is not None + else None + ) + else: + self.scale_upper_bound = scale_upper_bound + self.bias = bias if bias is not None else None @classmethod @@ -238,18 +330,27 @@ class Fp8Linear(torch.nn.Module): ) @classmethod - def from_fp8(cls, weight, scale, input_scale, bias, dtype): + def from_fp8(cls, weight, scale, input_scale, scale_upper_bound, bias, dtype): if FBGEMM_DYN_AVAILABLE: # fbgemm needs float32 scales. scale = scale.float() return cls( qweight=weight, scale=scale, - scale_upper_bound=input_scale, + input_scale=input_scale, + scale_upper_bound=scale_upper_bound, bias=bias, dtype=dtype, ) + @classmethod + def get_shared_device_identity(cls, device): + # Input scaling factors are no longer optional in _scaled_mm starting + # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale + if device not in cls._device_identity_cache: + cls._device_identity_cache[device] = torch.ones(1, device=device) + return cls._device_identity_cache[device] + def forward(self, input: torch.Tensor) -> torch.Tensor: if FBGEMM_MM_AVAILABLE: qinput, scale = fp8_quantize( @@ -266,15 +367,49 @@ class Fp8Linear(torch.nn.Module): ) return y.to(self.dtype) - qinput, scale = fp8_quantize(input, scalar=True) - output, _ = torch._scaled_mm( - qinput, - self.qweight.t(), - out_dtype=self.dtype, - scale_a=scale, - scale_b=self.scale, - bias=self.bias, + qinput, scale = fp8_quantize( + input, + self.input_scale, + scale_upper_bound=self.scale_upper_bound, + scalar=True, ) + + per_tensor_weights = self.scale.numel() == 1 + per_tensor_activations = scale.numel() == 1 + + if SYSTEM != "rocm" or (per_tensor_weights and per_tensor_activations): + output = torch._scaled_mm( + qinput, + self.qweight.t(), + out_dtype=self.dtype, + scale_a=scale, + scale_b=self.scale, + bias=self.bias, + ) + + if type(output) is tuple and len(output) == 2: + output = output[0] + else: + device_identity = None + if SYSTEM == "rocm": + device_identity = self.get_shared_device_identity(self.qweight.device) + + output = torch._scaled_mm( + qinput, + self.qweight.t(), + scale_a=device_identity, + scale_b=device_identity, + out_dtype=torch.float32, + ) + if type(output) is tuple and len(output) == 2: + output = output[0] + + output = output * scale * self.scale.t() + if self.bias is not None: + output = output + self.bias + + output = output.to(dtype=self.dtype) + return output diff --git a/server/text_generation_server/layers/marlin/fp8.py b/server/text_generation_server/layers/marlin/fp8.py index fe55a58a..dac109cf 100644 --- a/server/text_generation_server/layers/marlin/fp8.py +++ b/server/text_generation_server/layers/marlin/fp8.py @@ -62,7 +62,7 @@ class GPTQMarlinFP8Linear(nn.Module): return cls(qweight=qweight, scales=scales.to(dtype), bias=bias) @classmethod - def from_fp8(cls, weight, scale, _input_scale, bias, dtype): + def from_fp8(cls, weight, scale, _input_scale, _scale_upper_bound, bias, dtype): return cls(qweight=weight, scales=scale.to(dtype), bias=bias) def forward(self, A: torch.Tensor) -> torch.Tensor: diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index e5e5aabb..6c537394 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -340,7 +340,7 @@ def get_model( if method in {"gptq", "awq", "exl2"}: log_master(logger.info, f"Auto selecting quantization method {method}") quantize = method - elif method == "fbgemm_fp8": + elif method == "fbgemm_fp8" or method == "fp8": log_master(logger.info, "Auto selecting quantization method fp8") quantize = "fp8" else: