Machete WIP
This commit is contained in:
parent
43f39f6894
commit
c9e0f36dbc
|
@ -1,4 +1,5 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum, auto
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
|
@ -23,12 +24,14 @@ except ImportError:
|
||||||
try:
|
try:
|
||||||
major, _minor = torch.cuda.get_device_capability()
|
major, _minor = torch.cuda.get_device_capability()
|
||||||
has_sm_8_0 = major >= 8
|
has_sm_8_0 = major >= 8
|
||||||
|
has_sm_9_0 = major >= 9
|
||||||
except Exception:
|
except Exception:
|
||||||
has_sm_8_0 = False
|
has_sm_8_0 = False
|
||||||
|
|
||||||
|
|
||||||
GPTQ_MARLIN_BITS = [4, 8]
|
GPTQ_MARLIN_BITS = [4, 8]
|
||||||
GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
|
GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
|
||||||
|
MACHETE_GROUP_SIZES = [-1, 128]
|
||||||
MARLIN_TILE_SIZE = 16
|
MARLIN_TILE_SIZE = 16
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,6 +51,24 @@ def can_use_gptq_marlin(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _can_use_machete(*, desc_act: bool, group_size: int, num_bits: int) -> bool:
|
||||||
|
return (
|
||||||
|
(not desc_act or group_size == -1)
|
||||||
|
and has_sm_9_0
|
||||||
|
and group_size in MACHETE_GROUP_SIZES
|
||||||
|
and num_bits in [4, 8]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GPTQMarlinKernel(Enum):
|
||||||
|
"""
|
||||||
|
Enum for selecting the GPTQ Marlin kernel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
GPTQ_MARLIN = auto()
|
||||||
|
MACHETE = auto()
|
||||||
|
|
||||||
|
|
||||||
class GPTQMarlinWeightsLoader(WeightsLoader):
|
class GPTQMarlinWeightsLoader(WeightsLoader):
|
||||||
"""
|
"""
|
||||||
Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels.
|
Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels.
|
||||||
|
@ -69,9 +90,20 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
|
||||||
self.quant_method = quant_method
|
self.quant_method = quant_method
|
||||||
self.quantize = quantize
|
self.quantize = quantize
|
||||||
self.sym = sym
|
self.sym = sym
|
||||||
|
self.kernel = (
|
||||||
|
GPTQMarlinKernel.MACHETE
|
||||||
|
if _can_use_machete(desc_act=desc_act, group_size=groupsize, num_bits=bits)
|
||||||
|
else GPTQMarlinKernel.GPTQ_MARLIN
|
||||||
|
)
|
||||||
|
|
||||||
|
def _log_kernel(self):
|
||||||
|
if self.kernel == GPTQMarlinKernel.MACHETE:
|
||||||
|
log_once(logger.info, "Using Machete kernels")
|
||||||
|
else:
|
||||||
|
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||||
|
|
||||||
def get_weights(self, weights: Weights, prefix: str):
|
def get_weights(self, weights: Weights, prefix: str):
|
||||||
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
self._log_kernel()
|
||||||
try:
|
try:
|
||||||
qweight = weights.get_tensor(f"{prefix}.qweight")
|
qweight = weights.get_tensor(f"{prefix}.qweight")
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
|
@ -98,6 +130,7 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
|
||||||
bits=self.bits,
|
bits=self.bits,
|
||||||
desc_act=self.desc_act,
|
desc_act=self.desc_act,
|
||||||
groupsize=self.groupsize,
|
groupsize=self.groupsize,
|
||||||
|
kernel=self.kernel,
|
||||||
quant_method=self.quant_method,
|
quant_method=self.quant_method,
|
||||||
sym=self.sym,
|
sym=self.sym,
|
||||||
sharded_infeatures=False,
|
sharded_infeatures=False,
|
||||||
|
@ -141,6 +174,7 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
|
||||||
bits=self.bits,
|
bits=self.bits,
|
||||||
desc_act=self.desc_act,
|
desc_act=self.desc_act,
|
||||||
groupsize=self.groupsize,
|
groupsize=self.groupsize,
|
||||||
|
kernel=self.kernel,
|
||||||
quant_method=self.quant_method,
|
quant_method=self.quant_method,
|
||||||
sym=self.sym,
|
sym=self.sym,
|
||||||
sharded_infeatures=False,
|
sharded_infeatures=False,
|
||||||
|
@ -183,13 +217,14 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
|
||||||
bits=self.bits,
|
bits=self.bits,
|
||||||
desc_act=self.desc_act,
|
desc_act=self.desc_act,
|
||||||
groupsize=self.groupsize,
|
groupsize=self.groupsize,
|
||||||
|
kernel=self.kernel,
|
||||||
quant_method=self.quant_method,
|
quant_method=self.quant_method,
|
||||||
sym=self.sym,
|
sym=self.sym,
|
||||||
sharded_infeatures=False,
|
sharded_infeatures=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_weights_row(self, weights: Weights, prefix: str):
|
def get_weights_row(self, weights: Weights, prefix: str):
|
||||||
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
self._log_kernel()
|
||||||
try:
|
try:
|
||||||
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
|
@ -225,6 +260,7 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
|
||||||
bits=self.bits,
|
bits=self.bits,
|
||||||
desc_act=self.desc_act,
|
desc_act=self.desc_act,
|
||||||
groupsize=self.groupsize,
|
groupsize=self.groupsize,
|
||||||
|
kernel=self.kernel,
|
||||||
quant_method=self.quant_method,
|
quant_method=self.quant_method,
|
||||||
sym=self.sym,
|
sym=self.sym,
|
||||||
sharded_infeatures=sharded_in_features,
|
sharded_infeatures=sharded_in_features,
|
||||||
|
@ -255,8 +291,10 @@ class GPTQMarlinWeight(Weight):
|
||||||
qzeros: torch.Tensor
|
qzeros: torch.Tensor
|
||||||
scales: torch.Tensor
|
scales: torch.Tensor
|
||||||
g_idx: torch.Tensor
|
g_idx: torch.Tensor
|
||||||
|
kernel: GPTQMarlinKernel
|
||||||
perm: torch.Tensor
|
perm: torch.Tensor
|
||||||
bits: int
|
bits: int
|
||||||
|
groupsize: int
|
||||||
is_full_k: bool
|
is_full_k: bool
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
@ -284,6 +322,7 @@ def repack_gptq_for_marlin(
|
||||||
quant_method: str,
|
quant_method: str,
|
||||||
sym: bool,
|
sym: bool,
|
||||||
sharded_infeatures: bool,
|
sharded_infeatures: bool,
|
||||||
|
kernel: GPTQMarlinKernel,
|
||||||
) -> GPTQMarlinWeight:
|
) -> GPTQMarlinWeight:
|
||||||
"""Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
|
"""Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
|
||||||
_check_marlin_kernels()
|
_check_marlin_kernels()
|
||||||
|
@ -340,17 +379,23 @@ def repack_gptq_for_marlin(
|
||||||
out_features,
|
out_features,
|
||||||
bits,
|
bits,
|
||||||
)
|
)
|
||||||
|
elif kernel == GPTQMarlinKernel.MACHETE:
|
||||||
|
repacked = qweight
|
||||||
|
repacked.data = marlin_kernels.machete_prepack_B(
|
||||||
|
# Convert to column-major.
|
||||||
|
repacked.data.t().contiguous().t(),
|
||||||
|
bits,
|
||||||
|
qzeros is not None,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
repacked = marlin_kernels.gptq_marlin_repack(
|
repacked = marlin_kernels.gptq_marlin_repack(
|
||||||
qweight, perm, in_features, out_features, bits
|
qweight, perm, in_features, out_features, bits
|
||||||
)
|
)
|
||||||
|
scales = permute_scales(scales)
|
||||||
|
|
||||||
if qzeros is None:
|
if qzeros is None:
|
||||||
qzeros = torch.empty(0, dtype=torch.int, device=qweight.device)
|
qzeros = torch.empty(0, dtype=torch.int, device=qweight.device)
|
||||||
|
|
||||||
scales = permute_scales(scales)
|
|
||||||
|
|
||||||
is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures)
|
is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures)
|
||||||
|
|
||||||
return GPTQMarlinWeight(
|
return GPTQMarlinWeight(
|
||||||
|
@ -358,8 +403,10 @@ def repack_gptq_for_marlin(
|
||||||
qzeros=qzeros,
|
qzeros=qzeros,
|
||||||
scales=scales,
|
scales=scales,
|
||||||
g_idx=g_idx,
|
g_idx=g_idx,
|
||||||
|
kernel=kernel,
|
||||||
perm=perm,
|
perm=perm,
|
||||||
bits=bits,
|
bits=bits,
|
||||||
|
groupsize=groupsize,
|
||||||
is_full_k=is_full_k,
|
is_full_k=is_full_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -385,7 +432,10 @@ class GPTQMarlinLinear(nn.Module):
|
||||||
out_features = weight.scales.shape[1]
|
out_features = weight.scales.shape[1]
|
||||||
_check_valid_shape(in_features=in_features, out_features=out_features)
|
_check_valid_shape(in_features=in_features, out_features=out_features)
|
||||||
|
|
||||||
|
self.kernel = weight.kernel
|
||||||
|
|
||||||
self.bits = weight.bits
|
self.bits = weight.bits
|
||||||
|
self.groupsize = weight.groupsize
|
||||||
self.is_full_k = weight.is_full_k
|
self.is_full_k = weight.is_full_k
|
||||||
|
|
||||||
self.qweight = weight.qweight
|
self.qweight = weight.qweight
|
||||||
|
@ -406,6 +456,15 @@ class GPTQMarlinLinear(nn.Module):
|
||||||
assert marlin_kernels is not None
|
assert marlin_kernels is not None
|
||||||
|
|
||||||
A_flat = A.view(-1, A.shape[-1])
|
A_flat = A.view(-1, A.shape[-1])
|
||||||
|
if self.kernel == GPTQMarlinKernel.MACHETE:
|
||||||
|
C = marlin_kernels.machete_gemm(
|
||||||
|
A=A_flat,
|
||||||
|
B=self.qweight,
|
||||||
|
num_bits=self.bits,
|
||||||
|
scales=self.scales,
|
||||||
|
zeros=self.qzeros if self.qzeros.numel() > 0 else None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
C = marlin_kernels.gptq_marlin_gemm(
|
C = marlin_kernels.gptq_marlin_gemm(
|
||||||
A_flat,
|
A_flat,
|
||||||
self.qweight,
|
self.qweight,
|
||||||
|
|
Loading…
Reference in New Issue