nm changes
This commit is contained in:
parent
19688a0617
commit
f8771d0a83
|
@ -7,6 +7,7 @@ from text_generation_server.layers.fp8 import (
|
||||||
Fp8Weight,
|
Fp8Weight,
|
||||||
_load_scalar_or_matrix_scale,
|
_load_scalar_or_matrix_scale,
|
||||||
requantize_with_max_scale,
|
requantize_with_max_scale,
|
||||||
|
normalize_e4m3fn_to_e4m3fnuz,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.weights import Weights, WeightsLoader
|
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
@ -132,12 +133,6 @@ class W8ANFpLoader(WeightsLoader):
|
||||||
]
|
]
|
||||||
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)
|
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)
|
||||||
|
|
||||||
if weight_scale.numel() == len(prefixes):
|
|
||||||
logical_widths = [x[0] for x in shapes]
|
|
||||||
w, weight_scale = requantize_with_max_scale(
|
|
||||||
w, weight_scale.to(weights.device), logical_widths
|
|
||||||
)
|
|
||||||
|
|
||||||
input_scale = None
|
input_scale = None
|
||||||
if self.load_input_scale:
|
if self.load_input_scale:
|
||||||
input_scale = [
|
input_scale = [
|
||||||
|
@ -152,6 +147,17 @@ class W8ANFpLoader(WeightsLoader):
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.load_weight_scale or SYSTEM == "rocm":
|
||||||
|
w, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
w, weight_scale, input_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
if weight_scale.numel() == len(prefixes):
|
||||||
|
logical_widths = [x[0] for x in shapes]
|
||||||
|
w, weight_scale = requantize_with_max_scale(
|
||||||
|
w, weight_scale.to(weights.device), logical_widths, weights.dtype
|
||||||
|
)
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=weight_scale,
|
weight_scale=weight_scale,
|
||||||
|
|
Loading…
Reference in New Issue