diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json index dcb4d063..bf981e4f 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json @@ -11,12 +11,12 @@ }, { "id": 2323, - "logprob": -9.421875, + "logprob": -9.5625, "text": "Test" }, { "id": 1715, - "logprob": -10.546875, + "logprob": -10.375, "text": " request" } ], @@ -24,66 +24,66 @@ "tokens": [ { "id": 25, - "logprob": -0.8535156, + "logprob": -0.8984375, "special": false, "text": ":" }, { "id": 2209, - "logprob": -2.4804688, + "logprob": -2.78125, "special": false, "text": " Is" }, { "id": 279, - "logprob": -0.7167969, + "logprob": -0.6328125, "special": false, "text": " the" }, { "id": 734, - "logprob": -2.625, + "logprob": -2.703125, "special": false, "text": " function" }, { "id": 330, - "logprob": -0.35131836, + "logprob": -0.34179688, "special": false, "text": " \"" }, { "id": 4110, - "logprob": -2.4101562, + "logprob": -2.359375, "special": false, "text": "Create" }, { - "id": 264, - "logprob": -0.23181152, + "id": 7575, + "logprob": -2.1875, "special": false, - "text": " a" - }, - { - "id": 502, - "logprob": -0.25512695, - "special": false, - "text": " new" - }, - { - "id": 1052, - "logprob": -1.2792969, - "special": false, - "text": " file" + "text": "Process" }, { "id": 1, - "logprob": -1.2529297, + "logprob": -0.07910156, "special": false, "text": "\"" + }, + { + "id": 304, + "logprob": -0.83203125, + "special": false, + "text": " in" + }, + { + "id": 12468, + "logprob": -1.8203125, + "special": false, + "text": " Win" } ], "top_tokens": null }, - "generated_text": "Test request: Is the function \"Create a new file\"" + "generated_text": "Test request: Is the function \"CreateProcess\" in Win" } diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index cdf16d6b..bf5a0989 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -76,7 +76,9 @@ class HybridFP8UnquantLoader(WeightsLoader): if w.dtype == torch.float8_e4m3fn: # FP8 branch - scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + scale = weights.get_tensor( + f"{prefix}.weight_scale", to_dtype=False + ).reshape(-1) return Fp8Weight( weight=w, weight_scale=scale, @@ -102,7 +104,7 @@ class HybridFP8UnquantLoader(WeightsLoader): # FP8 branch scale = weights.get_packed_sharded( f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False - ) + ).reshape(-1) return Fp8Weight( weight=w, weight_scale=scale, @@ -115,8 +117,12 @@ class HybridFP8UnquantLoader(WeightsLoader): return UnquantizedWeight(w) def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): - w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] - w = torch.cat(w, dim=dim) + # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet + w = [ + weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes + ] + # Concat then send to the device + w = torch.cat(w, dim=dim).to(weights.device) # FP8 branch if w.dtype == torch.float8_e4m3fn: @@ -124,7 +130,7 @@ class HybridFP8UnquantLoader(WeightsLoader): weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False) for p in prefixes ] - scale = torch.cat(scale, dim=0) + scale = torch.cat(scale, dim=0).reshape(-1) return Fp8Weight( weight=w, weight_scale=scale, @@ -140,7 +146,9 @@ class HybridFP8UnquantLoader(WeightsLoader): w = weights.get_sharded(f"{prefix}.weight", dim=1) # FP8 branch if w.dtype == torch.float8_e4m3fn: - scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0, to_dtype=False) + scale = weights.get_tensor( + f"{prefix}.weight_scale", to_dtype=False + ).reshape(-1) return Fp8Weight( weight=w, weight_scale=scale, diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index 40271c35..a28012da 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -504,7 +504,7 @@ class GPTQMarlinFP8Linear(nn.Module): def __init__( self, qweight: torch.Tensor, - scale: torch.Tensor, + scales: torch.Tensor, bias: Optional[torch.Tensor], ) -> None: super().__init__() @@ -514,8 +514,11 @@ class GPTQMarlinFP8Linear(nn.Module): log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") - scale = scale.to(torch.float16) - qweight, scales = repack_fp8_for_marlin(qweight, scale) + scales = scales.unsqueeze(0) + if scales.shape[1] == 1: + out_features, in_features = qweight.shape + scales = scales.repeat(1, out_features) + qweight, scales = repack_fp8_for_marlin(qweight, scales) in_features = qweight.shape[0] * MARLIN_TILE_SIZE out_features = scales.shape[1] @@ -530,13 +533,13 @@ class GPTQMarlinFP8Linear(nn.Module): ) @classmethod - def from_unquant(cls, weight, bias, _dtype): - qweight, scale = fp8_quantize(weight) - return cls(qweight=qweight, scale=scale, bias=bias) + def from_unquant(cls, weight, bias, dtype): + qweight, scales = fp8_quantize(weight) + return cls(qweight=qweight, scales=scales.to(dtype), bias=bias) @classmethod - def from_fp8(cls, weight, scale, _input_scale, bias, _dtype): - return cls(qweight=weight, scale=scale, bias=bias) + def from_fp8(cls, weight, scale, _input_scale, bias, dtype): + return cls(qweight=weight, scales=scale.to(dtype), bias=bias) def forward(self, A: torch.Tensor) -> torch.Tensor: assert marlin_kernels is not None @@ -591,7 +594,7 @@ def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: return packed -def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor): +def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor): """ Repack FP8 tensor for GPTQ-Marlin. """ @@ -608,7 +611,6 @@ def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor): qweight, perm, in_features, out_features, 8 ) - scales = scale.reshape(1, 1).repeat(1, out_features) scales = permute_scales(scales) return repacked, scales @@ -621,7 +623,7 @@ class MarlinWeight(Weight): Attributes: B (torch.Tensor): int4-quantized weights packed into int32. - s (torch.Tensor): float16 scales. + s (torch.Tensor): bfloat16/float16 scales. """ B: torch.Tensor @@ -629,7 +631,7 @@ class MarlinWeight(Weight): def __post_init__(self): assert self.B.dtype == torch.int32 - assert self.s.dtype == torch.float16 + assert self.s.dtype in [torch.float16, torch.bfloat16] def get_linear(self, bias: torch.Tensor): return MarlinLinear(weight=self, bias=bias) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index a43cdfed..1cd13a2a 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -306,14 +306,32 @@ def get_model( max_input_tokens: int, ) -> Model: global FLASH_ATTENTION + + config_dict, _ = PretrainedConfig.get_config_dict( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + model_type = config_dict.get("model_type", None) + + quantization_config = config_dict.get("quantization_config", None) + if quantization_config is not None and quantize is None: + method = quantization_config.get("quant_method", None) + if method in {"gptq", "awq", "exl2"}: + log_master(logger.info, f"Auto selecting quantization method {method}") + quantize = method + elif method == "fbgemm_fp8": + log_master(logger.info, "Auto selecting quantization method fp8") + quantize = "fp8" + else: + log_master(logger.warning, f"Unknown quantization method {method}") + if dtype is None: if quantize in ["awq", "exl2", "gptq", "marlin"]: # These quantizers only work with float16 params. dtype = torch.float16 elif quantize == "fp8": - from text_generation_server.layers.fp8 import FBGEMM_MM_AVAILABLE + from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE - if FBGEMM_MM_AVAILABLE: + if FBGEMM_DYN_AVAILABLE: # fbgemm kernels are fp8xfp8->bf16 dtype = torch.bfloat16 else: @@ -332,11 +350,6 @@ def get_model( else: set_speculate(0) - config_dict, _ = PretrainedConfig.get_config_dict( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - model_type = config_dict.get("model_type", None) - speculator = None if "medusa_num_heads" in config_dict: medusa_model_id = model_id @@ -451,14 +464,6 @@ def get_model( raise RuntimeError( f"Could not determine model type for {model_id} revision {revision}" ) - quantization_config = config_dict.get("quantization_config", None) - if quantization_config is not None and quantize is None: - method = quantization_config.get("quant_method", None) - if method in {"gptq", "awq", "exl2"}: - log_master(logger.info, f"Auto selecting quantization method {method}") - quantize = method - else: - log_master(logger.warning, f"Unknown quantization method {method}") if quantize == "exl2" and sharded: raise RuntimeError( diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 66bb6051..108ced48 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -230,7 +230,9 @@ class Weights: tensor = tensor.to(device=self.device) return tensor - def get_partial_sharded(self, tensor_name: str, dim: int, to_dtype=True): + def get_partial_sharded( + self, tensor_name: str, dim: int, to_device=True, to_dtype=True + ): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) @@ -256,10 +258,11 @@ class Weights: and to_dtype ): tensor = tensor.to(dtype=self.dtype) - tensor = tensor.to(device=self.device) + if to_device: + tensor = tensor.to(device=self.device) return tensor - def get_sharded(self, tensor_name: str, dim: int, to_dtype=True): + def get_sharded(self, tensor_name: str, dim: int, to_device=True, to_dtype=True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) @@ -268,7 +271,9 @@ class Weights: assert ( size % world_size == 0 ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" - return self.get_partial_sharded(tensor_name, dim, to_dtype=to_dtype) + return self.get_partial_sharded( + tensor_name, dim, to_device=to_device, to_dtype=to_dtype + ) def get_packed_sharded( self,