Fp8 e4m3_fnuz support for rocm (#2588)
* (feat) fp8 fnuz support for rocm * (review comments) Fix compression_config load, type hints * (bug) update all has_tensor * (review_comments) fix typo and added comments * (nit) improved comment
This commit is contained in:
parent
ffe05ccd05
commit
704a58c807
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Union, List
|
from typing import Optional, Tuple, Union, List
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
@ -51,8 +51,32 @@ def get_fp8_linear() -> torch.nn.Module:
|
||||||
return Fp8Linear
|
return Fp8Linear
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
input_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
assert weight.dtype == torch.float8_e4m3fn
|
||||||
|
# The bits pattern 10000000(-128) represents zero in e4m3fn
|
||||||
|
# but NaN in e4m3fnuz. So here we set it to 0.
|
||||||
|
# https://onnx.ai/onnx/technical/float8.html
|
||||||
|
weight_as_int8 = weight.view(torch.int8)
|
||||||
|
ROCM_FP8_NAN_AS_INT = -128
|
||||||
|
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
|
||||||
|
weight = weight_as_int8.view(torch.float8_e4m3fnuz)
|
||||||
|
|
||||||
|
# For the same bits representation, e4m3fnuz value is half of
|
||||||
|
# the e4m3fn value, so we should double the scaling factor to
|
||||||
|
# get the same dequantized value.
|
||||||
|
# https://onnx.ai/onnx/technical/float8.html
|
||||||
|
weight_scale = weight_scale * 2.0
|
||||||
|
if input_scale is not None:
|
||||||
|
input_scale = input_scale * 2.0
|
||||||
|
return weight, weight_scale, input_scale
|
||||||
|
|
||||||
|
|
||||||
def fp8_quantize(
|
def fp8_quantize(
|
||||||
weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False
|
weight, scale=None, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False
|
||||||
):
|
):
|
||||||
if FBGEMM_DYN_AVAILABLE and not scalar:
|
if FBGEMM_DYN_AVAILABLE and not scalar:
|
||||||
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||||
|
@ -62,8 +86,11 @@ def fp8_quantize(
|
||||||
|
|
||||||
# weight, scale = quant_weights(weight, torch.int8, False)
|
# weight, scale = quant_weights(weight, torch.int8, False)
|
||||||
finfo = torch.finfo(qdtype)
|
finfo = torch.finfo(qdtype)
|
||||||
|
|
||||||
|
if scale is None:
|
||||||
# Calculate the scale as dtype max divided by absmax
|
# Calculate the scale as dtype max divided by absmax
|
||||||
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
|
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
|
||||||
|
|
||||||
# scale and clamp the tensor to bring it to
|
# scale and clamp the tensor to bring it to
|
||||||
# the representative range of float8 data type
|
# the representative range of float8 data type
|
||||||
# (as default cast is unsaturated)
|
# (as default cast is unsaturated)
|
||||||
|
@ -72,6 +99,10 @@ def fp8_quantize(
|
||||||
# as both required as inputs to torch._scaled_mm
|
# as both required as inputs to torch._scaled_mm
|
||||||
qweight = qweight.to(qdtype)
|
qweight = qweight.to(qdtype)
|
||||||
scale = scale.float().reciprocal()
|
scale = scale.float().reciprocal()
|
||||||
|
|
||||||
|
if SYSTEM == "rocm":
|
||||||
|
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale)
|
||||||
|
|
||||||
return qweight, scale
|
return qweight, scale
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,9 +123,17 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
.reshape(-1)
|
.reshape(-1)
|
||||||
.expand(w.shape[0])
|
.expand(w.shape[0])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
input_scale = None
|
||||||
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
|
input_scale = weights.get_tensor(
|
||||||
|
f"{prefix}.input_scale", to_dtype=False
|
||||||
|
).reshape(-1)
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
|
input_scale=input_scale,
|
||||||
activation_scale_ub=self.activation_scale_ub,
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
dtype=weights.dtype,
|
dtype=weights.dtype,
|
||||||
)
|
)
|
||||||
|
@ -125,9 +164,24 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
)
|
)
|
||||||
scale = scale.reshape(-1).expand(w.shape[0])
|
scale = scale.reshape(-1).expand(w.shape[0])
|
||||||
|
|
||||||
|
input_scale = None
|
||||||
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
|
input_scale = weights.get_tensor(
|
||||||
|
f"{prefix}.input_scale", to_dtype=False
|
||||||
|
)
|
||||||
|
if input_scale.numel() > 1:
|
||||||
|
input_scale = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.input_scale",
|
||||||
|
dim=0,
|
||||||
|
block_sizes=block_sizes,
|
||||||
|
to_dtype=False,
|
||||||
|
)
|
||||||
|
input_scale = input_scale.reshape(-1).max()
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
|
input_scale=input_scale,
|
||||||
activation_scale_ub=self.activation_scale_ub,
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
dtype=weights.dtype,
|
dtype=weights.dtype,
|
||||||
)
|
)
|
||||||
|
@ -154,9 +208,22 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
]
|
]
|
||||||
scale = torch.cat(scale, dim=0).reshape(-1)
|
scale = torch.cat(scale, dim=0).reshape(-1)
|
||||||
|
|
||||||
|
input_scale = [
|
||||||
|
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
|
||||||
|
for p, shape in zip(prefixes, shapes)
|
||||||
|
if weights.has_tensor(f"{p}.input_scale")
|
||||||
|
]
|
||||||
|
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
|
||||||
|
input_scale = (
|
||||||
|
torch.cat(input_scale, dim=0).reshape(-1).max()
|
||||||
|
if len(input_scale) != 0
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
|
input_scale=input_scale,
|
||||||
activation_scale_ub=self.activation_scale_ub,
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
dtype=weights.dtype,
|
dtype=weights.dtype,
|
||||||
)
|
)
|
||||||
|
@ -174,9 +241,16 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
.reshape(-1)
|
.reshape(-1)
|
||||||
.expand(w.shape[0])
|
.expand(w.shape[0])
|
||||||
)
|
)
|
||||||
|
input_scale = None
|
||||||
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
|
input_scale = weights.get_tensor(
|
||||||
|
f"{prefix}.input_scale", to_dtype=False
|
||||||
|
).reshape(-1)
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
|
input_scale=input_scale,
|
||||||
activation_scale_ub=self.activation_scale_ub,
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
dtype=weights.dtype,
|
dtype=weights.dtype,
|
||||||
)
|
)
|
||||||
|
@ -191,6 +265,7 @@ class Fp8Weight(Weight):
|
||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
dtype: torch.dtype
|
dtype: torch.dtype
|
||||||
weight_scale: Optional[torch.Tensor] = None
|
weight_scale: Optional[torch.Tensor] = None
|
||||||
|
input_scale: Optional[torch.Tensor] = None
|
||||||
activation_scale_ub: Optional[float] = None
|
activation_scale_ub: Optional[float] = None
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
|
@ -200,26 +275,43 @@ class Fp8Weight(Weight):
|
||||||
# memory. Can be non-contiguous when we e.g. expand from scalars.
|
# memory. Can be non-contiguous when we e.g. expand from scalars.
|
||||||
self.weight_scale = self.weight_scale.contiguous()
|
self.weight_scale = self.weight_scale.contiguous()
|
||||||
return get_fp8_linear().from_fp8(
|
return get_fp8_linear().from_fp8(
|
||||||
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
|
weight=self.weight,
|
||||||
|
scale=self.weight_scale,
|
||||||
|
dtype=self.dtype,
|
||||||
|
bias=bias,
|
||||||
|
input_scale=self.input_scale,
|
||||||
|
scale_upper_bound=self.activation_scale_ub,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Fp8Linear(torch.nn.Module):
|
class Fp8Linear(torch.nn.Module):
|
||||||
|
_device_identity_cache = {}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
qweight,
|
qweight: torch.Tensor,
|
||||||
scale,
|
scale: torch.Tensor,
|
||||||
scale_upper_bound,
|
dtype: torch.dtype,
|
||||||
bias,
|
bias: Optional[torch.Tensor] = None,
|
||||||
dtype,
|
input_scale: Optional[torch.Tensor] = None,
|
||||||
|
scale_upper_bound: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if FBGEMM_MM_AVAILABLE:
|
if FBGEMM_MM_AVAILABLE:
|
||||||
log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
|
log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
|
||||||
|
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
|
||||||
|
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
weight=qweight, weight_scale=scale
|
||||||
|
)
|
||||||
|
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.qweight = qweight
|
self.qweight = qweight
|
||||||
self.scale = scale
|
self.scale = scale.float()
|
||||||
|
self.input_scale = (
|
||||||
|
input_scale.float().reciprocal() if input_scale is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if FBGEMM_MM_AVAILABLE:
|
||||||
self.scale_upper_bound = (
|
self.scale_upper_bound = (
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
[scale_upper_bound], dtype=torch.float32, device=qweight.device
|
[scale_upper_bound], dtype=torch.float32, device=qweight.device
|
||||||
|
@ -227,6 +319,8 @@ class Fp8Linear(torch.nn.Module):
|
||||||
if scale_upper_bound is not None
|
if scale_upper_bound is not None
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.scale_upper_bound = scale_upper_bound
|
||||||
|
|
||||||
self.bias = bias if bias is not None else None
|
self.bias = bias if bias is not None else None
|
||||||
|
|
||||||
|
@ -234,22 +328,46 @@ class Fp8Linear(torch.nn.Module):
|
||||||
def from_unquant(cls, weight, bias, dtype):
|
def from_unquant(cls, weight, bias, dtype):
|
||||||
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
|
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
|
||||||
return cls(
|
return cls(
|
||||||
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
|
qweight=qweight,
|
||||||
|
scale=scale,
|
||||||
|
dtype=dtype,
|
||||||
|
bias=bias,
|
||||||
|
input_scale=None,
|
||||||
|
scale_upper_bound=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
|
def from_fp8(
|
||||||
|
cls,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
scale: torch.Tensor,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> "Fp8Linear":
|
||||||
|
input_scale = kwargs.get("input_scale", None)
|
||||||
|
scale_upper_bound = kwargs.get("scale_upper_bound", None)
|
||||||
|
|
||||||
if FBGEMM_DYN_AVAILABLE:
|
if FBGEMM_DYN_AVAILABLE:
|
||||||
# fbgemm needs float32 scales.
|
# fbgemm needs float32 scales.
|
||||||
scale = scale.float()
|
scale = scale.float()
|
||||||
return cls(
|
return cls(
|
||||||
qweight=weight,
|
qweight=weight,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
scale_upper_bound=input_scale,
|
input_scale=input_scale,
|
||||||
|
scale_upper_bound=scale_upper_bound,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_shared_device_identity(cls, device):
|
||||||
|
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||||
|
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||||
|
if device not in cls._device_identity_cache:
|
||||||
|
cls._device_identity_cache[device] = torch.ones(1, device=device)
|
||||||
|
return cls._device_identity_cache[device]
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
if FBGEMM_MM_AVAILABLE:
|
if FBGEMM_MM_AVAILABLE:
|
||||||
qinput, scale = fp8_quantize(
|
qinput, scale = fp8_quantize(
|
||||||
|
@ -266,8 +384,18 @@ class Fp8Linear(torch.nn.Module):
|
||||||
)
|
)
|
||||||
return y.to(self.dtype)
|
return y.to(self.dtype)
|
||||||
|
|
||||||
qinput, scale = fp8_quantize(input, scalar=True)
|
qinput, scale = fp8_quantize(
|
||||||
output, _ = torch._scaled_mm(
|
input,
|
||||||
|
self.input_scale,
|
||||||
|
scale_upper_bound=self.scale_upper_bound,
|
||||||
|
scalar=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
per_tensor_weights = self.scale.numel() == 1
|
||||||
|
per_tensor_activations = scale.numel() == 1
|
||||||
|
|
||||||
|
if SYSTEM != "rocm" or (per_tensor_weights and per_tensor_activations):
|
||||||
|
output = torch._scaled_mm(
|
||||||
qinput,
|
qinput,
|
||||||
self.qweight.t(),
|
self.qweight.t(),
|
||||||
out_dtype=self.dtype,
|
out_dtype=self.dtype,
|
||||||
|
@ -275,6 +403,30 @@ class Fp8Linear(torch.nn.Module):
|
||||||
scale_b=self.scale,
|
scale_b=self.scale,
|
||||||
bias=self.bias,
|
bias=self.bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(output, tuple) and len(output) == 2:
|
||||||
|
output = output[0]
|
||||||
|
else:
|
||||||
|
device_identity = None
|
||||||
|
if SYSTEM == "rocm":
|
||||||
|
device_identity = self.get_shared_device_identity(self.qweight.device)
|
||||||
|
|
||||||
|
output = torch._scaled_mm(
|
||||||
|
qinput,
|
||||||
|
self.qweight.t(),
|
||||||
|
scale_a=device_identity,
|
||||||
|
scale_b=device_identity,
|
||||||
|
out_dtype=torch.float32,
|
||||||
|
)
|
||||||
|
if isinstance(output, tuple) and len(output) == 2:
|
||||||
|
output = output[0]
|
||||||
|
|
||||||
|
output = output * scale * self.scale.t()
|
||||||
|
if self.bias is not None:
|
||||||
|
output = output + self.bias
|
||||||
|
|
||||||
|
output = output.to(dtype=self.dtype)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -392,7 +392,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_gptq_params(self, weights: Weights):
|
def _get_gptq_params(self, weights: Weights):
|
||||||
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"):
|
if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"):
|
||||||
self.bits = weights.get_tensor("gptq_bits").item()
|
self.bits = weights.get_tensor("gptq_bits").item()
|
||||||
self.groupsize = weights.get_tensor("gptq_groupsize").item()
|
self.groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||||
self.desc_act = False
|
self.desc_act = False
|
||||||
|
@ -400,7 +400,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||||
# before the `gptq_sym` setting tensor was added.
|
# before the `gptq_sym` setting tensor was added.
|
||||||
self.sym = (
|
self.sym = (
|
||||||
weights.get_tensor("gptq_sym").item()
|
weights.get_tensor("gptq_sym").item()
|
||||||
if weights._has_tensor("gptq_sym")
|
if weights.has_tensor("gptq_sym")
|
||||||
else False
|
else False
|
||||||
)
|
)
|
||||||
self.quant_method = "gptq"
|
self.quant_method = "gptq"
|
||||||
|
|
|
@ -62,7 +62,14 @@ class GPTQMarlinFP8Linear(nn.Module):
|
||||||
return cls(qweight=qweight, scales=scales.to(dtype), bias=bias)
|
return cls(qweight=qweight, scales=scales.to(dtype), bias=bias)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_fp8(cls, weight, scale, _input_scale, bias, dtype):
|
def from_fp8(
|
||||||
|
cls,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
scale: torch.Tensor,
|
||||||
|
bias: torch.Tensor,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
return cls(qweight=weight, scales=scale.to(dtype), bias=bias)
|
return cls(qweight=weight, scales=scale.to(dtype), bias=bias)
|
||||||
|
|
||||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
|
@ -231,7 +231,7 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_gptq_params(self, weights: Weights):
|
def _get_gptq_params(self, weights: Weights):
|
||||||
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"):
|
if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"):
|
||||||
self.bits = weights.get_tensor("gptq_bits").item()
|
self.bits = weights.get_tensor("gptq_bits").item()
|
||||||
self.groupsize = weights.get_tensor("gptq_groupsize").item()
|
self.groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||||
self.desc_act = False
|
self.desc_act = False
|
||||||
|
@ -239,7 +239,7 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
|
||||||
# before the `gptq_sym` setting tensor was added.
|
# before the `gptq_sym` setting tensor was added.
|
||||||
self.sym = (
|
self.sym = (
|
||||||
weights.get_tensor("gptq_sym").item()
|
weights.get_tensor("gptq_sym").item()
|
||||||
if weights._has_tensor("gptq_sym")
|
if weights.has_tensor("gptq_sym")
|
||||||
else False
|
else False
|
||||||
)
|
)
|
||||||
self.quant_method = "gptq"
|
self.quant_method = "gptq"
|
||||||
|
|
|
@ -357,17 +357,32 @@ def get_model(
|
||||||
compression_config = config_dict.get("compression_config", None)
|
compression_config = config_dict.get("compression_config", None)
|
||||||
if quantization_config is not None and quantize is None:
|
if quantization_config is not None and quantize is None:
|
||||||
method = quantization_config.get("quant_method", None)
|
method = quantization_config.get("quant_method", None)
|
||||||
|
config_groups = quantization_config.get("config_groups", None)
|
||||||
if method in {"gptq", "awq", "exl2"}:
|
if method in {"gptq", "awq", "exl2"}:
|
||||||
log_master(logger.info, f"Auto selecting quantization method {method}")
|
log_master(logger.info, f"Auto selecting quantization method {method}")
|
||||||
quantize = method
|
quantize = method
|
||||||
elif method == "fbgemm_fp8":
|
elif method == "fbgemm_fp8" or method == "fp8":
|
||||||
log_master(logger.info, "Auto selecting quantization method fp8")
|
log_master(logger.info, "Auto selecting quantization method fp8")
|
||||||
quantize = "fp8"
|
quantize = "fp8"
|
||||||
|
elif config_groups is not None:
|
||||||
|
# TODO: at some point we should probably fully parse the compression
|
||||||
|
# configuration to know which parameters are compressed.
|
||||||
|
for _, group in config_groups.items():
|
||||||
|
weights_config = group.get("weights")
|
||||||
|
if weights_config is not None:
|
||||||
|
if (
|
||||||
|
weights_config["type"] == "float"
|
||||||
|
and weights_config["num_bits"] == 8
|
||||||
|
):
|
||||||
|
log_master(
|
||||||
|
logger.info, "Auto selecting quantization method fp8"
|
||||||
|
)
|
||||||
|
quantize = "fp8"
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
log_master(logger.warning, f"Unknown quantization method {method}")
|
log_master(logger.warning, f"Unknown quantization method {method}")
|
||||||
elif compression_config is not None:
|
elif compression_config is not None:
|
||||||
# TODO: at some point we should probably fully parse the compression
|
# `compression_config` renamed to `quantization_config`; support retained for backward compatibility.
|
||||||
# configuration to know which parameters are compressed.
|
|
||||||
config_groups = compression_config.get("config_groups")
|
config_groups = compression_config.get("config_groups")
|
||||||
if config_groups is not None:
|
if config_groups is not None:
|
||||||
for _, group in config_groups.items():
|
for _, group in config_groups.items():
|
||||||
|
|
|
@ -197,7 +197,7 @@ class Weights:
|
||||||
slice_ = f.get_slice(tensor_name)
|
slice_ = f.get_slice(tensor_name)
|
||||||
return slice_
|
return slice_
|
||||||
|
|
||||||
def _has_tensor(self, tensor_name: str):
|
def has_tensor(self, tensor_name: str):
|
||||||
try:
|
try:
|
||||||
self.get_filename(tensor_name)
|
self.get_filename(tensor_name)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
Loading…
Reference in New Issue