fix per-column quantization
This commit is contained in:
parent
edfbfdfb3f
commit
6bf7090ecd
|
@ -6,6 +6,8 @@ from torch.cuda.amp import custom_bwd, custom_fwd
|
|||
|
||||
import torch
|
||||
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
from custom_kernels.exllama import make_q4, q4_matmul
|
||||
except Exception as e:
|
||||
|
@ -422,8 +424,9 @@ class Ex4bitLinear:
|
|||
self.groupsize = None
|
||||
if self.qzeros.shape[0] > 1:
|
||||
self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
|
||||
|
||||
assert groupsize == self.groupsize
|
||||
|
||||
if self.groupsize is not None:
|
||||
assert groupsize == self.groupsize
|
||||
|
||||
# Handle act-order matrix
|
||||
if self.g_idx is not None:
|
||||
|
|
|
@ -152,6 +152,8 @@ class Weights:
|
|||
except RuntimeError:
|
||||
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
|
||||
|
||||
bits, groupsize = self.get_gptq_qparams()
|
||||
|
||||
if use_triton_kernel:
|
||||
# The triton kernel reorders the scales/zero points instead of the weight/activation.
|
||||
# Thus, each rank needs the full qzeros/scales.
|
||||
|
@ -159,10 +161,14 @@ class Weights:
|
|||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
else:
|
||||
# Exllama reorders the weights in advance and the activations on the fly, thus
|
||||
# the scales and zero-points do not need to be reordered
|
||||
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||
if groupsize >= 16:
|
||||
# Exllama reorders the weights in advance and the activations on the fly, thus
|
||||
# the scales and zero-points do not need to be reordered.
|
||||
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||
else:
|
||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
|
||||
# For tp > 1, at this point we know we do not use act-order
|
||||
if self.process_group.size() == 1:
|
||||
|
@ -170,8 +176,6 @@ class Weights:
|
|||
else:
|
||||
g_idx = None
|
||||
|
||||
bits, groupsize = self.get_gptq_qparams()
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel)
|
||||
else:
|
||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||
|
|
Loading…
Reference in New Issue