Machete WIP

This commit is contained in:
Daniël de Kok 2024-10-14 07:59:09 +00:00
parent 43f39f6894
commit c9e0f36dbc
1 changed files with 80 additions and 21 deletions

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto
from typing import List, Optional, Union from typing import List, Optional, Union
import numpy import numpy
@ -23,12 +24,14 @@ except ImportError:
try: try:
major, _minor = torch.cuda.get_device_capability() major, _minor = torch.cuda.get_device_capability()
has_sm_8_0 = major >= 8 has_sm_8_0 = major >= 8
has_sm_9_0 = major >= 9
except Exception: except Exception:
has_sm_8_0 = False has_sm_8_0 = False
GPTQ_MARLIN_BITS = [4, 8] GPTQ_MARLIN_BITS = [4, 8]
GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128] GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
MACHETE_GROUP_SIZES = [-1, 128]
MARLIN_TILE_SIZE = 16 MARLIN_TILE_SIZE = 16
@ -48,6 +51,24 @@ def can_use_gptq_marlin(
) )
def _can_use_machete(*, desc_act: bool, group_size: int, num_bits: int) -> bool:
return (
(not desc_act or group_size == -1)
and has_sm_9_0
and group_size in MACHETE_GROUP_SIZES
and num_bits in [4, 8]
)
class GPTQMarlinKernel(Enum):
"""
Enum for selecting the GPTQ Marlin kernel.
"""
GPTQ_MARLIN = auto()
MACHETE = auto()
class GPTQMarlinWeightsLoader(WeightsLoader): class GPTQMarlinWeightsLoader(WeightsLoader):
""" """
Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels. Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels.
@ -69,9 +90,20 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
self.quant_method = quant_method self.quant_method = quant_method
self.quantize = quantize self.quantize = quantize
self.sym = sym self.sym = sym
self.kernel = (
GPTQMarlinKernel.MACHETE
if _can_use_machete(desc_act=desc_act, group_size=groupsize, num_bits=bits)
else GPTQMarlinKernel.GPTQ_MARLIN
)
def _log_kernel(self):
if self.kernel == GPTQMarlinKernel.MACHETE:
log_once(logger.info, "Using Machete kernels")
else:
log_once(logger.info, "Using GPTQ-Marlin kernels")
def get_weights(self, weights: Weights, prefix: str): def get_weights(self, weights: Weights, prefix: str):
log_once(logger.info, "Using GPTQ-Marlin kernels") self._log_kernel()
try: try:
qweight = weights.get_tensor(f"{prefix}.qweight") qweight = weights.get_tensor(f"{prefix}.qweight")
except RuntimeError: except RuntimeError:
@ -98,6 +130,7 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
bits=self.bits, bits=self.bits,
desc_act=self.desc_act, desc_act=self.desc_act,
groupsize=self.groupsize, groupsize=self.groupsize,
kernel=self.kernel,
quant_method=self.quant_method, quant_method=self.quant_method,
sym=self.sym, sym=self.sym,
sharded_infeatures=False, sharded_infeatures=False,
@ -141,6 +174,7 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
bits=self.bits, bits=self.bits,
desc_act=self.desc_act, desc_act=self.desc_act,
groupsize=self.groupsize, groupsize=self.groupsize,
kernel=self.kernel,
quant_method=self.quant_method, quant_method=self.quant_method,
sym=self.sym, sym=self.sym,
sharded_infeatures=False, sharded_infeatures=False,
@ -183,13 +217,14 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
bits=self.bits, bits=self.bits,
desc_act=self.desc_act, desc_act=self.desc_act,
groupsize=self.groupsize, groupsize=self.groupsize,
kernel=self.kernel,
quant_method=self.quant_method, quant_method=self.quant_method,
sym=self.sym, sym=self.sym,
sharded_infeatures=False, sharded_infeatures=False,
) )
def get_weights_row(self, weights: Weights, prefix: str): def get_weights_row(self, weights: Weights, prefix: str):
log_once(logger.info, "Using GPTQ-Marlin kernels") self._log_kernel()
try: try:
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError: except RuntimeError:
@ -225,6 +260,7 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
bits=self.bits, bits=self.bits,
desc_act=self.desc_act, desc_act=self.desc_act,
groupsize=self.groupsize, groupsize=self.groupsize,
kernel=self.kernel,
quant_method=self.quant_method, quant_method=self.quant_method,
sym=self.sym, sym=self.sym,
sharded_infeatures=sharded_in_features, sharded_infeatures=sharded_in_features,
@ -255,8 +291,10 @@ class GPTQMarlinWeight(Weight):
qzeros: torch.Tensor qzeros: torch.Tensor
scales: torch.Tensor scales: torch.Tensor
g_idx: torch.Tensor g_idx: torch.Tensor
kernel: GPTQMarlinKernel
perm: torch.Tensor perm: torch.Tensor
bits: int bits: int
groupsize: int
is_full_k: bool is_full_k: bool
def __post_init__(self): def __post_init__(self):
@ -284,6 +322,7 @@ def repack_gptq_for_marlin(
quant_method: str, quant_method: str,
sym: bool, sym: bool,
sharded_infeatures: bool, sharded_infeatures: bool,
kernel: GPTQMarlinKernel,
) -> GPTQMarlinWeight: ) -> GPTQMarlinWeight:
"""Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels.""" """Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
_check_marlin_kernels() _check_marlin_kernels()
@ -340,17 +379,23 @@ def repack_gptq_for_marlin(
out_features, out_features,
bits, bits,
) )
elif kernel == GPTQMarlinKernel.MACHETE:
repacked = qweight
repacked.data = marlin_kernels.machete_prepack_B(
# Convert to column-major.
repacked.data.t().contiguous().t(),
bits,
qzeros is not None,
)
else: else:
repacked = marlin_kernels.gptq_marlin_repack( repacked = marlin_kernels.gptq_marlin_repack(
qweight, perm, in_features, out_features, bits qweight, perm, in_features, out_features, bits
) )
scales = permute_scales(scales)
if qzeros is None: if qzeros is None:
qzeros = torch.empty(0, dtype=torch.int, device=qweight.device) qzeros = torch.empty(0, dtype=torch.int, device=qweight.device)
scales = permute_scales(scales)
is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures) is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures)
return GPTQMarlinWeight( return GPTQMarlinWeight(
@ -358,8 +403,10 @@ def repack_gptq_for_marlin(
qzeros=qzeros, qzeros=qzeros,
scales=scales, scales=scales,
g_idx=g_idx, g_idx=g_idx,
kernel=kernel,
perm=perm, perm=perm,
bits=bits, bits=bits,
groupsize=groupsize,
is_full_k=is_full_k, is_full_k=is_full_k,
) )
@ -385,7 +432,10 @@ class GPTQMarlinLinear(nn.Module):
out_features = weight.scales.shape[1] out_features = weight.scales.shape[1]
_check_valid_shape(in_features=in_features, out_features=out_features) _check_valid_shape(in_features=in_features, out_features=out_features)
self.kernel = weight.kernel
self.bits = weight.bits self.bits = weight.bits
self.groupsize = weight.groupsize
self.is_full_k = weight.is_full_k self.is_full_k = weight.is_full_k
self.qweight = weight.qweight self.qweight = weight.qweight
@ -406,6 +456,15 @@ class GPTQMarlinLinear(nn.Module):
assert marlin_kernels is not None assert marlin_kernels is not None
A_flat = A.view(-1, A.shape[-1]) A_flat = A.view(-1, A.shape[-1])
if self.kernel == GPTQMarlinKernel.MACHETE:
C = marlin_kernels.machete_gemm(
A=A_flat,
B=self.qweight,
num_bits=self.bits,
scales=self.scales,
zeros=self.qzeros if self.qzeros.numel() > 0 else None,
)
else:
C = marlin_kernels.gptq_marlin_gemm( C = marlin_kernels.gptq_marlin_gemm(
A_flat, A_flat,
self.qweight, self.qweight,