Make handling of FP8 scales more consisent (#2666)

Change `fp8_quantize` so that we can pass around reciprocals everywhere,
so scales are always passed around in the checkpoint format.

I also noticed that we ignore any input scales that we might have when
fbgemm is available. Skip this path if we already have a scale.
This commit is contained in:
Daniël de Kok 2024-10-19 09:05:01 +02:00 committed by GitHub
parent 153ff3740b
commit 5e0fb46821
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 21 additions and 10 deletions

View File

@ -76,9 +76,19 @@ def normalize_e4m3fn_to_e4m3fnuz(
def fp8_quantize(
weight, scale=None, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False
weight: torch.Tensor,
scale: Optional[torch.Tensor] = None,
scale_upper_bound: Optional[torch.Tensor] = None,
qdtype: torch.dtype = torch.float8_e4m3fn,
scalar: bool = False,
):
if FBGEMM_DYN_AVAILABLE and not scalar:
"""
This function returns a reciprocal of the scale, so that a tensor can be unscaled
by multiplying it with the returned scale. If a scale is given through the `scale`
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
be used without modification).
"""
if FBGEMM_DYN_AVAILABLE and not scalar and not scale:
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
)
@ -90,15 +100,18 @@ def fp8_quantize(
if scale is None:
# Calculate the scale as dtype max divided by absmax
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
scale = scale.float().reciprocal()
else:
# Use reciprocal to avoid more expensive division.
qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
qweight = qweight.to(qdtype)
scale = scale.float().reciprocal()
if SYSTEM == "rocm":
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale)
@ -307,9 +320,7 @@ class Fp8Linear(torch.nn.Module):
self.dtype = dtype
self.qweight = qweight
self.scale = scale.float()
self.input_scale = (
input_scale.float().reciprocal() if input_scale is not None else None
)
self.input_scale = input_scale.float() if input_scale is not None else None
if FBGEMM_MM_AVAILABLE:
self.scale_upper_bound = (