Make handling of FP8 scales more consisent (#2666)
Change `fp8_quantize` so that we can pass around reciprocals everywhere, so scales are always passed around in the checkpoint format. I also noticed that we ignore any input scales that we might have when fbgemm is available. Skip this path if we already have a scale.
This commit is contained in:
parent
153ff3740b
commit
5e0fb46821
|
@ -76,9 +76,19 @@ def normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
|
||||||
|
|
||||||
def fp8_quantize(
|
def fp8_quantize(
|
||||||
weight, scale=None, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False
|
weight: torch.Tensor,
|
||||||
|
scale: Optional[torch.Tensor] = None,
|
||||||
|
scale_upper_bound: Optional[torch.Tensor] = None,
|
||||||
|
qdtype: torch.dtype = torch.float8_e4m3fn,
|
||||||
|
scalar: bool = False,
|
||||||
):
|
):
|
||||||
if FBGEMM_DYN_AVAILABLE and not scalar:
|
"""
|
||||||
|
This function returns a reciprocal of the scale, so that a tensor can be unscaled
|
||||||
|
by multiplying it with the returned scale. If a scale is given through the `scale`
|
||||||
|
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
|
||||||
|
be used without modification).
|
||||||
|
"""
|
||||||
|
if FBGEMM_DYN_AVAILABLE and not scalar and not scale:
|
||||||
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||||
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
|
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
|
||||||
)
|
)
|
||||||
|
@ -90,15 +100,18 @@ def fp8_quantize(
|
||||||
if scale is None:
|
if scale is None:
|
||||||
# Calculate the scale as dtype max divided by absmax
|
# Calculate the scale as dtype max divided by absmax
|
||||||
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
|
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
|
||||||
|
# scale and clamp the tensor to bring it to
|
||||||
|
# the representative range of float8 data type
|
||||||
|
# (as default cast is unsaturated)
|
||||||
|
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
|
||||||
|
scale = scale.float().reciprocal()
|
||||||
|
else:
|
||||||
|
# Use reciprocal to avoid more expensive division.
|
||||||
|
qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
|
||||||
|
|
||||||
# scale and clamp the tensor to bring it to
|
|
||||||
# the representative range of float8 data type
|
|
||||||
# (as default cast is unsaturated)
|
|
||||||
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
|
|
||||||
# Return both float8 data and the inverse scale (as float),
|
# Return both float8 data and the inverse scale (as float),
|
||||||
# as both required as inputs to torch._scaled_mm
|
# as both required as inputs to torch._scaled_mm
|
||||||
qweight = qweight.to(qdtype)
|
qweight = qweight.to(qdtype)
|
||||||
scale = scale.float().reciprocal()
|
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale)
|
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale)
|
||||||
|
@ -307,9 +320,7 @@ class Fp8Linear(torch.nn.Module):
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.qweight = qweight
|
self.qweight = qweight
|
||||||
self.scale = scale.float()
|
self.scale = scale.float()
|
||||||
self.input_scale = (
|
self.input_scale = input_scale.float() if input_scale is not None else None
|
||||||
input_scale.float().reciprocal() if input_scale is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if FBGEMM_MM_AVAILABLE:
|
if FBGEMM_MM_AVAILABLE:
|
||||||
self.scale_upper_bound = (
|
self.scale_upper_bound = (
|
||||||
|
|
Loading…
Reference in New Issue