From 8cc2febdb6523611fe218724a7090b8d8a255084 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Mon, 30 Sep 2024 12:07:38 +0000 Subject: [PATCH] (fix) quantize=fp8 --- server/text_generation_server/layers/fp8.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 8aabc990..dca7fa95 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -296,8 +296,7 @@ class Fp8Linear(torch.nn.Module): super().__init__() if FBGEMM_MM_AVAILABLE: log_once(logger.info, "Using FBGEMM fp8 optimized kernels") - - if SYSTEM == "rocm": + if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn: qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight=qweight, weight_scale=scale ) @@ -326,7 +325,12 @@ class Fp8Linear(torch.nn.Module): def from_unquant(cls, weight, bias, dtype): qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE) return cls( - qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype + qweight=qweight, + scale=scale, + input_scale=None, + scale_upper_bound=None, + bias=bias, + dtype=dtype, ) @classmethod