diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 4fb151a9..f6616d3e 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -124,50 +124,7 @@ class GPTQWeightsLoader(WeightsLoader): self.sym = sym def get_weights(self, weights: Weights, prefix: str): - from text_generation_server.layers.marlin import ( - can_use_gptq_marlin, - repack_gptq_for_marlin, - ) - self._get_gptq_params(weights) - if can_use_gptq_marlin( - bits=self.bits, - groupsize=self.groupsize, - quant_method=self.quant_method, - quantize=self.quantize, - sym=self.sym, - ): - log_once(logger.info, "Using GPTQ-Marlin kernels") - try: - qweight = weights.get_tensor(f"{prefix}.qweight") - except RuntimeError: - raise RuntimeError( - f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" - ) - - if not self.sym: - qzeros = weights.get_tensor(f"{prefix}.qzeros") - else: - qzeros = None - - if self.quant_method == "awq": - g_idx = None - else: - g_idx = weights.get_tensor(f"{prefix}.g_idx") - scales = weights.get_tensor(f"{prefix}.scales") - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - qzeros=qzeros, - g_idx=g_idx, - bits=self.bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method=self.quant_method, - sym=self.sym, - sharded_infeatures=False, - ) use_exllama = True if self.bits != 4: @@ -248,11 +205,6 @@ class GPTQWeightsLoader(WeightsLoader): prefix: str, block_sizes: Union[int, List[int]], ): - from text_generation_server.layers.marlin import ( - can_use_gptq_marlin, - repack_gptq_for_marlin, - ) - try: qweight = weights.get_packed_sharded( f"{prefix}.qweight", dim=1, block_sizes=block_sizes @@ -267,36 +219,6 @@ class GPTQWeightsLoader(WeightsLoader): scales = scales.to(dtype=weights.dtype) self._get_gptq_params(weights) - if can_use_gptq_marlin( - bits=self.bits, - groupsize=self.groupsize, - quant_method=self.quant_method, - quantize=self.quantize, - sym=self.sym, - ): - if not self.sym: - qzeros = weights.get_packed_sharded( - f"{prefix}.qzeros", dim=1, block_sizes=block_sizes - ) - else: - qzeros = None - - if self.quant_method == "awq": - g_idx = None - else: - g_idx = weights.get_tensor(f"{prefix}.g_idx") - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - qzeros=qzeros, - g_idx=g_idx, - bits=self.bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method=self.quant_method, - sym=self.sym, - sharded_infeatures=False, - ) qzeros = weights.get_packed_sharded( f"{prefix}.qzeros", dim=1, block_sizes=block_sizes @@ -334,11 +256,6 @@ class GPTQWeightsLoader(WeightsLoader): ) def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): - from text_generation_server.layers.marlin import ( - can_use_gptq_marlin, - repack_gptq_for_marlin, - ) - try: qweight = torch.cat( [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 @@ -353,41 +270,6 @@ class GPTQWeightsLoader(WeightsLoader): ) self._get_gptq_params(weights) - if can_use_gptq_marlin( - bits=self.bits, - groupsize=self.groupsize, - quant_method=self.quant_method, - quantize=self.quantize, - sym=self.sym, - ): - - if not self.sym: - qzeros = torch.cat( - [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 - ) - else: - qzeros = None - - if self.quant_method == "awq": - g_idx = None - else: - w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - qzeros=qzeros, - g_idx=g_idx, - bits=self.bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method=self.quant_method, - sym=self.sym, - sharded_infeatures=False, - ) qzeros = torch.cat( [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 @@ -441,59 +323,7 @@ class GPTQWeightsLoader(WeightsLoader): ) def get_weights_row(self, weights: Weights, prefix: str): - from text_generation_server.layers.marlin import ( - can_use_gptq_marlin, - repack_gptq_for_marlin, - ) - self._get_gptq_params(weights) - if can_use_gptq_marlin( - bits=self.bits, - groupsize=self.groupsize, - quant_method=self.quant_method, - quantize=self.quantize, - sym=self.sym, - ): - log_once(logger.info, "Using GPTQ-Marlin kernels") - try: - qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" - ) - - if not self.sym: - if self.desc_act or self.groupsize == -1: - qzeros = weights.get_tensor(f"{prefix}.qzeros") - else: - qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0) - else: - qzeros = None - - if self.quant_method == "awq": - g_idx = None - else: - g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) - - if self.desc_act or self.groupsize == -1: - scales = weights.get_tensor(f"{prefix}.scales") - else: - scales = weights.get_sharded(f"{prefix}.scales", dim=0) - - sharded_in_features = weights.process_group.size() > 1 - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - qzeros=qzeros, - g_idx=g_idx, - bits=self.bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method=self.quant_method, - sym=self.sym, - sharded_infeatures=sharded_in_features, - ) use_exllama = True if self.bits != 4: diff --git a/server/text_generation_server/layers/marlin/__init__.py b/server/text_generation_server/layers/marlin/__init__.py index 8a143f96..3ff3ed58 100644 --- a/server/text_generation_server/layers/marlin/__init__.py +++ b/server/text_generation_server/layers/marlin/__init__.py @@ -1,7 +1,6 @@ from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear from text_generation_server.layers.marlin.gptq import ( - GPTQMarlinLinear, - GPTQMarlinWeight, + GPTQMarlinWeightsLoader, can_use_gptq_marlin, repack_gptq_for_marlin, ) @@ -9,8 +8,7 @@ from text_generation_server.layers.marlin.marlin import MarlinWeightsLoader __all__ = [ "GPTQMarlinFP8Linear", - "GPTQMarlinLinear", - "GPTQMarlinWeight", + "GPTQMarlinWeightsLoader", "MarlinWeightsLoader", "can_use_gptq_marlin", "repack_gptq_for_marlin", diff --git a/server/text_generation_server/layers/marlin/gptq.py b/server/text_generation_server/layers/marlin/gptq.py index 80674b14..c7663b60 100644 --- a/server/text_generation_server/layers/marlin/gptq.py +++ b/server/text_generation_server/layers/marlin/gptq.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional +from typing import List, Optional, Union import numpy import torch @@ -13,7 +13,7 @@ from text_generation_server.layers.marlin.util import ( ) from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.log import log_once -from text_generation_server.utils.weights import Weight +from text_generation_server.utils.weights import Weight, Weights, WeightsLoader try: import marlin_kernels @@ -48,6 +48,204 @@ def can_use_gptq_marlin( ) +class GPTQMarlinWeightsLoader(WeightsLoader): + """ + Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels. + """ + + def __init__( + self, + *, + bits: int, + desc_act: bool, + groupsize: int, + quant_method: str, + quantize: str, + sym: bool, + ): + self.bits = bits + self.desc_act = desc_act + self.groupsize = groupsize + self.quant_method = quant_method + self.quantize = quantize + self.sym = sym + + def get_weights(self, weights: Weights, prefix: str): + log_once(logger.info, "Using GPTQ-Marlin kernels") + try: + qweight = weights.get_tensor(f"{prefix}.qweight") + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" + ) + + if not self.sym: + qzeros = weights.get_tensor(f"{prefix}.qzeros") + else: + qzeros = None + + if self.quant_method == "awq": + g_idx = None + else: + g_idx = weights.get_tensor(f"{prefix}.g_idx") + scales = weights.get_tensor(f"{prefix}.scales") + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + qzeros=qzeros, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + quant_method=self.quant_method, + sym=self.sym, + sharded_infeatures=False, + ) + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + + try: + qweight = weights.get_packed_sharded( + f"{prefix}.qweight", dim=1, block_sizes=block_sizes + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight, make sure the model is already quantized." + ) + scales = weights.get_packed_sharded( + f"{prefix}.scales", dim=1, block_sizes=block_sizes + ) + scales = scales.to(dtype=weights.dtype) + + if not self.sym: + qzeros = weights.get_packed_sharded( + f"{prefix}.qzeros", dim=1, block_sizes=block_sizes + ) + else: + qzeros = None + + if self.quant_method == "awq": + g_idx = None + else: + g_idx = weights.get_tensor(f"{prefix}.g_idx") + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + qzeros=qzeros, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + quant_method=self.quant_method, + sym=self.sym, + sharded_infeatures=False, + ) + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + try: + qweight = torch.cat( + [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight, make sure the model is already quantized" + ) + + scales = torch.cat( + [weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 + ) + + if not self.sym: + qzeros = torch.cat( + [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 + ) + else: + qzeros = None + + if self.quant_method == "awq": + g_idx = None + else: + w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + qzeros=qzeros, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + quant_method=self.quant_method, + sym=self.sym, + sharded_infeatures=False, + ) + + def get_weights_row(self, weights: Weights, prefix: str): + log_once(logger.info, "Using GPTQ-Marlin kernels") + try: + qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" + ) + + if not self.sym: + if self.desc_act or self.groupsize == -1: + qzeros = weights.get_tensor(f"{prefix}.qzeros") + else: + qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0) + else: + qzeros = None + + if self.quant_method == "awq": + g_idx = None + else: + g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) + + if self.desc_act or self.groupsize == -1: + scales = weights.get_tensor(f"{prefix}.scales") + else: + scales = weights.get_sharded(f"{prefix}.scales", dim=0) + + sharded_in_features = weights.process_group.size() > 1 + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + qzeros=qzeros, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + quant_method=self.quant_method, + sym=self.sym, + sharded_infeatures=sharded_in_features, + ) + + def _get_gptq_params(self, weights: Weights): + if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"): + self.bits = weights.get_tensor("gptq_bits").item() + self.groupsize = weights.get_tensor("gptq_groupsize").item() + self.desc_act = False + # `server quantize` used asymmetric quantization unconditionally + # before the `gptq_sym` setting tensor was added. + self.sym = ( + weights.get_tensor("gptq_sym").item() + if weights._has_tensor("gptq_sym") + else False + ) + self.quant_method = "gptq" + + @dataclass class GPTQMarlinWeight(Weight): """ diff --git a/server/text_generation_server/utils/quantization.py b/server/text_generation_server/utils/quantization.py index cfb8b1db..ee561acc 100644 --- a/server/text_generation_server/utils/quantization.py +++ b/server/text_generation_server/utils/quantization.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import Optional from huggingface_hub import hf_hub_download +from text_generation_server.layers.marlin.gptq import can_use_gptq_marlin from text_generation_server.utils.weights import ( DefaultWeightsLoader, WeightsLoader, @@ -128,14 +129,32 @@ def get_loader( f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config." ) - return GPTQWeightsLoader( + if can_use_gptq_marlin( bits=quantizer_config.bits, - desc_act=quantizer_config.desc_act, groupsize=quantizer_config.groupsize, quant_method=quantizer_config.quant_method, quantize=quantize, sym=quantizer_config.sym, - ) + ): + from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader + + return GPTQMarlinWeightsLoader( + bits=quantizer_config.bits, + desc_act=quantizer_config.desc_act, + groupsize=quantizer_config.groupsize, + quant_method=quantizer_config.quant_method, + quantize=quantize, + sym=quantizer_config.sym, + ) + else: + return GPTQWeightsLoader( + bits=quantizer_config.bits, + desc_act=quantizer_config.desc_act, + groupsize=quantizer_config.groupsize, + quant_method=quantizer_config.quant_method, + quantize=quantize, + sym=quantizer_config.sym, + ) elif quantize == "bitsandbytes": from text_generation_server.layers.bnb import BNBWeight