nm changes

This commit is contained in:
Mohit Sharma 2024-12-18 12:15:28 +00:00
parent 19688a0617
commit f8771d0a83
1 changed files with 12 additions and 6 deletions

View File

@ -7,6 +7,7 @@ from text_generation_server.layers.fp8 import (
Fp8Weight,
_load_scalar_or_matrix_scale,
requantize_with_max_scale,
normalize_e4m3fn_to_e4m3fnuz,
)
from text_generation_server.utils.weights import Weights, WeightsLoader
from text_generation_server.utils.import_utils import SYSTEM
@ -132,12 +133,6 @@ class W8ANFpLoader(WeightsLoader):
]
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)
if weight_scale.numel() == len(prefixes):
logical_widths = [x[0] for x in shapes]
w, weight_scale = requantize_with_max_scale(
w, weight_scale.to(weights.device), logical_widths
)
input_scale = None
if self.load_input_scale:
input_scale = [
@ -152,6 +147,17 @@ class W8ANFpLoader(WeightsLoader):
else None
)
if self.load_weight_scale or SYSTEM == "rocm":
w, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
w, weight_scale, input_scale
)
if weight_scale.numel() == len(prefixes):
logical_widths = [x[0] for x in shapes]
w, weight_scale = requantize_with_max_scale(
w, weight_scale.to(weights.device), logical_widths, weights.dtype
)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,