Machete WIP
This commit is contained in:
parent
43f39f6894
commit
c9e0f36dbc
|
@ -1,4 +1,5 @@
|
|||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy
|
||||
|
@ -23,12 +24,14 @@ except ImportError:
|
|||
try:
|
||||
major, _minor = torch.cuda.get_device_capability()
|
||||
has_sm_8_0 = major >= 8
|
||||
has_sm_9_0 = major >= 9
|
||||
except Exception:
|
||||
has_sm_8_0 = False
|
||||
|
||||
|
||||
GPTQ_MARLIN_BITS = [4, 8]
|
||||
GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
MACHETE_GROUP_SIZES = [-1, 128]
|
||||
MARLIN_TILE_SIZE = 16
|
||||
|
||||
|
||||
|
@ -48,6 +51,24 @@ def can_use_gptq_marlin(
|
|||
)
|
||||
|
||||
|
||||
def _can_use_machete(*, desc_act: bool, group_size: int, num_bits: int) -> bool:
|
||||
return (
|
||||
(not desc_act or group_size == -1)
|
||||
and has_sm_9_0
|
||||
and group_size in MACHETE_GROUP_SIZES
|
||||
and num_bits in [4, 8]
|
||||
)
|
||||
|
||||
|
||||
class GPTQMarlinKernel(Enum):
|
||||
"""
|
||||
Enum for selecting the GPTQ Marlin kernel.
|
||||
"""
|
||||
|
||||
GPTQ_MARLIN = auto()
|
||||
MACHETE = auto()
|
||||
|
||||
|
||||
class GPTQMarlinWeightsLoader(WeightsLoader):
|
||||
"""
|
||||
Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels.
|
||||
|
@ -69,9 +90,20 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
|
|||
self.quant_method = quant_method
|
||||
self.quantize = quantize
|
||||
self.sym = sym
|
||||
self.kernel = (
|
||||
GPTQMarlinKernel.MACHETE
|
||||
if _can_use_machete(desc_act=desc_act, group_size=groupsize, num_bits=bits)
|
||||
else GPTQMarlinKernel.GPTQ_MARLIN
|
||||
)
|
||||
|
||||
def _log_kernel(self):
|
||||
if self.kernel == GPTQMarlinKernel.MACHETE:
|
||||
log_once(logger.info, "Using Machete kernels")
|
||||
else:
|
||||
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||
|
||||
def get_weights(self, weights: Weights, prefix: str):
|
||||
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||
self._log_kernel()
|
||||
try:
|
||||
qweight = weights.get_tensor(f"{prefix}.qweight")
|
||||
except RuntimeError:
|
||||
|
@ -98,6 +130,7 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
|
|||
bits=self.bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
kernel=self.kernel,
|
||||
quant_method=self.quant_method,
|
||||
sym=self.sym,
|
||||
sharded_infeatures=False,
|
||||
|
@ -141,6 +174,7 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
|
|||
bits=self.bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
kernel=self.kernel,
|
||||
quant_method=self.quant_method,
|
||||
sym=self.sym,
|
||||
sharded_infeatures=False,
|
||||
|
@ -183,13 +217,14 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
|
|||
bits=self.bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
kernel=self.kernel,
|
||||
quant_method=self.quant_method,
|
||||
sym=self.sym,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
def get_weights_row(self, weights: Weights, prefix: str):
|
||||
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||
self._log_kernel()
|
||||
try:
|
||||
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
|
@ -225,6 +260,7 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
|
|||
bits=self.bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
kernel=self.kernel,
|
||||
quant_method=self.quant_method,
|
||||
sym=self.sym,
|
||||
sharded_infeatures=sharded_in_features,
|
||||
|
@ -255,8 +291,10 @@ class GPTQMarlinWeight(Weight):
|
|||
qzeros: torch.Tensor
|
||||
scales: torch.Tensor
|
||||
g_idx: torch.Tensor
|
||||
kernel: GPTQMarlinKernel
|
||||
perm: torch.Tensor
|
||||
bits: int
|
||||
groupsize: int
|
||||
is_full_k: bool
|
||||
|
||||
def __post_init__(self):
|
||||
|
@ -284,6 +322,7 @@ def repack_gptq_for_marlin(
|
|||
quant_method: str,
|
||||
sym: bool,
|
||||
sharded_infeatures: bool,
|
||||
kernel: GPTQMarlinKernel,
|
||||
) -> GPTQMarlinWeight:
|
||||
"""Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
|
||||
_check_marlin_kernels()
|
||||
|
@ -340,17 +379,23 @@ def repack_gptq_for_marlin(
|
|||
out_features,
|
||||
bits,
|
||||
)
|
||||
|
||||
elif kernel == GPTQMarlinKernel.MACHETE:
|
||||
repacked = qweight
|
||||
repacked.data = marlin_kernels.machete_prepack_B(
|
||||
# Convert to column-major.
|
||||
repacked.data.t().contiguous().t(),
|
||||
bits,
|
||||
qzeros is not None,
|
||||
)
|
||||
else:
|
||||
repacked = marlin_kernels.gptq_marlin_repack(
|
||||
qweight, perm, in_features, out_features, bits
|
||||
)
|
||||
scales = permute_scales(scales)
|
||||
|
||||
if qzeros is None:
|
||||
qzeros = torch.empty(0, dtype=torch.int, device=qweight.device)
|
||||
|
||||
scales = permute_scales(scales)
|
||||
|
||||
is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures)
|
||||
|
||||
return GPTQMarlinWeight(
|
||||
|
@ -358,8 +403,10 @@ def repack_gptq_for_marlin(
|
|||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
kernel=kernel,
|
||||
perm=perm,
|
||||
bits=bits,
|
||||
groupsize=groupsize,
|
||||
is_full_k=is_full_k,
|
||||
)
|
||||
|
||||
|
@ -385,7 +432,10 @@ class GPTQMarlinLinear(nn.Module):
|
|||
out_features = weight.scales.shape[1]
|
||||
_check_valid_shape(in_features=in_features, out_features=out_features)
|
||||
|
||||
self.kernel = weight.kernel
|
||||
|
||||
self.bits = weight.bits
|
||||
self.groupsize = weight.groupsize
|
||||
self.is_full_k = weight.is_full_k
|
||||
|
||||
self.qweight = weight.qweight
|
||||
|
@ -406,22 +456,31 @@ class GPTQMarlinLinear(nn.Module):
|
|||
assert marlin_kernels is not None
|
||||
|
||||
A_flat = A.view(-1, A.shape[-1])
|
||||
C = marlin_kernels.gptq_marlin_gemm(
|
||||
A_flat,
|
||||
self.qweight,
|
||||
self.scales,
|
||||
self.qzeros,
|
||||
self.g_idx,
|
||||
self.perm,
|
||||
self.workspace,
|
||||
self.bits,
|
||||
A_flat.shape[0],
|
||||
self.scales.shape[1],
|
||||
A_flat.shape[1],
|
||||
self.is_full_k,
|
||||
self.qzeros.numel() > 0,
|
||||
True,
|
||||
)
|
||||
if self.kernel == GPTQMarlinKernel.MACHETE:
|
||||
C = marlin_kernels.machete_gemm(
|
||||
A=A_flat,
|
||||
B=self.qweight,
|
||||
num_bits=self.bits,
|
||||
scales=self.scales,
|
||||
zeros=self.qzeros if self.qzeros.numel() > 0 else None,
|
||||
)
|
||||
else:
|
||||
C = marlin_kernels.gptq_marlin_gemm(
|
||||
A_flat,
|
||||
self.qweight,
|
||||
self.scales,
|
||||
self.qzeros,
|
||||
self.g_idx,
|
||||
self.perm,
|
||||
self.workspace,
|
||||
self.bits,
|
||||
A_flat.shape[0],
|
||||
self.scales.shape[1],
|
||||
A_flat.shape[1],
|
||||
self.is_full_k,
|
||||
self.qzeros.numel() > 0,
|
||||
True,
|
||||
)
|
||||
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
|
||||
|
||||
if self.bias is not None:
|
||||
|
|
Loading…
Reference in New Issue