nm changes
This commit is contained in:
parent
19688a0617
commit
f8771d0a83
|
@ -7,6 +7,7 @@ from text_generation_server.layers.fp8 import (
|
|||
Fp8Weight,
|
||||
_load_scalar_or_matrix_scale,
|
||||
requantize_with_max_scale,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
@ -132,12 +133,6 @@ class W8ANFpLoader(WeightsLoader):
|
|||
]
|
||||
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)
|
||||
|
||||
if weight_scale.numel() == len(prefixes):
|
||||
logical_widths = [x[0] for x in shapes]
|
||||
w, weight_scale = requantize_with_max_scale(
|
||||
w, weight_scale.to(weights.device), logical_widths
|
||||
)
|
||||
|
||||
input_scale = None
|
||||
if self.load_input_scale:
|
||||
input_scale = [
|
||||
|
@ -152,6 +147,17 @@ class W8ANFpLoader(WeightsLoader):
|
|||
else None
|
||||
)
|
||||
|
||||
if self.load_weight_scale or SYSTEM == "rocm":
|
||||
w, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
w, weight_scale, input_scale
|
||||
)
|
||||
|
||||
if weight_scale.numel() == len(prefixes):
|
||||
logical_widths = [x[0] for x in shapes]
|
||||
w, weight_scale = requantize_with_max_scale(
|
||||
w, weight_scale.to(weights.device), logical_widths, weights.dtype
|
||||
)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=weight_scale,
|
||||
|
|
Loading…
Reference in New Issue