diff --git a/server/text_generation_server/layers/compressed_tensors/w8an_fp.py b/server/text_generation_server/layers/compressed_tensors/w8an_fp.py index 959fd5b3..15bdce08 100644 --- a/server/text_generation_server/layers/compressed_tensors/w8an_fp.py +++ b/server/text_generation_server/layers/compressed_tensors/w8an_fp.py @@ -7,6 +7,7 @@ from text_generation_server.layers.fp8 import ( Fp8Weight, _load_scalar_or_matrix_scale, requantize_with_max_scale, + normalize_e4m3fn_to_e4m3fnuz, ) from text_generation_server.utils.weights import Weights, WeightsLoader from text_generation_server.utils.import_utils import SYSTEM @@ -132,12 +133,6 @@ class W8ANFpLoader(WeightsLoader): ] weight_scale = torch.cat(weight_scale, dim=0).reshape(-1) - if weight_scale.numel() == len(prefixes): - logical_widths = [x[0] for x in shapes] - w, weight_scale = requantize_with_max_scale( - w, weight_scale.to(weights.device), logical_widths - ) - input_scale = None if self.load_input_scale: input_scale = [ @@ -152,6 +147,17 @@ class W8ANFpLoader(WeightsLoader): else None ) + if self.load_weight_scale or SYSTEM == "rocm": + w, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + w, weight_scale, input_scale + ) + + if weight_scale.numel() == len(prefixes): + logical_widths = [x[0] for x in shapes] + w, weight_scale = requantize_with_max_scale( + w, weight_scale.to(weights.device), logical_widths, weights.dtype + ) + return Fp8Weight( weight=w, weight_scale=weight_scale,