Handle GPTQ-Marlin loading in `GPTQMarlinWeightLoader` (#2300)
The `GPTWeightLoader` was structured like this in pseudocode: if marlin: Set up tensors in a way that GPTQ-Marlin expects else: Set up tensors in a way that ExLlama/GPTQ/AWQ expect However, the GPT-Marlin implementation details should really be in the `marlin` module. So move the former part out to a separate `GPTQMarlinWeightsLoader`.
This commit is contained in:
parent
2b19d671b4
commit
34f7dcfd80
|
@ -124,50 +124,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||
self.sym = sym
|
||||
|
||||
def get_weights(self, weights: Weights, prefix: str):
|
||||
from text_generation_server.layers.marlin import (
|
||||
can_use_gptq_marlin,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
self._get_gptq_params(weights)
|
||||
if can_use_gptq_marlin(
|
||||
bits=self.bits,
|
||||
groupsize=self.groupsize,
|
||||
quant_method=self.quant_method,
|
||||
quantize=self.quantize,
|
||||
sym=self.sym,
|
||||
):
|
||||
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||
try:
|
||||
qweight = weights.get_tensor(f"{prefix}.qweight")
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
if not self.sym:
|
||||
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
||||
else:
|
||||
qzeros = None
|
||||
|
||||
if self.quant_method == "awq":
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
||||
scales = weights.get_tensor(f"{prefix}.scales")
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
qzeros=qzeros,
|
||||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
quant_method=self.quant_method,
|
||||
sym=self.sym,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
use_exllama = True
|
||||
if self.bits != 4:
|
||||
|
@ -248,11 +205,6 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
from text_generation_server.layers.marlin import (
|
||||
can_use_gptq_marlin,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
try:
|
||||
qweight = weights.get_packed_sharded(
|
||||
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||
|
@ -267,36 +219,6 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||
scales = scales.to(dtype=weights.dtype)
|
||||
|
||||
self._get_gptq_params(weights)
|
||||
if can_use_gptq_marlin(
|
||||
bits=self.bits,
|
||||
groupsize=self.groupsize,
|
||||
quant_method=self.quant_method,
|
||||
quantize=self.quantize,
|
||||
sym=self.sym,
|
||||
):
|
||||
if not self.sym:
|
||||
qzeros = weights.get_packed_sharded(
|
||||
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
else:
|
||||
qzeros = None
|
||||
|
||||
if self.quant_method == "awq":
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
qzeros=qzeros,
|
||||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
quant_method=self.quant_method,
|
||||
sym=self.sym,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
qzeros = weights.get_packed_sharded(
|
||||
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
||||
|
@ -334,11 +256,6 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||
)
|
||||
|
||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||
from text_generation_server.layers.marlin import (
|
||||
can_use_gptq_marlin,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
try:
|
||||
qweight = torch.cat(
|
||||
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||
|
@ -353,41 +270,6 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||
)
|
||||
|
||||
self._get_gptq_params(weights)
|
||||
if can_use_gptq_marlin(
|
||||
bits=self.bits,
|
||||
groupsize=self.groupsize,
|
||||
quant_method=self.quant_method,
|
||||
quantize=self.quantize,
|
||||
sym=self.sym,
|
||||
):
|
||||
|
||||
if not self.sym:
|
||||
qzeros = torch.cat(
|
||||
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
else:
|
||||
qzeros = None
|
||||
|
||||
if self.quant_method == "awq":
|
||||
g_idx = None
|
||||
else:
|
||||
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
qzeros=qzeros,
|
||||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
quant_method=self.quant_method,
|
||||
sym=self.sym,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
qzeros = torch.cat(
|
||||
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||
|
@ -441,59 +323,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||
)
|
||||
|
||||
def get_weights_row(self, weights: Weights, prefix: str):
|
||||
from text_generation_server.layers.marlin import (
|
||||
can_use_gptq_marlin,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
self._get_gptq_params(weights)
|
||||
if can_use_gptq_marlin(
|
||||
bits=self.bits,
|
||||
groupsize=self.groupsize,
|
||||
quant_method=self.quant_method,
|
||||
quantize=self.quantize,
|
||||
sym=self.sym,
|
||||
):
|
||||
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||
try:
|
||||
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
if not self.sym:
|
||||
if self.desc_act or self.groupsize == -1:
|
||||
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
||||
else:
|
||||
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||
else:
|
||||
qzeros = None
|
||||
|
||||
if self.quant_method == "awq":
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
|
||||
if self.desc_act or self.groupsize == -1:
|
||||
scales = weights.get_tensor(f"{prefix}.scales")
|
||||
else:
|
||||
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
|
||||
|
||||
sharded_in_features = weights.process_group.size() > 1
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
qzeros=qzeros,
|
||||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
quant_method=self.quant_method,
|
||||
sym=self.sym,
|
||||
sharded_infeatures=sharded_in_features,
|
||||
)
|
||||
|
||||
use_exllama = True
|
||||
if self.bits != 4:
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear
|
||||
from text_generation_server.layers.marlin.gptq import (
|
||||
GPTQMarlinLinear,
|
||||
GPTQMarlinWeight,
|
||||
GPTQMarlinWeightsLoader,
|
||||
can_use_gptq_marlin,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
@ -9,8 +8,7 @@ from text_generation_server.layers.marlin.marlin import MarlinWeightsLoader
|
|||
|
||||
__all__ = [
|
||||
"GPTQMarlinFP8Linear",
|
||||
"GPTQMarlinLinear",
|
||||
"GPTQMarlinWeight",
|
||||
"GPTQMarlinWeightsLoader",
|
||||
"MarlinWeightsLoader",
|
||||
"can_use_gptq_marlin",
|
||||
"repack_gptq_for_marlin",
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
@ -13,7 +13,7 @@ from text_generation_server.layers.marlin.util import (
|
|||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import Weight
|
||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||
|
||||
try:
|
||||
import marlin_kernels
|
||||
|
@ -48,6 +48,204 @@ def can_use_gptq_marlin(
|
|||
)
|
||||
|
||||
|
||||
class GPTQMarlinWeightsLoader(WeightsLoader):
|
||||
"""
|
||||
Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
bits: int,
|
||||
desc_act: bool,
|
||||
groupsize: int,
|
||||
quant_method: str,
|
||||
quantize: str,
|
||||
sym: bool,
|
||||
):
|
||||
self.bits = bits
|
||||
self.desc_act = desc_act
|
||||
self.groupsize = groupsize
|
||||
self.quant_method = quant_method
|
||||
self.quantize = quantize
|
||||
self.sym = sym
|
||||
|
||||
def get_weights(self, weights: Weights, prefix: str):
|
||||
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||
try:
|
||||
qweight = weights.get_tensor(f"{prefix}.qweight")
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
if not self.sym:
|
||||
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
||||
else:
|
||||
qzeros = None
|
||||
|
||||
if self.quant_method == "awq":
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
||||
scales = weights.get_tensor(f"{prefix}.scales")
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
qzeros=qzeros,
|
||||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
quant_method=self.quant_method,
|
||||
sym=self.sym,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: Weights,
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
|
||||
try:
|
||||
qweight = weights.get_packed_sharded(
|
||||
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
|
||||
)
|
||||
scales = weights.get_packed_sharded(
|
||||
f"{prefix}.scales", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
scales = scales.to(dtype=weights.dtype)
|
||||
|
||||
if not self.sym:
|
||||
qzeros = weights.get_packed_sharded(
|
||||
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
else:
|
||||
qzeros = None
|
||||
|
||||
if self.quant_method == "awq":
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
qzeros=qzeros,
|
||||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
quant_method=self.quant_method,
|
||||
sym=self.sym,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||
try:
|
||||
qweight = torch.cat(
|
||||
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
scales = torch.cat(
|
||||
[weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
if not self.sym:
|
||||
qzeros = torch.cat(
|
||||
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
else:
|
||||
qzeros = None
|
||||
|
||||
if self.quant_method == "awq":
|
||||
g_idx = None
|
||||
else:
|
||||
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
qzeros=qzeros,
|
||||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
quant_method=self.quant_method,
|
||||
sym=self.sym,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
def get_weights_row(self, weights: Weights, prefix: str):
|
||||
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||
try:
|
||||
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
if not self.sym:
|
||||
if self.desc_act or self.groupsize == -1:
|
||||
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
||||
else:
|
||||
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||
else:
|
||||
qzeros = None
|
||||
|
||||
if self.quant_method == "awq":
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
|
||||
if self.desc_act or self.groupsize == -1:
|
||||
scales = weights.get_tensor(f"{prefix}.scales")
|
||||
else:
|
||||
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
|
||||
|
||||
sharded_in_features = weights.process_group.size() > 1
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
qzeros=qzeros,
|
||||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
quant_method=self.quant_method,
|
||||
sym=self.sym,
|
||||
sharded_infeatures=sharded_in_features,
|
||||
)
|
||||
|
||||
def _get_gptq_params(self, weights: Weights):
|
||||
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"):
|
||||
self.bits = weights.get_tensor("gptq_bits").item()
|
||||
self.groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||
self.desc_act = False
|
||||
# `server quantize` used asymmetric quantization unconditionally
|
||||
# before the `gptq_sym` setting tensor was added.
|
||||
self.sym = (
|
||||
weights.get_tensor("gptq_sym").item()
|
||||
if weights._has_tensor("gptq_sym")
|
||||
else False
|
||||
)
|
||||
self.quant_method = "gptq"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTQMarlinWeight(Weight):
|
||||
"""
|
||||
|
|
|
@ -4,6 +4,7 @@ from dataclasses import dataclass
|
|||
from typing import Optional
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from text_generation_server.layers.marlin.gptq import can_use_gptq_marlin
|
||||
from text_generation_server.utils.weights import (
|
||||
DefaultWeightsLoader,
|
||||
WeightsLoader,
|
||||
|
@ -128,6 +129,24 @@ def get_loader(
|
|||
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
|
||||
)
|
||||
|
||||
if can_use_gptq_marlin(
|
||||
bits=quantizer_config.bits,
|
||||
groupsize=quantizer_config.groupsize,
|
||||
quant_method=quantizer_config.quant_method,
|
||||
quantize=quantize,
|
||||
sym=quantizer_config.sym,
|
||||
):
|
||||
from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader
|
||||
|
||||
return GPTQMarlinWeightsLoader(
|
||||
bits=quantizer_config.bits,
|
||||
desc_act=quantizer_config.desc_act,
|
||||
groupsize=quantizer_config.groupsize,
|
||||
quant_method=quantizer_config.quant_method,
|
||||
quantize=quantize,
|
||||
sym=quantizer_config.sym,
|
||||
)
|
||||
else:
|
||||
return GPTQWeightsLoader(
|
||||
bits=quantizer_config.bits,
|
||||
desc_act=quantizer_config.desc_act,
|
||||
|
|
Loading…
Reference in New Issue