from dataclasses import dataclass from typing import Optional, Tuple, List import torch import torch.nn as nn from text_generation_server.utils.import_utils import SYSTEM try: import marlin_kernels except ImportError: marlin_kernels = None try: major, _minor = torch.cuda.get_device_capability() has_sm_8_0 = major >= 8 except Exception: has_sm_8_0 = False GPTQ_MARLIN_BITS = [4, 8] GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128] MARLIN_TILE_SIZE = 16 def _check_marlin_kernels(): if not (SYSTEM == "cuda" and has_sm_8_0): raise NotImplementedError( "Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later." ) if marlin_kernels is None: raise NotImplementedError( "marlin is not installed, install it with: pip install server/marlin" ) def _check_valid_shape(in_features: int, out_features: int): if (in_features % 128 != 0 or out_features % 64 != 0) and ( in_features % 64 != 0 or out_features % 128 != 0 ): raise ValueError( f"The GPTQ Marlin kernel does not have a valid thread configuration for weight matrix with shape ({out_features}, {in_features})." " The shape elements must be divisible by (128, 64) or (64, 128)." ) # https://github.com/IST-DASLab/marlin/blob/2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c/marlin/__init__.py#L40C1-L68C54 def _get_perms() -> Tuple[List[int], List[int]]: scale_perm = [] for i in range(8): scale_perm.extend([i + 8 * j for j in range(8)]) scale_perm_single = [] for i in range(4): scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) return scale_perm, scale_perm_single _scale_perm, _scale_perm_single = _get_perms() def permute_scales(scales: torch.Tensor): out_features = scales.shape[1] if scales.shape[0] == 1: scales = scales.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single] else: scales = scales.reshape((-1, len(_scale_perm)))[:, _scale_perm] return scales.reshape((-1, out_features)).contiguous() @dataclass class GPTQMarlinWeight: """ Repacked GPTQ Marlin weights. """ qweight: torch.Tensor scales: torch.Tensor g_idx: torch.Tensor perm: torch.Tensor bits: int is_full_k: bool def __post_init__(self): assert self.qweight.dtype == torch.int32 assert self.scales.dtype == torch.float16 assert self.g_idx.dtype == torch.int32 assert self.perm.dtype == torch.int32 def repack_gptq_for_marlin( *, qweight: torch.Tensor, scales: torch.Tensor, g_idx: torch.Tensor, bits: int, desc_act: bool, groupsize: int, sym: bool, sharded_infeatures: bool, ) -> GPTQMarlinWeight: """Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels.""" _check_marlin_kernels() assert marlin_kernels is not None if bits not in GPTQ_MARLIN_BITS: supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS) raise RuntimeError( f"Repacking {bits}-bit GPTQ weights as Marlin is not supported, must be one of: {supported_bits}" ) if groupsize not in GPTQ_MARLIN_GROUP_SIZES: supported_sizes = ", ".join(str(b) for b in GPTQ_MARLIN_GROUP_SIZES) raise RuntimeError( f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}" ) if not sym: raise RuntimeError( "Repacking GPTQ weights with asymmetric quantization as Marlin is not supported." ) weights_per_int = 32 // bits in_features = qweight.shape[0] * weights_per_int out_features = qweight.shape[1] if in_features % groupsize != 0: raise ValueError( f"Number of input features ({in_features}) not divisible by group size ({groupsize})" ) if desc_act and groupsize != -1: perm = torch.argsort(g_idx).to(torch.int) g_idx = g_idx[perm] else: perm = torch.empty(0, dtype=torch.int, device=qweight.device) g_idx = torch.empty(0, dtype=torch.int, device=qweight.device) repacked = marlin_kernels.gptq_marlin_repack( qweight, perm, in_features, out_features, bits ) scales = permute_scales(scales) is_full_k = not (desc_act and sharded_infeatures) return GPTQMarlinWeight( qweight=repacked, scales=scales, g_idx=g_idx, perm=perm, bits=bits, is_full_k=is_full_k, ) class GPTQMarlinLinear(nn.Module): """ Linear layer for GPTQ weights that were converted for the GPTQ-Marlin kernels. """ def __init__( self, *, weight: GPTQMarlinWeight, bias: Optional[torch.Tensor], ): super().__init__() _check_marlin_kernels() assert marlin_kernels is not None in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE out_features = weight.scales.shape[1] _check_valid_shape(in_features=in_features, out_features=out_features) self.bits = weight.bits self.is_full_k = weight.is_full_k self.register_buffer("qweight", weight.qweight) self.register_buffer("scales", weight.scales) self.register_buffer("g_idx", weight.g_idx) self.register_buffer("perm", weight.perm) if bias is not None: self.register_buffer("bias", bias) else: self.bias = None self.workspace = torch.zeros( out_features // 64 * 16, dtype=torch.int, device=weight.qweight.device ) def forward(self, A: torch.Tensor) -> torch.Tensor: assert marlin_kernels is not None A_flat = A.view(-1, A.shape[-1]) C = marlin_kernels.gptq_marlin_gemm( A_flat, self.qweight, self.scales, self.g_idx, self.perm, self.workspace, self.bits, A_flat.shape[0], self.scales.shape[1], A_flat.shape[1], self.is_full_k, ) C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) if self.bias is not None: C += self.bias return C @dataclass class MarlinWeight: """ Marlin weights. Attributes: B (torch.Tensor): int4-quantized weights packed into int32. s (torch.Tensor): float16 scales. """ B: torch.Tensor s: torch.Tensor def __post_init__(self): assert self.B.dtype == torch.int32 assert self.s.dtype == torch.float16 class MarlinLinear(nn.Module): def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]): super().__init__() _check_marlin_kernels() assert marlin_kernels is not None in_features = weight.B.shape[0] * MARLIN_TILE_SIZE out_features = weight.s.shape[1] assert ( in_features % 128 == 0 ), f"Number of input features ({in_features}) not divisable by 128" assert ( out_features % 256 == 0 ), f"Number of output features ({out_features}) not divisable by 256" groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0] assert groupsize in { -1, 128, }, f"Group size must be -1 or 128, was {groupsize}" self.register_buffer("B", weight.B) self.register_buffer("s", weight.s) if bias is not None: self.register_buffer("bias", bias) else: self.bias = None self.workspace = torch.zeros( out_features // 64 * 16, dtype=torch.int, device=weight.B.device ) def forward(self, A: torch.Tensor) -> torch.Tensor: assert marlin_kernels is not None C = marlin_kernels.marlin_gemm( A.view(-1, A.shape[-1]), self.B, self.s, self.workspace, A.shape[0], self.s.shape[1], A.shape[1], ) C = C.reshape(A.shape[:-1] + (self.s.shape[1],)) if self.bias is not None: C += self.bias return C