(fix) quantize=fp8
This commit is contained in:
parent
8ee9823d3b
commit
8cc2febdb6
|
@ -296,8 +296,7 @@ class Fp8Linear(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if FBGEMM_MM_AVAILABLE:
|
if FBGEMM_MM_AVAILABLE:
|
||||||
log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
|
log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
|
||||||
|
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
|
||||||
if SYSTEM == "rocm":
|
|
||||||
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight=qweight, weight_scale=scale
|
weight=qweight, weight_scale=scale
|
||||||
)
|
)
|
||||||
|
@ -326,7 +325,12 @@ class Fp8Linear(torch.nn.Module):
|
||||||
def from_unquant(cls, weight, bias, dtype):
|
def from_unquant(cls, weight, bias, dtype):
|
||||||
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
|
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
|
||||||
return cls(
|
return cls(
|
||||||
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
|
qweight=qweight,
|
||||||
|
scale=scale,
|
||||||
|
input_scale=None,
|
||||||
|
scale_upper_bound=None,
|
||||||
|
bias=bias,
|
||||||
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
Loading…
Reference in New Issue