diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index c626ddb8..18a40afa 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -76,9 +76,19 @@ def normalize_e4m3fn_to_e4m3fnuz( def fp8_quantize( - weight, scale=None, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False + weight: torch.Tensor, + scale: Optional[torch.Tensor] = None, + scale_upper_bound: Optional[torch.Tensor] = None, + qdtype: torch.dtype = torch.float8_e4m3fn, + scalar: bool = False, ): - if FBGEMM_DYN_AVAILABLE and not scalar: + """ + This function returns a reciprocal of the scale, so that a tensor can be unscaled + by multiplying it with the returned scale. If a scale is given through the `scale` + argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can + be used without modification). + """ + if FBGEMM_DYN_AVAILABLE and not scalar and not scale: qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row( weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype ) @@ -90,15 +100,18 @@ def fp8_quantize( if scale is None: # Calculate the scale as dtype max divided by absmax scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound) + # scale and clamp the tensor to bring it to + # the representative range of float8 data type + # (as default cast is unsaturated) + qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max) + scale = scale.float().reciprocal() + else: + # Use reciprocal to avoid more expensive division. + qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max) - # scale and clamp the tensor to bring it to - # the representative range of float8 data type - # (as default cast is unsaturated) - qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max) # Return both float8 data and the inverse scale (as float), # as both required as inputs to torch._scaled_mm qweight = qweight.to(qdtype) - scale = scale.float().reciprocal() if SYSTEM == "rocm": qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale) @@ -307,9 +320,7 @@ class Fp8Linear(torch.nn.Module): self.dtype = dtype self.qweight = qweight self.scale = scale.float() - self.input_scale = ( - input_scale.float().reciprocal() if input_scale is not None else None - ) + self.input_scale = input_scale.float() if input_scale is not None else None if FBGEMM_MM_AVAILABLE: self.scale_upper_bound = (