diff --git a/server/text_generation_server/layers/marlin/gptq.py b/server/text_generation_server/layers/marlin/gptq.py index 7245431f..4da02da2 100644 --- a/server/text_generation_server/layers/marlin/gptq.py +++ b/server/text_generation_server/layers/marlin/gptq.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from enum import Enum, auto from typing import List, Optional, Union import numpy @@ -23,12 +24,14 @@ except ImportError: try: major, _minor = torch.cuda.get_device_capability() has_sm_8_0 = major >= 8 + has_sm_9_0 = major >= 9 except Exception: has_sm_8_0 = False GPTQ_MARLIN_BITS = [4, 8] GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128] +MACHETE_GROUP_SIZES = [-1, 128] MARLIN_TILE_SIZE = 16 @@ -48,6 +51,24 @@ def can_use_gptq_marlin( ) +def _can_use_machete(*, desc_act: bool, group_size: int, num_bits: int) -> bool: + return ( + (not desc_act or group_size == -1) + and has_sm_9_0 + and group_size in MACHETE_GROUP_SIZES + and num_bits in [4, 8] + ) + + +class GPTQMarlinKernel(Enum): + """ + Enum for selecting the GPTQ Marlin kernel. + """ + + GPTQ_MARLIN = auto() + MACHETE = auto() + + class GPTQMarlinWeightsLoader(WeightsLoader): """ Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels. @@ -69,9 +90,20 @@ class GPTQMarlinWeightsLoader(WeightsLoader): self.quant_method = quant_method self.quantize = quantize self.sym = sym + self.kernel = ( + GPTQMarlinKernel.MACHETE + if _can_use_machete(desc_act=desc_act, group_size=groupsize, num_bits=bits) + else GPTQMarlinKernel.GPTQ_MARLIN + ) + + def _log_kernel(self): + if self.kernel == GPTQMarlinKernel.MACHETE: + log_once(logger.info, "Using Machete kernels") + else: + log_once(logger.info, "Using GPTQ-Marlin kernels") def get_weights(self, weights: Weights, prefix: str): - log_once(logger.info, "Using GPTQ-Marlin kernels") + self._log_kernel() try: qweight = weights.get_tensor(f"{prefix}.qweight") except RuntimeError: @@ -98,6 +130,7 @@ class GPTQMarlinWeightsLoader(WeightsLoader): bits=self.bits, desc_act=self.desc_act, groupsize=self.groupsize, + kernel=self.kernel, quant_method=self.quant_method, sym=self.sym, sharded_infeatures=False, @@ -141,6 +174,7 @@ class GPTQMarlinWeightsLoader(WeightsLoader): bits=self.bits, desc_act=self.desc_act, groupsize=self.groupsize, + kernel=self.kernel, quant_method=self.quant_method, sym=self.sym, sharded_infeatures=False, @@ -183,13 +217,14 @@ class GPTQMarlinWeightsLoader(WeightsLoader): bits=self.bits, desc_act=self.desc_act, groupsize=self.groupsize, + kernel=self.kernel, quant_method=self.quant_method, sym=self.sym, sharded_infeatures=False, ) def get_weights_row(self, weights: Weights, prefix: str): - log_once(logger.info, "Using GPTQ-Marlin kernels") + self._log_kernel() try: qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) except RuntimeError: @@ -225,6 +260,7 @@ class GPTQMarlinWeightsLoader(WeightsLoader): bits=self.bits, desc_act=self.desc_act, groupsize=self.groupsize, + kernel=self.kernel, quant_method=self.quant_method, sym=self.sym, sharded_infeatures=sharded_in_features, @@ -255,8 +291,10 @@ class GPTQMarlinWeight(Weight): qzeros: torch.Tensor scales: torch.Tensor g_idx: torch.Tensor + kernel: GPTQMarlinKernel perm: torch.Tensor bits: int + groupsize: int is_full_k: bool def __post_init__(self): @@ -284,6 +322,7 @@ def repack_gptq_for_marlin( quant_method: str, sym: bool, sharded_infeatures: bool, + kernel: GPTQMarlinKernel, ) -> GPTQMarlinWeight: """Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels.""" _check_marlin_kernels() @@ -340,17 +379,23 @@ def repack_gptq_for_marlin( out_features, bits, ) - + elif kernel == GPTQMarlinKernel.MACHETE: + repacked = qweight + repacked.data = marlin_kernels.machete_prepack_B( + # Convert to column-major. + repacked.data.t().contiguous().t(), + bits, + qzeros is not None, + ) else: repacked = marlin_kernels.gptq_marlin_repack( qweight, perm, in_features, out_features, bits ) + scales = permute_scales(scales) if qzeros is None: qzeros = torch.empty(0, dtype=torch.int, device=qweight.device) - scales = permute_scales(scales) - is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures) return GPTQMarlinWeight( @@ -358,8 +403,10 @@ def repack_gptq_for_marlin( qzeros=qzeros, scales=scales, g_idx=g_idx, + kernel=kernel, perm=perm, bits=bits, + groupsize=groupsize, is_full_k=is_full_k, ) @@ -385,7 +432,10 @@ class GPTQMarlinLinear(nn.Module): out_features = weight.scales.shape[1] _check_valid_shape(in_features=in_features, out_features=out_features) + self.kernel = weight.kernel + self.bits = weight.bits + self.groupsize = weight.groupsize self.is_full_k = weight.is_full_k self.qweight = weight.qweight @@ -406,22 +456,31 @@ class GPTQMarlinLinear(nn.Module): assert marlin_kernels is not None A_flat = A.view(-1, A.shape[-1]) - C = marlin_kernels.gptq_marlin_gemm( - A_flat, - self.qweight, - self.scales, - self.qzeros, - self.g_idx, - self.perm, - self.workspace, - self.bits, - A_flat.shape[0], - self.scales.shape[1], - A_flat.shape[1], - self.is_full_k, - self.qzeros.numel() > 0, - True, - ) + if self.kernel == GPTQMarlinKernel.MACHETE: + C = marlin_kernels.machete_gemm( + A=A_flat, + B=self.qweight, + num_bits=self.bits, + scales=self.scales, + zeros=self.qzeros if self.qzeros.numel() > 0 else None, + ) + else: + C = marlin_kernels.gptq_marlin_gemm( + A_flat, + self.qweight, + self.scales, + self.qzeros, + self.g_idx, + self.perm, + self.workspace, + self.bits, + A_flat.shape[0], + self.scales.shape[1], + A_flat.shape[1], + self.is_full_k, + self.qzeros.numel() > 0, + True, + ) C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) if self.bias is not None: