from dataclasses import dataclass from typing import List, Optional, Union import torch import torch.nn as nn from text_generation_server.layers.marlin.util import _check_marlin_kernels from text_generation_server.utils.weights import Weight, Weights, WeightsLoader try: import marlin_kernels except ImportError: marlin_kernels = None class MarlinWeightsLoader(WeightsLoader): """Loader for Marlin-quantized weights.""" def __init__(self, *, bits: int, is_marlin_24: bool): self.bits = bits self.is_marlin_24 = is_marlin_24 def get_weights(self, weights: "Weights", prefix: str): """ Get weights at the given prefix and apply without tensor paralllism. """ is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" if is_marlin_24: try: B = weights.get_tensor(f"{prefix}.B_24") except RuntimeError: raise RuntimeError( "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." ) B_meta = weights.get_tensor(f"{prefix}.B_meta") s = weights.get_tensor(f"{prefix}.s") weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) else: try: B = weights.get_tensor(f"{prefix}.B") except RuntimeError: raise RuntimeError( "Cannot load `marlin` weight, make sure the model is already quantized." ) s = weights.get_tensor(f"{prefix}.s") weight = MarlinWeight(B=B, s=s) return weight def get_weights_col_packed( self, weights: Weights, prefix: str, block_sizes: Union[int, List[int]], ): if self.is_marlin_24: B = weights.get_packed_sharded( f"{prefix}.B_24", dim=1, block_sizes=block_sizes ) B_meta = weights.get_packed_sharded( f"{prefix}.B_meta", dim=1, block_sizes=block_sizes ) s = weights.get_packed_sharded( f"{prefix}.s", dim=1, block_sizes=block_sizes ) weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) else: B = weights.get_packed_sharded( f"{prefix}.B", dim=1, block_sizes=block_sizes ) s = weights.get_packed_sharded( f"{prefix}.s", dim=1, block_sizes=block_sizes ) weight = MarlinWeight(B=B, s=s) return weight def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): if self.is_marlin_24: try: B = torch.cat( [weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1 ) except RuntimeError: raise RuntimeError( "Cannot load `marlin` weight, make sure the model is already quantized" ) B_meta = torch.cat( [weights.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1 ) s = torch.cat( [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 ) weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) else: try: B = torch.cat( [weights.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1 ) except RuntimeError: raise RuntimeError( "Cannot load `marlin` weight, make sure the model is already quantized" ) s = torch.cat( [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 ) weight = MarlinWeight(B=B, s=s) return weight def get_weights_row(self, weights: Weights, prefix: str): if self.is_marlin_24: try: B = weights.get_sharded(f"{prefix}.B_24", dim=0) except RuntimeError: raise RuntimeError( "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." ) B_meta = weights.get_sharded(f"{prefix}.B_meta", dim=0) num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0] if num_groups == 1: # The number of groups is 1 when groupsize == -1. share # scales between all shards in this case. s = weights.get_tensor(f"{prefix}.s") else: s = weights.get_sharded(f"{prefix}.s", dim=0) weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) else: try: B = weights.get_sharded(f"{prefix}.B", dim=0) except RuntimeError: raise RuntimeError( "Cannot load `marlin` weight, make sure the model is already quantized." ) num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0] if num_groups == 1: # The number of groups is 1 when groupsize == -1. share # scales between all shards in this case. s = weights.get_tensor(f"{prefix}.s") else: s = weights.get_sharded(f"{prefix}.s", dim=0) weight = MarlinWeight(B=B, s=s) return weight @dataclass class MarlinWeight(Weight): """ Marlin weights. Attributes: B (torch.Tensor): int4-quantized weights packed into int32. s (torch.Tensor): bfloat16/float16 scales. """ B: torch.Tensor s: torch.Tensor def __post_init__(self): assert self.B.dtype == torch.int32 assert self.s.dtype in [torch.float16, torch.bfloat16] def get_linear(self, bias: torch.Tensor): return MarlinLinear(weight=self, bias=bias) class MarlinLinear(nn.Module): def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]): super().__init__() _check_marlin_kernels() assert marlin_kernels is not None in_features = weight.B.shape[0] * MARLIN_TILE_SIZE out_features = weight.s.shape[1] assert ( in_features % 128 == 0 ), f"Number of input features ({in_features}) not divisable by 128" assert ( out_features % 256 == 0 ), f"Number of output features ({out_features}) not divisable by 256" groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0] assert groupsize in { -1, 128, }, f"Group size must be -1 or 128, was {groupsize}" self.B = weight.B self.s = weight.s if bias is not None: self.bias = bias else: self.bias = None self.workspace = torch.zeros( out_features // 64 * 16, dtype=torch.int, device=weight.B.device ) def forward(self, A: torch.Tensor) -> torch.Tensor: assert marlin_kernels is not None C = marlin_kernels.marlin_gemm( A.view(-1, A.shape[-1]), self.B, self.s, self.workspace, A.shape[0], self.s.shape[1], A.shape[1], ) C = C.reshape(A.shape[:-1] + (self.s.shape[1],)) if self.bias is not None: C += self.bias return C GPTQ_MARLIN_24_MIN_THREAD_N = 128 GPTQ_MARLIN_24_MIN_THREAD_K = 128 GPTQ_MARLIN_24_MAX_PARALLEL = 64 GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] MARLIN_TILE_SIZE = 16 @dataclass class GPTQMarlin24Weight: """ GPTQ-Marlin 2:4 weights. Attributes: B (torch.Tensor): int4-quantized weights packed into int32. B_meta (torch.Tensor): metadata for 2:4 sparsity. s (torch.Tensor): float16 scales. bits: quantized weight size. """ B: torch.Tensor B_meta: torch.Tensor s: torch.Tensor bits: int def __post_init__(self): assert self.B.dtype == torch.int32 assert self.B_meta.dtype == torch.int16 assert self.s.dtype == torch.float16 def get_linear(self, bias: torch.Tensor): return GPTQMarlin24Linear( weight=self, bias=bias, ) class GPTQMarlin24Linear(nn.Module): def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]): super().__init__() _check_marlin_kernels() assert marlin_kernels is not None if weight.bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS: supported_bits = ", ".join( str(b) for b in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS ) raise RuntimeError( f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}" ) in_features = weight.B.shape[0] * MARLIN_TILE_SIZE * 2 out_features = weight.s.shape[1] groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0] if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES: supported_sizes = ", ".join( str(b) for b in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES ) raise RuntimeError( f"Group size {groupsize} is not supported, must be one of: {supported_sizes}" ) self.bits = weight.bits weights_per_int32 = 32 // self.bits assert ( out_features % GPTQ_MARLIN_24_MIN_THREAD_N == 0 ), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_N} threads" assert ( out_features % weights_per_int32 == 0 ), f"Number of output features ({out_features}) not divisable by weights per int32 ({weights_per_int32})" assert ( in_features % GPTQ_MARLIN_24_MIN_THREAD_K == 0 ), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_K} threads" if groupsize != -1 and in_features % groupsize != 0: raise ValueError( f"Number of input features ({in_features}) not divisable by group size ({groupsize})" ) self.B = weight.B self.B_meta = weight.B_meta self.s = weight.s if bias is not None: self.bias = bias else: self.bias = None self.workspace = torch.zeros( (out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL, dtype=torch.int, device=weight.B.device, ) def forward(self, A: torch.Tensor) -> torch.Tensor: assert marlin_kernels is not None C = marlin_kernels.gptq_marlin_24_gemm( A.view(-1, A.shape[-1]), self.B, self.B_meta, self.s, self.workspace, self.bits, A.shape[0], self.s.shape[1], A.shape[1], ) C = C.reshape(A.shape[:-1] + (self.s.shape[1],)) if self.bias is not None: C += self.bias return C