Make handling of FP8 scales more consisent (#2666)
Change `fp8_quantize` so that we can pass around reciprocals everywhere, so scales are always passed around in the checkpoint format. I also noticed that we ignore any input scales that we might have when fbgemm is available. Skip this path if we already have a scale.
This commit is contained in:
parent
153ff3740b
commit
5e0fb46821
|
@ -76,9 +76,19 @@ def normalize_e4m3fn_to_e4m3fnuz(
|
|||
|
||||
|
||||
def fp8_quantize(
|
||||
weight, scale=None, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False
|
||||
weight: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
scale_upper_bound: Optional[torch.Tensor] = None,
|
||||
qdtype: torch.dtype = torch.float8_e4m3fn,
|
||||
scalar: bool = False,
|
||||
):
|
||||
if FBGEMM_DYN_AVAILABLE and not scalar:
|
||||
"""
|
||||
This function returns a reciprocal of the scale, so that a tensor can be unscaled
|
||||
by multiplying it with the returned scale. If a scale is given through the `scale`
|
||||
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
|
||||
be used without modification).
|
||||
"""
|
||||
if FBGEMM_DYN_AVAILABLE and not scalar and not scale:
|
||||
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
|
||||
)
|
||||
|
@ -90,15 +100,18 @@ def fp8_quantize(
|
|||
if scale is None:
|
||||
# Calculate the scale as dtype max divided by absmax
|
||||
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
|
||||
# scale and clamp the tensor to bring it to
|
||||
# the representative range of float8 data type
|
||||
# (as default cast is unsaturated)
|
||||
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
scale = scale.float().reciprocal()
|
||||
else:
|
||||
# Use reciprocal to avoid more expensive division.
|
||||
qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
|
||||
|
||||
# scale and clamp the tensor to bring it to
|
||||
# the representative range of float8 data type
|
||||
# (as default cast is unsaturated)
|
||||
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
# Return both float8 data and the inverse scale (as float),
|
||||
# as both required as inputs to torch._scaled_mm
|
||||
qweight = qweight.to(qdtype)
|
||||
scale = scale.float().reciprocal()
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale)
|
||||
|
@ -307,9 +320,7 @@ class Fp8Linear(torch.nn.Module):
|
|||
self.dtype = dtype
|
||||
self.qweight = qweight
|
||||
self.scale = scale.float()
|
||||
self.input_scale = (
|
||||
input_scale.float().reciprocal() if input_scale is not None else None
|
||||
)
|
||||
self.input_scale = input_scale.float() if input_scale is not None else None
|
||||
|
||||
if FBGEMM_MM_AVAILABLE:
|
||||
self.scale_upper_bound = (
|
||||
|
|
Loading…
Reference in New Issue