fix per-column quantization

This commit is contained in:
Felix Marty 2023-07-19 17:55:41 +00:00
parent edfbfdfb3f
commit 6bf7090ecd
2 changed files with 15 additions and 8 deletions

View File

@ -6,6 +6,8 @@ from torch.cuda.amp import custom_bwd, custom_fwd
import torch
from loguru import logger
try:
from custom_kernels.exllama import make_q4, q4_matmul
except Exception as e:
@ -422,8 +424,9 @@ class Ex4bitLinear:
self.groupsize = None
if self.qzeros.shape[0] > 1:
self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
assert groupsize == self.groupsize
if self.groupsize is not None:
assert groupsize == self.groupsize
# Handle act-order matrix
if self.g_idx is not None:

View File

@ -152,6 +152,8 @@ class Weights:
except RuntimeError:
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
bits, groupsize = self.get_gptq_qparams()
if use_triton_kernel:
# The triton kernel reorders the scales/zero points instead of the weight/activation.
# Thus, each rank needs the full qzeros/scales.
@ -159,10 +161,14 @@ class Weights:
scales = self.get_tensor(f"{prefix}.scales")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
else:
# Exllama reorders the weights in advance and the activations on the fly, thus
# the scales and zero-points do not need to be reordered
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
scales = self.get_sharded(f"{prefix}.scales", dim=0)
if groupsize >= 16:
# Exllama reorders the weights in advance and the activations on the fly, thus
# the scales and zero-points do not need to be reordered.
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
scales = self.get_sharded(f"{prefix}.scales", dim=0)
else:
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
# For tp > 1, at this point we know we do not use act-order
if self.process_group.size() == 1:
@ -170,8 +176,6 @@ class Weights:
else:
g_idx = None
bits, groupsize = self.get_gptq_qparams()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel)
else:
weight = self.get_sharded(f"{prefix}.weight", dim=1)