From 704a58c807a84e685542b5d9d7117317bcb8b39c Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Wed, 16 Oct 2024 13:24:50 +0530 Subject: [PATCH] Fp8 e4m3_fnuz support for rocm (#2588) * (feat) fp8 fnuz support for rocm * (review comments) Fix compression_config load, type hints * (bug) update all has_tensor * (review_comments) fix typo and added comments * (nit) improved comment --- server/text_generation_server/layers/fp8.py | 208 +++++++++++++++--- .../layers/gptq/__init__.py | 4 +- .../layers/marlin/fp8.py | 9 +- .../layers/marlin/gptq.py | 4 +- .../text_generation_server/models/__init__.py | 21 +- .../text_generation_server/utils/weights.py | 2 +- 6 files changed, 211 insertions(+), 37 deletions(-) diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 61dd5115..c626ddb8 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -1,7 +1,7 @@ import torch from dataclasses import dataclass -from typing import Optional, Union, List +from typing import Optional, Tuple, Union, List from loguru import logger from text_generation_server.utils.import_utils import SYSTEM @@ -51,8 +51,32 @@ def get_fp8_linear() -> torch.nn.Module: return Fp8Linear +def normalize_e4m3fn_to_e4m3fnuz( + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + assert weight.dtype == torch.float8_e4m3fn + # The bits pattern 10000000(-128) represents zero in e4m3fn + # but NaN in e4m3fnuz. So here we set it to 0. + # https://onnx.ai/onnx/technical/float8.html + weight_as_int8 = weight.view(torch.int8) + ROCM_FP8_NAN_AS_INT = -128 + weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0 + weight = weight_as_int8.view(torch.float8_e4m3fnuz) + + # For the same bits representation, e4m3fnuz value is half of + # the e4m3fn value, so we should double the scaling factor to + # get the same dequantized value. + # https://onnx.ai/onnx/technical/float8.html + weight_scale = weight_scale * 2.0 + if input_scale is not None: + input_scale = input_scale * 2.0 + return weight, weight_scale, input_scale + + def fp8_quantize( - weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False + weight, scale=None, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False ): if FBGEMM_DYN_AVAILABLE and not scalar: qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row( @@ -62,8 +86,11 @@ def fp8_quantize( # weight, scale = quant_weights(weight, torch.int8, False) finfo = torch.finfo(qdtype) - # Calculate the scale as dtype max divided by absmax - scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound) + + if scale is None: + # Calculate the scale as dtype max divided by absmax + scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound) + # scale and clamp the tensor to bring it to # the representative range of float8 data type # (as default cast is unsaturated) @@ -72,6 +99,10 @@ def fp8_quantize( # as both required as inputs to torch._scaled_mm qweight = qweight.to(qdtype) scale = scale.float().reciprocal() + + if SYSTEM == "rocm": + qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale) + return qweight, scale @@ -92,9 +123,17 @@ class HybridFP8UnquantLoader(WeightsLoader): .reshape(-1) .expand(w.shape[0]) ) + + input_scale = None + if weights.has_tensor(f"{prefix}.input_scale"): + input_scale = weights.get_tensor( + f"{prefix}.input_scale", to_dtype=False + ).reshape(-1) + return Fp8Weight( weight=w, weight_scale=scale, + input_scale=input_scale, activation_scale_ub=self.activation_scale_ub, dtype=weights.dtype, ) @@ -125,9 +164,24 @@ class HybridFP8UnquantLoader(WeightsLoader): ) scale = scale.reshape(-1).expand(w.shape[0]) + input_scale = None + if weights.has_tensor(f"{prefix}.input_scale"): + input_scale = weights.get_tensor( + f"{prefix}.input_scale", to_dtype=False + ) + if input_scale.numel() > 1: + input_scale = weights.get_packed_sharded( + f"{prefix}.input_scale", + dim=0, + block_sizes=block_sizes, + to_dtype=False, + ) + input_scale = input_scale.reshape(-1).max() + return Fp8Weight( weight=w, weight_scale=scale, + input_scale=input_scale, activation_scale_ub=self.activation_scale_ub, dtype=weights.dtype, ) @@ -154,9 +208,22 @@ class HybridFP8UnquantLoader(WeightsLoader): ] scale = torch.cat(scale, dim=0).reshape(-1) + input_scale = [ + _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape) + for p, shape in zip(prefixes, shapes) + if weights.has_tensor(f"{p}.input_scale") + ] + assert len(input_scale) == 0 or len(input_scale) == len(prefixes) + input_scale = ( + torch.cat(input_scale, dim=0).reshape(-1).max() + if len(input_scale) != 0 + else None + ) + return Fp8Weight( weight=w, weight_scale=scale, + input_scale=input_scale, activation_scale_ub=self.activation_scale_ub, dtype=weights.dtype, ) @@ -174,9 +241,16 @@ class HybridFP8UnquantLoader(WeightsLoader): .reshape(-1) .expand(w.shape[0]) ) + input_scale = None + if weights.has_tensor(f"{prefix}.input_scale"): + input_scale = weights.get_tensor( + f"{prefix}.input_scale", to_dtype=False + ).reshape(-1) + return Fp8Weight( weight=w, weight_scale=scale, + input_scale=input_scale, activation_scale_ub=self.activation_scale_ub, dtype=weights.dtype, ) @@ -191,6 +265,7 @@ class Fp8Weight(Weight): weight: torch.Tensor dtype: torch.dtype weight_scale: Optional[torch.Tensor] = None + input_scale: Optional[torch.Tensor] = None activation_scale_ub: Optional[float] = None def get_linear(self, bias: torch.Tensor): @@ -200,56 +275,99 @@ class Fp8Weight(Weight): # memory. Can be non-contiguous when we e.g. expand from scalars. self.weight_scale = self.weight_scale.contiguous() return get_fp8_linear().from_fp8( - self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype + weight=self.weight, + scale=self.weight_scale, + dtype=self.dtype, + bias=bias, + input_scale=self.input_scale, + scale_upper_bound=self.activation_scale_ub, ) class Fp8Linear(torch.nn.Module): + _device_identity_cache = {} + def __init__( self, - qweight, - scale, - scale_upper_bound, - bias, - dtype, + qweight: torch.Tensor, + scale: torch.Tensor, + dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, + input_scale: Optional[torch.Tensor] = None, + scale_upper_bound: Optional[float] = None, ) -> None: super().__init__() if FBGEMM_MM_AVAILABLE: log_once(logger.info, "Using FBGEMM fp8 optimized kernels") + if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn: + qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=qweight, weight_scale=scale + ) self.dtype = dtype self.qweight = qweight - self.scale = scale - self.scale_upper_bound = ( - torch.tensor( - [scale_upper_bound], dtype=torch.float32, device=qweight.device - ) - if scale_upper_bound is not None - else None + self.scale = scale.float() + self.input_scale = ( + input_scale.float().reciprocal() if input_scale is not None else None ) + if FBGEMM_MM_AVAILABLE: + self.scale_upper_bound = ( + torch.tensor( + [scale_upper_bound], dtype=torch.float32, device=qweight.device + ) + if scale_upper_bound is not None + else None + ) + else: + self.scale_upper_bound = scale_upper_bound + self.bias = bias if bias is not None else None @classmethod def from_unquant(cls, weight, bias, dtype): qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE) return cls( - qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype + qweight=qweight, + scale=scale, + dtype=dtype, + bias=bias, + input_scale=None, + scale_upper_bound=None, ) @classmethod - def from_fp8(cls, weight, scale, input_scale, bias, dtype): + def from_fp8( + cls, + weight: torch.Tensor, + scale: torch.Tensor, + dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, + **kwargs, + ) -> "Fp8Linear": + input_scale = kwargs.get("input_scale", None) + scale_upper_bound = kwargs.get("scale_upper_bound", None) + if FBGEMM_DYN_AVAILABLE: # fbgemm needs float32 scales. scale = scale.float() return cls( qweight=weight, scale=scale, - scale_upper_bound=input_scale, + input_scale=input_scale, + scale_upper_bound=scale_upper_bound, bias=bias, dtype=dtype, ) + @classmethod + def get_shared_device_identity(cls, device): + # Input scaling factors are no longer optional in _scaled_mm starting + # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale + if device not in cls._device_identity_cache: + cls._device_identity_cache[device] = torch.ones(1, device=device) + return cls._device_identity_cache[device] + def forward(self, input: torch.Tensor) -> torch.Tensor: if FBGEMM_MM_AVAILABLE: qinput, scale = fp8_quantize( @@ -266,15 +384,49 @@ class Fp8Linear(torch.nn.Module): ) return y.to(self.dtype) - qinput, scale = fp8_quantize(input, scalar=True) - output, _ = torch._scaled_mm( - qinput, - self.qweight.t(), - out_dtype=self.dtype, - scale_a=scale, - scale_b=self.scale, - bias=self.bias, + qinput, scale = fp8_quantize( + input, + self.input_scale, + scale_upper_bound=self.scale_upper_bound, + scalar=True, ) + + per_tensor_weights = self.scale.numel() == 1 + per_tensor_activations = scale.numel() == 1 + + if SYSTEM != "rocm" or (per_tensor_weights and per_tensor_activations): + output = torch._scaled_mm( + qinput, + self.qweight.t(), + out_dtype=self.dtype, + scale_a=scale, + scale_b=self.scale, + bias=self.bias, + ) + + if isinstance(output, tuple) and len(output) == 2: + output = output[0] + else: + device_identity = None + if SYSTEM == "rocm": + device_identity = self.get_shared_device_identity(self.qweight.device) + + output = torch._scaled_mm( + qinput, + self.qweight.t(), + scale_a=device_identity, + scale_b=device_identity, + out_dtype=torch.float32, + ) + if isinstance(output, tuple) and len(output) == 2: + output = output[0] + + output = output * scale * self.scale.t() + if self.bias is not None: + output = output + self.bias + + output = output.to(dtype=self.dtype) + return output diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 505caa59..1fd183fa 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -392,7 +392,7 @@ class GPTQWeightsLoader(WeightsLoader): ) def _get_gptq_params(self, weights: Weights): - if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"): + if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"): self.bits = weights.get_tensor("gptq_bits").item() self.groupsize = weights.get_tensor("gptq_groupsize").item() self.desc_act = False @@ -400,7 +400,7 @@ class GPTQWeightsLoader(WeightsLoader): # before the `gptq_sym` setting tensor was added. self.sym = ( weights.get_tensor("gptq_sym").item() - if weights._has_tensor("gptq_sym") + if weights.has_tensor("gptq_sym") else False ) self.quant_method = "gptq" diff --git a/server/text_generation_server/layers/marlin/fp8.py b/server/text_generation_server/layers/marlin/fp8.py index fe55a58a..49f5c480 100644 --- a/server/text_generation_server/layers/marlin/fp8.py +++ b/server/text_generation_server/layers/marlin/fp8.py @@ -62,7 +62,14 @@ class GPTQMarlinFP8Linear(nn.Module): return cls(qweight=qweight, scales=scales.to(dtype), bias=bias) @classmethod - def from_fp8(cls, weight, scale, _input_scale, bias, dtype): + def from_fp8( + cls, + weight: torch.Tensor, + scale: torch.Tensor, + bias: torch.Tensor, + dtype: torch.dtype, + **kwargs, + ): return cls(qweight=weight, scales=scale.to(dtype), bias=bias) def forward(self, A: torch.Tensor) -> torch.Tensor: diff --git a/server/text_generation_server/layers/marlin/gptq.py b/server/text_generation_server/layers/marlin/gptq.py index 7245431f..47341c0f 100644 --- a/server/text_generation_server/layers/marlin/gptq.py +++ b/server/text_generation_server/layers/marlin/gptq.py @@ -231,7 +231,7 @@ class GPTQMarlinWeightsLoader(WeightsLoader): ) def _get_gptq_params(self, weights: Weights): - if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"): + if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"): self.bits = weights.get_tensor("gptq_bits").item() self.groupsize = weights.get_tensor("gptq_groupsize").item() self.desc_act = False @@ -239,7 +239,7 @@ class GPTQMarlinWeightsLoader(WeightsLoader): # before the `gptq_sym` setting tensor was added. self.sym = ( weights.get_tensor("gptq_sym").item() - if weights._has_tensor("gptq_sym") + if weights.has_tensor("gptq_sym") else False ) self.quant_method = "gptq" diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 17eed976..019617d2 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -357,17 +357,32 @@ def get_model( compression_config = config_dict.get("compression_config", None) if quantization_config is not None and quantize is None: method = quantization_config.get("quant_method", None) + config_groups = quantization_config.get("config_groups", None) if method in {"gptq", "awq", "exl2"}: log_master(logger.info, f"Auto selecting quantization method {method}") quantize = method - elif method == "fbgemm_fp8": + elif method == "fbgemm_fp8" or method == "fp8": log_master(logger.info, "Auto selecting quantization method fp8") quantize = "fp8" + elif config_groups is not None: + # TODO: at some point we should probably fully parse the compression + # configuration to know which parameters are compressed. + for _, group in config_groups.items(): + weights_config = group.get("weights") + if weights_config is not None: + if ( + weights_config["type"] == "float" + and weights_config["num_bits"] == 8 + ): + log_master( + logger.info, "Auto selecting quantization method fp8" + ) + quantize = "fp8" + break else: log_master(logger.warning, f"Unknown quantization method {method}") elif compression_config is not None: - # TODO: at some point we should probably fully parse the compression - # configuration to know which parameters are compressed. + # `compression_config` renamed to `quantization_config`; support retained for backward compatibility. config_groups = compression_config.get("config_groups") if config_groups is not None: for _, group in config_groups.items(): diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 75e01f7c..548591e5 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -197,7 +197,7 @@ class Weights: slice_ = f.get_slice(tensor_name) return slice_ - def _has_tensor(self, tensor_name: str): + def has_tensor(self, tensor_name: str): try: self.get_filename(tensor_name) except Exception: