465 lines
14 KiB
Python
465 lines
14 KiB
Python
from dataclasses import dataclass
|
|
from typing import List, Optional, Union
|
|
|
|
import numpy
|
|
import torch
|
|
import torch.nn as nn
|
|
from loguru import logger
|
|
from text_generation_server.layers.marlin.util import (
|
|
_check_marlin_kernels,
|
|
marlin_zero_points,
|
|
permute_scales,
|
|
unpack_cols,
|
|
)
|
|
from text_generation_server.utils.import_utils import SYSTEM
|
|
from text_generation_server.utils.log import log_once
|
|
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
|
|
|
try:
|
|
import marlin_kernels
|
|
except ImportError:
|
|
marlin_kernels = None
|
|
|
|
try:
|
|
major, _minor = torch.cuda.get_device_capability()
|
|
has_sm_8_0 = major >= 8
|
|
except Exception:
|
|
has_sm_8_0 = False
|
|
|
|
|
|
GPTQ_MARLIN_BITS = [4, 8]
|
|
GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
|
|
MARLIN_TILE_SIZE = 16
|
|
|
|
|
|
def can_use_gptq_marlin(
|
|
*, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool
|
|
) -> bool:
|
|
return (
|
|
SYSTEM == "cuda"
|
|
and marlin_kernels is not None
|
|
and has_sm_8_0
|
|
and quantize in {"awq", "gptq"}
|
|
and quant_method in {"awq", "gptq"}
|
|
and bits in GPTQ_MARLIN_BITS
|
|
and groupsize in GPTQ_MARLIN_GROUP_SIZES
|
|
# We only support asymmetric quantization for AWQ.
|
|
and (sym or quant_method == "awq")
|
|
)
|
|
|
|
|
|
class GPTQMarlinWeightsLoader(WeightsLoader):
|
|
"""
|
|
Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
bits: int,
|
|
desc_act: bool,
|
|
groupsize: int,
|
|
quant_method: str,
|
|
quantize: str,
|
|
sym: bool,
|
|
):
|
|
self.bits = bits
|
|
self.desc_act = desc_act
|
|
self.groupsize = groupsize
|
|
self.quant_method = quant_method
|
|
self.quantize = quantize
|
|
self.sym = sym
|
|
|
|
def get_weights(self, weights: Weights, prefix: str):
|
|
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
|
try:
|
|
qweight = weights.get_tensor(f"{prefix}.qweight")
|
|
except RuntimeError:
|
|
raise RuntimeError(
|
|
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
|
)
|
|
|
|
if not self.sym:
|
|
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
|
else:
|
|
qzeros = None
|
|
|
|
if self.quant_method == "awq":
|
|
g_idx = None
|
|
else:
|
|
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
|
scales = weights.get_tensor(f"{prefix}.scales")
|
|
|
|
return repack_gptq_for_marlin(
|
|
qweight=qweight,
|
|
scales=scales,
|
|
qzeros=qzeros,
|
|
g_idx=g_idx,
|
|
bits=self.bits,
|
|
desc_act=self.desc_act,
|
|
groupsize=self.groupsize,
|
|
quant_method=self.quant_method,
|
|
sym=self.sym,
|
|
sharded_infeatures=False,
|
|
)
|
|
|
|
def get_weights_col_packed(
|
|
self,
|
|
weights: Weights,
|
|
prefix: str,
|
|
block_sizes: Union[int, List[int]],
|
|
):
|
|
try:
|
|
qweight = weights.get_packed_sharded(
|
|
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
|
)
|
|
except RuntimeError:
|
|
raise RuntimeError(
|
|
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
|
|
)
|
|
scales = weights.get_packed_sharded(
|
|
f"{prefix}.scales", dim=1, block_sizes=block_sizes
|
|
)
|
|
scales = scales.to(dtype=weights.dtype)
|
|
|
|
if not self.sym:
|
|
qzeros = weights.get_packed_sharded(
|
|
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
|
)
|
|
else:
|
|
qzeros = None
|
|
|
|
if self.quant_method == "awq":
|
|
g_idx = None
|
|
else:
|
|
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
|
return repack_gptq_for_marlin(
|
|
qweight=qweight,
|
|
scales=scales,
|
|
qzeros=qzeros,
|
|
g_idx=g_idx,
|
|
bits=self.bits,
|
|
desc_act=self.desc_act,
|
|
groupsize=self.groupsize,
|
|
quant_method=self.quant_method,
|
|
sym=self.sym,
|
|
sharded_infeatures=False,
|
|
)
|
|
|
|
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
|
try:
|
|
qweight = torch.cat(
|
|
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
|
)
|
|
except RuntimeError:
|
|
raise RuntimeError(
|
|
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
|
|
)
|
|
|
|
scales = torch.cat(
|
|
[weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
|
)
|
|
|
|
if not self.sym:
|
|
qzeros = torch.cat(
|
|
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
|
)
|
|
else:
|
|
qzeros = None
|
|
|
|
if self.quant_method == "awq":
|
|
g_idx = None
|
|
else:
|
|
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
|
for w2 in w[1:]:
|
|
torch.testing.assert_close(w2, w[0])
|
|
g_idx = w[0]
|
|
|
|
return repack_gptq_for_marlin(
|
|
qweight=qweight,
|
|
scales=scales,
|
|
qzeros=qzeros,
|
|
g_idx=g_idx,
|
|
bits=self.bits,
|
|
desc_act=self.desc_act,
|
|
groupsize=self.groupsize,
|
|
quant_method=self.quant_method,
|
|
sym=self.sym,
|
|
sharded_infeatures=False,
|
|
)
|
|
|
|
def get_weights_row(self, weights: Weights, prefix: str):
|
|
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
|
try:
|
|
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
|
except RuntimeError:
|
|
raise RuntimeError(
|
|
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
|
)
|
|
|
|
if not self.sym:
|
|
if self.desc_act or self.groupsize == -1:
|
|
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
|
else:
|
|
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
|
|
else:
|
|
qzeros = None
|
|
|
|
if self.quant_method == "awq":
|
|
g_idx = None
|
|
else:
|
|
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
|
|
|
|
if self.desc_act or self.groupsize == -1:
|
|
scales = weights.get_tensor(f"{prefix}.scales")
|
|
else:
|
|
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
|
|
|
|
sharded_in_features = weights.process_group.size() > 1
|
|
|
|
return repack_gptq_for_marlin(
|
|
qweight=qweight,
|
|
scales=scales,
|
|
qzeros=qzeros,
|
|
g_idx=g_idx,
|
|
bits=self.bits,
|
|
desc_act=self.desc_act,
|
|
groupsize=self.groupsize,
|
|
quant_method=self.quant_method,
|
|
sym=self.sym,
|
|
sharded_infeatures=sharded_in_features,
|
|
)
|
|
|
|
def _get_gptq_params(self, weights: Weights):
|
|
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"):
|
|
self.bits = weights.get_tensor("gptq_bits").item()
|
|
self.groupsize = weights.get_tensor("gptq_groupsize").item()
|
|
self.desc_act = False
|
|
# `server quantize` used asymmetric quantization unconditionally
|
|
# before the `gptq_sym` setting tensor was added.
|
|
self.sym = (
|
|
weights.get_tensor("gptq_sym").item()
|
|
if weights._has_tensor("gptq_sym")
|
|
else False
|
|
)
|
|
self.quant_method = "gptq"
|
|
|
|
|
|
@dataclass
|
|
class GPTQMarlinWeight(Weight):
|
|
"""
|
|
Repacked GPTQ Marlin weights.
|
|
"""
|
|
|
|
qweight: torch.Tensor
|
|
qzeros: torch.Tensor
|
|
scales: torch.Tensor
|
|
g_idx: torch.Tensor
|
|
perm: torch.Tensor
|
|
bits: int
|
|
is_full_k: bool
|
|
|
|
def __post_init__(self):
|
|
assert self.qweight.dtype == torch.int32
|
|
assert self.scales.dtype == torch.float16
|
|
assert self.g_idx.dtype == torch.int32
|
|
assert self.perm.dtype == torch.int32
|
|
|
|
def get_linear(self, bias: torch.Tensor):
|
|
return GPTQMarlinLinear(
|
|
weight=self,
|
|
bias=bias,
|
|
)
|
|
|
|
|
|
def repack_gptq_for_marlin(
|
|
*,
|
|
qweight: torch.Tensor,
|
|
qzeros: Optional[torch.Tensor],
|
|
scales: torch.Tensor,
|
|
g_idx: Optional[torch.Tensor],
|
|
bits: int,
|
|
desc_act: bool,
|
|
groupsize: int,
|
|
quant_method: str,
|
|
sym: bool,
|
|
sharded_infeatures: bool,
|
|
) -> GPTQMarlinWeight:
|
|
"""Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
|
|
_check_marlin_kernels()
|
|
assert marlin_kernels is not None
|
|
|
|
if bits not in GPTQ_MARLIN_BITS:
|
|
supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS)
|
|
raise RuntimeError(
|
|
f"Repacking {bits}-bit GPTQ weights as Marlin is not supported, must be one of: {supported_bits}"
|
|
)
|
|
|
|
if groupsize not in GPTQ_MARLIN_GROUP_SIZES:
|
|
supported_sizes = ", ".join(str(b) for b in GPTQ_MARLIN_GROUP_SIZES)
|
|
raise RuntimeError(
|
|
f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}"
|
|
)
|
|
if not (sym or quant_method == "awq"):
|
|
raise RuntimeError(
|
|
"Repacking GPTQ weights with asymmetric quantization as Marlin is not supported."
|
|
)
|
|
|
|
log_once(logger.info, f"Converting {quant_method} model to Marlin packing format.")
|
|
|
|
weights_per_int = 32 // bits
|
|
in_features = qweight.shape[0]
|
|
out_features = qweight.shape[1]
|
|
|
|
# AWQ uses column packing, GPTQ uses row packing
|
|
if quant_method == "awq":
|
|
out_features *= weights_per_int
|
|
else:
|
|
in_features *= weights_per_int
|
|
|
|
if in_features % groupsize != 0:
|
|
raise ValueError(
|
|
f"Number of input features ({in_features}) not divisible by group size ({groupsize})"
|
|
)
|
|
|
|
if g_idx is not None and desc_act and groupsize != -1:
|
|
perm = torch.argsort(g_idx).to(torch.int)
|
|
g_idx = g_idx[perm]
|
|
else:
|
|
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
|
|
g_idx = torch.empty(0, dtype=torch.int, device=qweight.device)
|
|
|
|
if quant_method == "awq":
|
|
repacked = marlin_kernels.awq_marlin_repack(
|
|
qweight, in_features, out_features, bits
|
|
)
|
|
if qzeros is not None:
|
|
qzeros = awq_to_marlin_zero_points(
|
|
qzeros,
|
|
in_features // groupsize,
|
|
out_features,
|
|
bits,
|
|
)
|
|
|
|
else:
|
|
repacked = marlin_kernels.gptq_marlin_repack(
|
|
qweight, perm, in_features, out_features, bits
|
|
)
|
|
|
|
if qzeros is None:
|
|
qzeros = torch.empty(0, dtype=torch.int, device=qweight.device)
|
|
|
|
scales = permute_scales(scales)
|
|
|
|
is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures)
|
|
|
|
return GPTQMarlinWeight(
|
|
qweight=repacked,
|
|
qzeros=qzeros,
|
|
scales=scales,
|
|
g_idx=g_idx,
|
|
perm=perm,
|
|
bits=bits,
|
|
is_full_k=is_full_k,
|
|
)
|
|
|
|
|
|
class GPTQMarlinLinear(nn.Module):
|
|
"""
|
|
Linear layer for GPTQ weights that were converted for the GPTQ-Marlin
|
|
kernels.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
weight: GPTQMarlinWeight,
|
|
bias: Optional[torch.Tensor],
|
|
):
|
|
super().__init__()
|
|
|
|
_check_marlin_kernels()
|
|
assert marlin_kernels is not None
|
|
|
|
in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE
|
|
out_features = weight.scales.shape[1]
|
|
_check_valid_shape(in_features=in_features, out_features=out_features)
|
|
|
|
self.bits = weight.bits
|
|
self.is_full_k = weight.is_full_k
|
|
|
|
self.qweight = weight.qweight
|
|
self.qzeros = weight.qzeros
|
|
self.scales = weight.scales
|
|
self.g_idx = weight.g_idx
|
|
self.perm = weight.perm
|
|
if bias is not None:
|
|
self.bias = bias
|
|
else:
|
|
self.bias = None
|
|
|
|
self.workspace = torch.zeros(
|
|
out_features // 64 * 16, dtype=torch.int, device=weight.qweight.device
|
|
)
|
|
|
|
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
|
assert marlin_kernels is not None
|
|
|
|
A_flat = A.view(-1, A.shape[-1])
|
|
C = marlin_kernels.gptq_marlin_gemm(
|
|
A_flat,
|
|
self.qweight,
|
|
self.scales,
|
|
self.qzeros,
|
|
self.g_idx,
|
|
self.perm,
|
|
self.workspace,
|
|
self.bits,
|
|
A_flat.shape[0],
|
|
self.scales.shape[1],
|
|
A_flat.shape[1],
|
|
self.is_full_k,
|
|
self.qzeros.numel() > 0,
|
|
True,
|
|
)
|
|
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
|
|
|
|
if self.bias is not None:
|
|
C += self.bias
|
|
|
|
return C
|
|
|
|
|
|
def awq_to_marlin_zero_points(
|
|
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
|
) -> torch.Tensor:
|
|
# AWQ zero-points are quantized and packed on the column dim.
|
|
# In addition, the values are permuted based on dequantizer.
|
|
# Here we undo both of these, and then apply marlin permutation
|
|
# and pack it back.
|
|
q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
|
|
|
|
# Undo interleaving (use argsort(..) to get inverse perm)
|
|
if num_bits == 4:
|
|
undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
|
|
elif num_bits == 8:
|
|
undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
|
|
else:
|
|
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
|
|
|
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
|
|
q_zp = q_zp.reshape((-1, size_n)).contiguous()
|
|
|
|
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
|
|
return marlin_zp
|
|
|
|
|
|
def _check_valid_shape(in_features: int, out_features: int):
|
|
if (in_features % 128 != 0 or out_features % 64 != 0) and (
|
|
in_features % 64 != 0 or out_features % 128 != 0
|
|
):
|
|
raise ValueError(
|
|
f"The GPTQ Marlin kernel does not have a valid thread configuration for weight matrix with shape ({out_features}, {in_features})."
|
|
" The shape elements must be divisible by (128, 64) or (64, 128)."
|
|
)
|