(fix) quantize=fp8
This commit is contained in:
parent
8ee9823d3b
commit
8cc2febdb6
|
@ -296,8 +296,7 @@ class Fp8Linear(torch.nn.Module):
|
|||
super().__init__()
|
||||
if FBGEMM_MM_AVAILABLE:
|
||||
log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
|
||||
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=qweight, weight_scale=scale
|
||||
)
|
||||
|
@ -326,7 +325,12 @@ class Fp8Linear(torch.nn.Module):
|
|||
def from_unquant(cls, weight, bias, dtype):
|
||||
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
|
||||
return cls(
|
||||
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
|
||||
qweight=qweight,
|
||||
scale=scale,
|
||||
input_scale=None,
|
||||
scale_upper_bound=None,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
Loading…
Reference in New Issue