Machete WIP

This commit is contained in:
Daniël de Kok 2024-10-14 07:59:09 +00:00
parent 43f39f6894
commit c9e0f36dbc
1 changed files with 80 additions and 21 deletions

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass
from enum import Enum, auto
from typing import List, Optional, Union
import numpy
@ -23,12 +24,14 @@ except ImportError:
try:
major, _minor = torch.cuda.get_device_capability()
has_sm_8_0 = major >= 8
has_sm_9_0 = major >= 9
except Exception:
has_sm_8_0 = False
GPTQ_MARLIN_BITS = [4, 8]
GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
MACHETE_GROUP_SIZES = [-1, 128]
MARLIN_TILE_SIZE = 16
@ -48,6 +51,24 @@ def can_use_gptq_marlin(
)
def _can_use_machete(*, desc_act: bool, group_size: int, num_bits: int) -> bool:
return (
(not desc_act or group_size == -1)
and has_sm_9_0
and group_size in MACHETE_GROUP_SIZES
and num_bits in [4, 8]
)
class GPTQMarlinKernel(Enum):
"""
Enum for selecting the GPTQ Marlin kernel.
"""
GPTQ_MARLIN = auto()
MACHETE = auto()
class GPTQMarlinWeightsLoader(WeightsLoader):
"""
Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels.
@ -69,9 +90,20 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
self.quant_method = quant_method
self.quantize = quantize
self.sym = sym
self.kernel = (
GPTQMarlinKernel.MACHETE
if _can_use_machete(desc_act=desc_act, group_size=groupsize, num_bits=bits)
else GPTQMarlinKernel.GPTQ_MARLIN
)
def _log_kernel(self):
if self.kernel == GPTQMarlinKernel.MACHETE:
log_once(logger.info, "Using Machete kernels")
else:
log_once(logger.info, "Using GPTQ-Marlin kernels")
def get_weights(self, weights: Weights, prefix: str):
log_once(logger.info, "Using GPTQ-Marlin kernels")
self._log_kernel()
try:
qweight = weights.get_tensor(f"{prefix}.qweight")
except RuntimeError:
@ -98,6 +130,7 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
kernel=self.kernel,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=False,
@ -141,6 +174,7 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
kernel=self.kernel,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=False,
@ -183,13 +217,14 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
kernel=self.kernel,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=False,
)
def get_weights_row(self, weights: Weights, prefix: str):
log_once(logger.info, "Using GPTQ-Marlin kernels")
self._log_kernel()
try:
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
@ -225,6 +260,7 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
kernel=self.kernel,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=sharded_in_features,
@ -255,8 +291,10 @@ class GPTQMarlinWeight(Weight):
qzeros: torch.Tensor
scales: torch.Tensor
g_idx: torch.Tensor
kernel: GPTQMarlinKernel
perm: torch.Tensor
bits: int
groupsize: int
is_full_k: bool
def __post_init__(self):
@ -284,6 +322,7 @@ def repack_gptq_for_marlin(
quant_method: str,
sym: bool,
sharded_infeatures: bool,
kernel: GPTQMarlinKernel,
) -> GPTQMarlinWeight:
"""Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
_check_marlin_kernels()
@ -340,17 +379,23 @@ def repack_gptq_for_marlin(
out_features,
bits,
)
elif kernel == GPTQMarlinKernel.MACHETE:
repacked = qweight
repacked.data = marlin_kernels.machete_prepack_B(
# Convert to column-major.
repacked.data.t().contiguous().t(),
bits,
qzeros is not None,
)
else:
repacked = marlin_kernels.gptq_marlin_repack(
qweight, perm, in_features, out_features, bits
)
scales = permute_scales(scales)
if qzeros is None:
qzeros = torch.empty(0, dtype=torch.int, device=qweight.device)
scales = permute_scales(scales)
is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures)
return GPTQMarlinWeight(
@ -358,8 +403,10 @@ def repack_gptq_for_marlin(
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
kernel=kernel,
perm=perm,
bits=bits,
groupsize=groupsize,
is_full_k=is_full_k,
)
@ -385,7 +432,10 @@ class GPTQMarlinLinear(nn.Module):
out_features = weight.scales.shape[1]
_check_valid_shape(in_features=in_features, out_features=out_features)
self.kernel = weight.kernel
self.bits = weight.bits
self.groupsize = weight.groupsize
self.is_full_k = weight.is_full_k
self.qweight = weight.qweight
@ -406,22 +456,31 @@ class GPTQMarlinLinear(nn.Module):
assert marlin_kernels is not None
A_flat = A.view(-1, A.shape[-1])
C = marlin_kernels.gptq_marlin_gemm(
A_flat,
self.qweight,
self.scales,
self.qzeros,
self.g_idx,
self.perm,
self.workspace,
self.bits,
A_flat.shape[0],
self.scales.shape[1],
A_flat.shape[1],
self.is_full_k,
self.qzeros.numel() > 0,
True,
)
if self.kernel == GPTQMarlinKernel.MACHETE:
C = marlin_kernels.machete_gemm(
A=A_flat,
B=self.qweight,
num_bits=self.bits,
scales=self.scales,
zeros=self.qzeros if self.qzeros.numel() > 0 else None,
)
else:
C = marlin_kernels.gptq_marlin_gemm(
A_flat,
self.qweight,
self.scales,
self.qzeros,
self.g_idx,
self.perm,
self.workspace,
self.bits,
A_flat.shape[0],
self.scales.shape[1],
A_flat.shape[1],
self.is_full_k,
self.qzeros.numel() > 0,
True,
)
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
if self.bias is not None: