(fix) quantize=fp8

This commit is contained in:
Mohit Sharma 2024-09-30 12:07:38 +00:00
parent 8ee9823d3b
commit 8cc2febdb6
1 changed files with 7 additions and 3 deletions

View File

@ -296,8 +296,7 @@ class Fp8Linear(torch.nn.Module):
super().__init__() super().__init__()
if FBGEMM_MM_AVAILABLE: if FBGEMM_MM_AVAILABLE:
log_once(logger.info, "Using FBGEMM fp8 optimized kernels") log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
if SYSTEM == "rocm":
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz( qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=qweight, weight_scale=scale weight=qweight, weight_scale=scale
) )
@ -326,7 +325,12 @@ class Fp8Linear(torch.nn.Module):
def from_unquant(cls, weight, bias, dtype): def from_unquant(cls, weight, bias, dtype):
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE) qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
return cls( return cls(
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype qweight=qweight,
scale=scale,
input_scale=None,
scale_upper_bound=None,
bias=bias,
dtype=dtype,
) )
@classmethod @classmethod