Add support for scalar FP8 weight scales (#2550)
* Add support for scalar FP8 weight scales * Support LLM compressor FP8 checkpoints on H100 On H100, we use fbgemm-gpu, which requires bfloat16 as the input dtype. However, we wouldn't pick up fp8 quantization for models quantized with LLM compressor. This change adds enough parsing to detect if models have FP8-quantized weights. * Remove stray debug print
This commit is contained in:
parent
0ff6ff60ad
commit
c29dc89c18
|
@ -87,9 +87,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
|
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
scale = weights.get_tensor(
|
scale = (
|
||||||
f"{prefix}.weight_scale", to_dtype=False
|
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
).reshape(-1)
|
.reshape(-1)
|
||||||
|
.expand(w.shape[0])
|
||||||
|
)
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
|
@ -113,9 +115,16 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
|
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
scale = weights.get_packed_sharded(
|
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False
|
if scale.numel() > 1:
|
||||||
).reshape(-1)
|
scale = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.weight_scale",
|
||||||
|
dim=0,
|
||||||
|
block_sizes=block_sizes,
|
||||||
|
to_dtype=False,
|
||||||
|
)
|
||||||
|
scale = scale.reshape(-1).expand(w.shape[0])
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
|
@ -132,16 +141,19 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
w = [
|
w = [
|
||||||
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
|
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
|
||||||
]
|
]
|
||||||
|
shapes = [x.shape for x in w]
|
||||||
|
|
||||||
# Concat then send to the device
|
# Concat then send to the device
|
||||||
w = torch.cat(w, dim=dim).to(weights.device)
|
w = torch.cat(w, dim=dim).to(weights.device)
|
||||||
|
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
scale = [
|
scale = [
|
||||||
weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False)
|
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
|
||||||
for p in prefixes
|
for p, shape in zip(prefixes, shapes)
|
||||||
]
|
]
|
||||||
scale = torch.cat(scale, dim=0).reshape(-1)
|
scale = torch.cat(scale, dim=0).reshape(-1)
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
|
@ -157,9 +169,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
scale = weights.get_tensor(
|
scale = (
|
||||||
f"{prefix}.weight_scale", to_dtype=False
|
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
).reshape(-1)
|
.reshape(-1)
|
||||||
|
.expand(w.shape[0])
|
||||||
|
)
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
|
@ -182,6 +196,9 @@ class Fp8Weight(Weight):
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
if self.weight_scale is None:
|
if self.weight_scale is None:
|
||||||
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
|
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
|
||||||
|
# This is not checked by the fbgemm kernels, but they require contiguous
|
||||||
|
# memory. Can be non-contiguous when we e.g. expand from scalars.
|
||||||
|
self.weight_scale = self.weight_scale.contiguous()
|
||||||
return get_fp8_linear().from_fp8(
|
return get_fp8_linear().from_fp8(
|
||||||
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
|
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
|
||||||
)
|
)
|
||||||
|
@ -222,6 +239,9 @@ class Fp8Linear(torch.nn.Module):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
|
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
|
||||||
|
if FBGEMM_DYN_AVAILABLE:
|
||||||
|
# fbgemm needs float32 scales.
|
||||||
|
scale = scale.float()
|
||||||
return cls(
|
return cls(
|
||||||
qweight=weight,
|
qweight=weight,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
|
@ -256,3 +276,10 @@ class Fp8Linear(torch.nn.Module):
|
||||||
bias=self.bias,
|
bias=self.bias,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
|
||||||
|
scale = weights.get_tensor(prefix, to_dtype=False)
|
||||||
|
if scale.numel() > 1:
|
||||||
|
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
|
||||||
|
return scale.reshape(-1).expand(shape[0])
|
||||||
|
|
|
@ -334,6 +334,7 @@ def get_model(
|
||||||
model_type = config_dict.get("model_type", None)
|
model_type = config_dict.get("model_type", None)
|
||||||
|
|
||||||
quantization_config = config_dict.get("quantization_config", None)
|
quantization_config = config_dict.get("quantization_config", None)
|
||||||
|
compression_config = config_dict.get("compression_config", None)
|
||||||
if quantization_config is not None and quantize is None:
|
if quantization_config is not None and quantize is None:
|
||||||
method = quantization_config.get("quant_method", None)
|
method = quantization_config.get("quant_method", None)
|
||||||
if method in {"gptq", "awq", "exl2"}:
|
if method in {"gptq", "awq", "exl2"}:
|
||||||
|
@ -344,6 +345,23 @@ def get_model(
|
||||||
quantize = "fp8"
|
quantize = "fp8"
|
||||||
else:
|
else:
|
||||||
log_master(logger.warning, f"Unknown quantization method {method}")
|
log_master(logger.warning, f"Unknown quantization method {method}")
|
||||||
|
elif compression_config is not None:
|
||||||
|
# TODO: at some point we should probably fully parse the compression
|
||||||
|
# configuration to know which parameters are compressed.
|
||||||
|
config_groups = compression_config.get("config_groups")
|
||||||
|
if config_groups is not None:
|
||||||
|
for _, group in config_groups.items():
|
||||||
|
weights_config = group.get("weights")
|
||||||
|
if weights_config is not None:
|
||||||
|
if (
|
||||||
|
weights_config["type"] == "float"
|
||||||
|
and weights_config["num_bits"] == 8
|
||||||
|
):
|
||||||
|
log_master(
|
||||||
|
logger.info, "Auto selecting quantization method fp8"
|
||||||
|
)
|
||||||
|
quantize = "fp8"
|
||||||
|
break
|
||||||
|
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||||
|
@ -768,7 +786,6 @@ def get_model(
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
|
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
|
||||||
print(f">>> model_type: {model_type}")
|
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
|
Loading…
Reference in New Issue