diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 59b08b55..61dd5115 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -87,9 +87,11 @@ class HybridFP8UnquantLoader(WeightsLoader): if w.dtype == torch.float8_e4m3fn: # FP8 branch - scale = weights.get_tensor( - f"{prefix}.weight_scale", to_dtype=False - ).reshape(-1) + scale = ( + weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + .reshape(-1) + .expand(w.shape[0]) + ) return Fp8Weight( weight=w, weight_scale=scale, @@ -113,9 +115,16 @@ class HybridFP8UnquantLoader(WeightsLoader): if w.dtype == torch.float8_e4m3fn: # FP8 branch - scale = weights.get_packed_sharded( - f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False - ).reshape(-1) + scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + if scale.numel() > 1: + scale = weights.get_packed_sharded( + f"{prefix}.weight_scale", + dim=0, + block_sizes=block_sizes, + to_dtype=False, + ) + scale = scale.reshape(-1).expand(w.shape[0]) + return Fp8Weight( weight=w, weight_scale=scale, @@ -132,16 +141,19 @@ class HybridFP8UnquantLoader(WeightsLoader): w = [ weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes ] + shapes = [x.shape for x in w] + # Concat then send to the device w = torch.cat(w, dim=dim).to(weights.device) # FP8 branch if w.dtype == torch.float8_e4m3fn: scale = [ - weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False) - for p in prefixes + _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape) + for p, shape in zip(prefixes, shapes) ] scale = torch.cat(scale, dim=0).reshape(-1) + return Fp8Weight( weight=w, weight_scale=scale, @@ -157,9 +169,11 @@ class HybridFP8UnquantLoader(WeightsLoader): w = weights.get_sharded(f"{prefix}.weight", dim=1) # FP8 branch if w.dtype == torch.float8_e4m3fn: - scale = weights.get_tensor( - f"{prefix}.weight_scale", to_dtype=False - ).reshape(-1) + scale = ( + weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + .reshape(-1) + .expand(w.shape[0]) + ) return Fp8Weight( weight=w, weight_scale=scale, @@ -182,6 +196,9 @@ class Fp8Weight(Weight): def get_linear(self, bias: torch.Tensor): if self.weight_scale is None: return get_fp8_linear().from_unquant(self.weight, bias, self.dtype) + # This is not checked by the fbgemm kernels, but they require contiguous + # memory. Can be non-contiguous when we e.g. expand from scalars. + self.weight_scale = self.weight_scale.contiguous() return get_fp8_linear().from_fp8( self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype ) @@ -222,6 +239,9 @@ class Fp8Linear(torch.nn.Module): @classmethod def from_fp8(cls, weight, scale, input_scale, bias, dtype): + if FBGEMM_DYN_AVAILABLE: + # fbgemm needs float32 scales. + scale = scale.float() return cls( qweight=weight, scale=scale, @@ -256,3 +276,10 @@ class Fp8Linear(torch.nn.Module): bias=self.bias, ) return output + + +def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size): + scale = weights.get_tensor(prefix, to_dtype=False) + if scale.numel() > 1: + scale = weights.get_sharded(prefix, dim=0, to_dtype=False) + return scale.reshape(-1).expand(shape[0]) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index fc530b38..e5e5aabb 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -334,6 +334,7 @@ def get_model( model_type = config_dict.get("model_type", None) quantization_config = config_dict.get("quantization_config", None) + compression_config = config_dict.get("compression_config", None) if quantization_config is not None and quantize is None: method = quantization_config.get("quant_method", None) if method in {"gptq", "awq", "exl2"}: @@ -344,6 +345,23 @@ def get_model( quantize = "fp8" else: log_master(logger.warning, f"Unknown quantization method {method}") + elif compression_config is not None: + # TODO: at some point we should probably fully parse the compression + # configuration to know which parameters are compressed. + config_groups = compression_config.get("config_groups") + if config_groups is not None: + for _, group in config_groups.items(): + weights_config = group.get("weights") + if weights_config is not None: + if ( + weights_config["type"] == "float" + and weights_config["num_bits"] == 8 + ): + log_master( + logger.info, "Auto selecting quantization method fp8" + ) + quantize = "fp8" + break if dtype is None: if quantize in ["awq", "exl2", "gptq", "marlin"]: @@ -768,7 +786,6 @@ def get_model( ) elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3: - print(f">>> model_type: {model_type}") if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id,