diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 8aabc990..dca7fa95 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -296,8 +296,7 @@ class Fp8Linear(torch.nn.Module): super().__init__() if FBGEMM_MM_AVAILABLE: log_once(logger.info, "Using FBGEMM fp8 optimized kernels") - - if SYSTEM == "rocm": + if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn: qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight=qweight, weight_scale=scale ) @@ -326,7 +325,12 @@ class Fp8Linear(torch.nn.Module): def from_unquant(cls, weight, bias, dtype): qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE) return cls( - qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype + qweight=qweight, + scale=scale, + input_scale=None, + scale_upper_bound=None, + bias=bias, + dtype=dtype, ) @classmethod