2024-06-05 02:14:40 -06:00
|
|
|
from dataclasses import dataclass
|
2024-06-14 01:45:42 -06:00
|
|
|
from typing import Optional, Tuple, List
|
2024-06-05 02:14:40 -06:00
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
|
2024-06-14 01:45:42 -06:00
|
|
|
from text_generation_server.utils.import_utils import SYSTEM
|
|
|
|
|
2024-06-05 02:14:40 -06:00
|
|
|
try:
|
2024-06-14 01:45:42 -06:00
|
|
|
import marlin_kernels
|
2024-06-05 02:14:40 -06:00
|
|
|
except ImportError:
|
2024-06-14 01:45:42 -06:00
|
|
|
marlin_kernels = None
|
2024-06-05 02:14:40 -06:00
|
|
|
|
|
|
|
try:
|
|
|
|
major, _minor = torch.cuda.get_device_capability()
|
|
|
|
has_sm_8_0 = major >= 8
|
|
|
|
except Exception:
|
|
|
|
has_sm_8_0 = False
|
|
|
|
|
2024-06-14 01:45:42 -06:00
|
|
|
|
|
|
|
GPTQ_MARLIN_BITS = [4, 8]
|
|
|
|
GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
|
2024-06-05 02:14:40 -06:00
|
|
|
MARLIN_TILE_SIZE = 16
|
|
|
|
|
|
|
|
|
2024-06-14 01:45:42 -06:00
|
|
|
def _check_marlin_kernels():
|
|
|
|
if not (SYSTEM == "cuda" and has_sm_8_0):
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later."
|
|
|
|
)
|
|
|
|
|
|
|
|
if marlin_kernels is None:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"marlin is not installed, install it with: pip install server/marlin"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def _check_valid_shape(in_features: int, out_features: int):
|
|
|
|
if (in_features % 128 != 0 or out_features % 64 != 0) and (
|
|
|
|
in_features % 64 != 0 or out_features % 128 != 0
|
|
|
|
):
|
|
|
|
raise ValueError(
|
|
|
|
f"The GPTQ Marlin kernel does not have a valid thread configuration for weight matrix with shape ({out_features}, {in_features})."
|
|
|
|
" The shape elements must be divisible by (128, 64) or (64, 128)."
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# https://github.com/IST-DASLab/marlin/blob/2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c/marlin/__init__.py#L40C1-L68C54
|
|
|
|
def _get_perms() -> Tuple[List[int], List[int]]:
|
|
|
|
scale_perm = []
|
|
|
|
for i in range(8):
|
|
|
|
scale_perm.extend([i + 8 * j for j in range(8)])
|
|
|
|
scale_perm_single = []
|
|
|
|
for i in range(4):
|
|
|
|
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
|
|
|
return scale_perm, scale_perm_single
|
|
|
|
|
|
|
|
|
|
|
|
_scale_perm, _scale_perm_single = _get_perms()
|
|
|
|
|
|
|
|
|
|
|
|
def permute_scales(scales: torch.Tensor):
|
|
|
|
out_features = scales.shape[1]
|
|
|
|
if scales.shape[0] == 1:
|
|
|
|
scales = scales.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
|
|
|
|
else:
|
|
|
|
scales = scales.reshape((-1, len(_scale_perm)))[:, _scale_perm]
|
|
|
|
return scales.reshape((-1, out_features)).contiguous()
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class GPTQMarlinWeight:
|
|
|
|
"""
|
|
|
|
Repacked GPTQ Marlin weights.
|
|
|
|
"""
|
|
|
|
|
|
|
|
qweight: torch.Tensor
|
|
|
|
scales: torch.Tensor
|
|
|
|
g_idx: torch.Tensor
|
|
|
|
perm: torch.Tensor
|
|
|
|
bits: int
|
|
|
|
is_full_k: bool
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
assert self.qweight.dtype == torch.int32
|
|
|
|
assert self.scales.dtype == torch.float16
|
|
|
|
assert self.g_idx.dtype == torch.int32
|
|
|
|
assert self.perm.dtype == torch.int32
|
|
|
|
|
|
|
|
|
|
|
|
def repack_gptq_for_marlin(
|
|
|
|
*,
|
|
|
|
qweight: torch.Tensor,
|
|
|
|
scales: torch.Tensor,
|
|
|
|
g_idx: torch.Tensor,
|
|
|
|
bits: int,
|
|
|
|
desc_act: bool,
|
|
|
|
groupsize: int,
|
|
|
|
sym: bool,
|
|
|
|
sharded_infeatures: bool,
|
|
|
|
) -> GPTQMarlinWeight:
|
|
|
|
"""Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
|
|
|
|
_check_marlin_kernels()
|
|
|
|
assert marlin_kernels is not None
|
|
|
|
|
|
|
|
if bits not in GPTQ_MARLIN_BITS:
|
|
|
|
supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS)
|
|
|
|
raise RuntimeError(
|
|
|
|
f"Repacking {bits}-bit GPTQ weights as Marlin is not supported, must be one of: {supported_bits}"
|
|
|
|
)
|
|
|
|
|
|
|
|
if groupsize not in GPTQ_MARLIN_GROUP_SIZES:
|
|
|
|
supported_sizes = ", ".join(str(b) for b in GPTQ_MARLIN_GROUP_SIZES)
|
|
|
|
raise RuntimeError(
|
|
|
|
f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}"
|
|
|
|
)
|
|
|
|
if not sym:
|
|
|
|
raise RuntimeError(
|
|
|
|
"Repacking GPTQ weights with asymmetric quantization as Marlin is not supported."
|
|
|
|
)
|
|
|
|
|
|
|
|
weights_per_int = 32 // bits
|
|
|
|
in_features = qweight.shape[0] * weights_per_int
|
|
|
|
out_features = qweight.shape[1]
|
|
|
|
|
|
|
|
if in_features % groupsize != 0:
|
|
|
|
raise ValueError(
|
|
|
|
f"Number of input features ({in_features}) not divisible by group size ({groupsize})"
|
|
|
|
)
|
|
|
|
|
|
|
|
if desc_act and groupsize != -1:
|
|
|
|
perm = torch.argsort(g_idx).to(torch.int)
|
|
|
|
g_idx = g_idx[perm]
|
|
|
|
else:
|
|
|
|
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
|
|
|
|
g_idx = torch.empty(0, dtype=torch.int, device=qweight.device)
|
|
|
|
|
|
|
|
repacked = marlin_kernels.gptq_marlin_repack(
|
|
|
|
qweight, perm, in_features, out_features, bits
|
|
|
|
)
|
|
|
|
|
|
|
|
scales = permute_scales(scales)
|
|
|
|
|
|
|
|
is_full_k = not (desc_act and sharded_infeatures)
|
|
|
|
|
|
|
|
return GPTQMarlinWeight(
|
|
|
|
qweight=repacked,
|
|
|
|
scales=scales,
|
|
|
|
g_idx=g_idx,
|
|
|
|
perm=perm,
|
|
|
|
bits=bits,
|
|
|
|
is_full_k=is_full_k,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class GPTQMarlinLinear(nn.Module):
|
|
|
|
"""
|
|
|
|
Linear layer for GPTQ weights that were converted for the GPTQ-Marlin
|
|
|
|
kernels.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
*,
|
|
|
|
weight: GPTQMarlinWeight,
|
|
|
|
bias: Optional[torch.Tensor],
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
_check_marlin_kernels()
|
|
|
|
assert marlin_kernels is not None
|
|
|
|
|
|
|
|
in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE
|
|
|
|
out_features = weight.scales.shape[1]
|
|
|
|
_check_valid_shape(in_features=in_features, out_features=out_features)
|
|
|
|
|
|
|
|
self.bits = weight.bits
|
|
|
|
self.is_full_k = weight.is_full_k
|
|
|
|
|
|
|
|
self.register_buffer("qweight", weight.qweight)
|
|
|
|
self.register_buffer("scales", weight.scales)
|
|
|
|
self.register_buffer("g_idx", weight.g_idx)
|
|
|
|
self.register_buffer("perm", weight.perm)
|
|
|
|
if bias is not None:
|
|
|
|
self.register_buffer("bias", bias)
|
|
|
|
else:
|
|
|
|
self.bias = None
|
|
|
|
|
|
|
|
self.workspace = torch.zeros(
|
|
|
|
out_features // 64 * 16, dtype=torch.int, device=weight.qweight.device
|
|
|
|
)
|
|
|
|
|
|
|
|
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
|
|
|
assert marlin_kernels is not None
|
|
|
|
|
|
|
|
A_flat = A.view(-1, A.shape[-1])
|
|
|
|
C = marlin_kernels.gptq_marlin_gemm(
|
|
|
|
A_flat,
|
|
|
|
self.qweight,
|
|
|
|
self.scales,
|
|
|
|
self.g_idx,
|
|
|
|
self.perm,
|
|
|
|
self.workspace,
|
|
|
|
self.bits,
|
|
|
|
A_flat.shape[0],
|
|
|
|
self.scales.shape[1],
|
|
|
|
A_flat.shape[1],
|
|
|
|
self.is_full_k,
|
|
|
|
)
|
|
|
|
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
|
|
|
|
|
|
|
|
if self.bias is not None:
|
|
|
|
C += self.bias
|
|
|
|
|
|
|
|
return C
|
|
|
|
|
|
|
|
|
2024-06-05 02:14:40 -06:00
|
|
|
@dataclass
|
|
|
|
class MarlinWeight:
|
|
|
|
"""
|
|
|
|
Marlin weights.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
B (torch.Tensor): int4-quantized weights packed into int32.
|
|
|
|
s (torch.Tensor): float16 scales.
|
|
|
|
"""
|
|
|
|
|
|
|
|
B: torch.Tensor
|
|
|
|
s: torch.Tensor
|
|
|
|
|
2024-06-14 01:45:42 -06:00
|
|
|
def __post_init__(self):
|
|
|
|
assert self.B.dtype == torch.int32
|
|
|
|
assert self.s.dtype == torch.float16
|
|
|
|
|
2024-06-05 02:14:40 -06:00
|
|
|
|
|
|
|
class MarlinLinear(nn.Module):
|
2024-06-14 01:45:42 -06:00
|
|
|
def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]):
|
2024-06-05 02:14:40 -06:00
|
|
|
super().__init__()
|
|
|
|
|
2024-06-14 01:45:42 -06:00
|
|
|
_check_marlin_kernels()
|
|
|
|
assert marlin_kernels is not None
|
2024-06-05 02:14:40 -06:00
|
|
|
|
2024-06-14 01:45:42 -06:00
|
|
|
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE
|
|
|
|
out_features = weight.s.shape[1]
|
2024-06-05 02:14:40 -06:00
|
|
|
assert (
|
|
|
|
in_features % 128 == 0
|
|
|
|
), f"Number of input features ({in_features}) not divisable by 128"
|
|
|
|
assert (
|
|
|
|
out_features % 256 == 0
|
|
|
|
), f"Number of output features ({out_features}) not divisable by 256"
|
|
|
|
|
2024-06-14 01:45:42 -06:00
|
|
|
groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0]
|
|
|
|
assert groupsize in {
|
2024-06-05 02:14:40 -06:00
|
|
|
-1,
|
|
|
|
128,
|
2024-06-14 01:45:42 -06:00
|
|
|
}, f"Group size must be -1 or 128, was {groupsize}"
|
2024-06-05 02:14:40 -06:00
|
|
|
|
2024-06-14 01:45:42 -06:00
|
|
|
self.register_buffer("B", weight.B)
|
|
|
|
self.register_buffer("s", weight.s)
|
2024-06-05 02:14:40 -06:00
|
|
|
if bias is not None:
|
|
|
|
self.register_buffer("bias", bias)
|
|
|
|
else:
|
|
|
|
self.bias = None
|
|
|
|
|
|
|
|
self.workspace = torch.zeros(
|
2024-06-14 01:45:42 -06:00
|
|
|
out_features // 64 * 16, dtype=torch.int, device=weight.B.device
|
2024-06-05 02:14:40 -06:00
|
|
|
)
|
|
|
|
|
|
|
|
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
2024-06-14 01:45:42 -06:00
|
|
|
assert marlin_kernels is not None
|
|
|
|
|
|
|
|
C = marlin_kernels.marlin_gemm(
|
|
|
|
A.view(-1, A.shape[-1]),
|
2024-06-05 02:14:40 -06:00
|
|
|
self.B,
|
|
|
|
self.s,
|
|
|
|
self.workspace,
|
2024-06-14 01:45:42 -06:00
|
|
|
A.shape[0],
|
|
|
|
self.s.shape[1],
|
|
|
|
A.shape[1],
|
2024-06-05 02:14:40 -06:00
|
|
|
)
|
2024-06-14 01:45:42 -06:00
|
|
|
C = C.reshape(A.shape[:-1] + (self.s.shape[1],))
|
2024-06-05 02:14:40 -06:00
|
|
|
|
|
|
|
if self.bias is not None:
|
|
|
|
C += self.bias
|
|
|
|
|
|
|
|
return C
|