Handle GPTQ-Marlin loading in `GPTQMarlinWeightLoader` (#2300)
The `GPTWeightLoader` was structured like this in pseudocode: if marlin: Set up tensors in a way that GPTQ-Marlin expects else: Set up tensors in a way that ExLlama/GPTQ/AWQ expect However, the GPT-Marlin implementation details should really be in the `marlin` module. So move the former part out to a separate `GPTQMarlinWeightsLoader`.
This commit is contained in:
parent
2b19d671b4
commit
34f7dcfd80
|
@ -124,50 +124,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||||
self.sym = sym
|
self.sym = sym
|
||||||
|
|
||||||
def get_weights(self, weights: Weights, prefix: str):
|
def get_weights(self, weights: Weights, prefix: str):
|
||||||
from text_generation_server.layers.marlin import (
|
|
||||||
can_use_gptq_marlin,
|
|
||||||
repack_gptq_for_marlin,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._get_gptq_params(weights)
|
self._get_gptq_params(weights)
|
||||||
if can_use_gptq_marlin(
|
|
||||||
bits=self.bits,
|
|
||||||
groupsize=self.groupsize,
|
|
||||||
quant_method=self.quant_method,
|
|
||||||
quantize=self.quantize,
|
|
||||||
sym=self.sym,
|
|
||||||
):
|
|
||||||
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
|
||||||
try:
|
|
||||||
qweight = weights.get_tensor(f"{prefix}.qweight")
|
|
||||||
except RuntimeError:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.sym:
|
|
||||||
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
|
||||||
else:
|
|
||||||
qzeros = None
|
|
||||||
|
|
||||||
if self.quant_method == "awq":
|
|
||||||
g_idx = None
|
|
||||||
else:
|
|
||||||
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
|
||||||
scales = weights.get_tensor(f"{prefix}.scales")
|
|
||||||
|
|
||||||
return repack_gptq_for_marlin(
|
|
||||||
qweight=qweight,
|
|
||||||
scales=scales,
|
|
||||||
qzeros=qzeros,
|
|
||||||
g_idx=g_idx,
|
|
||||||
bits=self.bits,
|
|
||||||
desc_act=self.desc_act,
|
|
||||||
groupsize=self.groupsize,
|
|
||||||
quant_method=self.quant_method,
|
|
||||||
sym=self.sym,
|
|
||||||
sharded_infeatures=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
use_exllama = True
|
use_exllama = True
|
||||||
if self.bits != 4:
|
if self.bits != 4:
|
||||||
|
@ -248,11 +205,6 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||||
prefix: str,
|
prefix: str,
|
||||||
block_sizes: Union[int, List[int]],
|
block_sizes: Union[int, List[int]],
|
||||||
):
|
):
|
||||||
from text_generation_server.layers.marlin import (
|
|
||||||
can_use_gptq_marlin,
|
|
||||||
repack_gptq_for_marlin,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qweight = weights.get_packed_sharded(
|
qweight = weights.get_packed_sharded(
|
||||||
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||||
|
@ -267,36 +219,6 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||||
scales = scales.to(dtype=weights.dtype)
|
scales = scales.to(dtype=weights.dtype)
|
||||||
|
|
||||||
self._get_gptq_params(weights)
|
self._get_gptq_params(weights)
|
||||||
if can_use_gptq_marlin(
|
|
||||||
bits=self.bits,
|
|
||||||
groupsize=self.groupsize,
|
|
||||||
quant_method=self.quant_method,
|
|
||||||
quantize=self.quantize,
|
|
||||||
sym=self.sym,
|
|
||||||
):
|
|
||||||
if not self.sym:
|
|
||||||
qzeros = weights.get_packed_sharded(
|
|
||||||
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
qzeros = None
|
|
||||||
|
|
||||||
if self.quant_method == "awq":
|
|
||||||
g_idx = None
|
|
||||||
else:
|
|
||||||
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
|
||||||
return repack_gptq_for_marlin(
|
|
||||||
qweight=qweight,
|
|
||||||
scales=scales,
|
|
||||||
qzeros=qzeros,
|
|
||||||
g_idx=g_idx,
|
|
||||||
bits=self.bits,
|
|
||||||
desc_act=self.desc_act,
|
|
||||||
groupsize=self.groupsize,
|
|
||||||
quant_method=self.quant_method,
|
|
||||||
sym=self.sym,
|
|
||||||
sharded_infeatures=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
qzeros = weights.get_packed_sharded(
|
qzeros = weights.get_packed_sharded(
|
||||||
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
||||||
|
@ -334,11 +256,6 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||||
from text_generation_server.layers.marlin import (
|
|
||||||
can_use_gptq_marlin,
|
|
||||||
repack_gptq_for_marlin,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qweight = torch.cat(
|
qweight = torch.cat(
|
||||||
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||||
|
@ -353,41 +270,6 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||||
)
|
)
|
||||||
|
|
||||||
self._get_gptq_params(weights)
|
self._get_gptq_params(weights)
|
||||||
if can_use_gptq_marlin(
|
|
||||||
bits=self.bits,
|
|
||||||
groupsize=self.groupsize,
|
|
||||||
quant_method=self.quant_method,
|
|
||||||
quantize=self.quantize,
|
|
||||||
sym=self.sym,
|
|
||||||
):
|
|
||||||
|
|
||||||
if not self.sym:
|
|
||||||
qzeros = torch.cat(
|
|
||||||
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
qzeros = None
|
|
||||||
|
|
||||||
if self.quant_method == "awq":
|
|
||||||
g_idx = None
|
|
||||||
else:
|
|
||||||
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
|
||||||
for w2 in w[1:]:
|
|
||||||
torch.testing.assert_close(w2, w[0])
|
|
||||||
g_idx = w[0]
|
|
||||||
|
|
||||||
return repack_gptq_for_marlin(
|
|
||||||
qweight=qweight,
|
|
||||||
scales=scales,
|
|
||||||
qzeros=qzeros,
|
|
||||||
g_idx=g_idx,
|
|
||||||
bits=self.bits,
|
|
||||||
desc_act=self.desc_act,
|
|
||||||
groupsize=self.groupsize,
|
|
||||||
quant_method=self.quant_method,
|
|
||||||
sym=self.sym,
|
|
||||||
sharded_infeatures=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
qzeros = torch.cat(
|
qzeros = torch.cat(
|
||||||
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||||
|
@ -441,59 +323,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_weights_row(self, weights: Weights, prefix: str):
|
def get_weights_row(self, weights: Weights, prefix: str):
|
||||||
from text_generation_server.layers.marlin import (
|
|
||||||
can_use_gptq_marlin,
|
|
||||||
repack_gptq_for_marlin,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._get_gptq_params(weights)
|
self._get_gptq_params(weights)
|
||||||
if can_use_gptq_marlin(
|
|
||||||
bits=self.bits,
|
|
||||||
groupsize=self.groupsize,
|
|
||||||
quant_method=self.quant_method,
|
|
||||||
quantize=self.quantize,
|
|
||||||
sym=self.sym,
|
|
||||||
):
|
|
||||||
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
|
||||||
try:
|
|
||||||
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
|
||||||
except RuntimeError:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.sym:
|
|
||||||
if self.desc_act or self.groupsize == -1:
|
|
||||||
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
|
||||||
else:
|
|
||||||
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
|
|
||||||
else:
|
|
||||||
qzeros = None
|
|
||||||
|
|
||||||
if self.quant_method == "awq":
|
|
||||||
g_idx = None
|
|
||||||
else:
|
|
||||||
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
|
|
||||||
|
|
||||||
if self.desc_act or self.groupsize == -1:
|
|
||||||
scales = weights.get_tensor(f"{prefix}.scales")
|
|
||||||
else:
|
|
||||||
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
|
|
||||||
|
|
||||||
sharded_in_features = weights.process_group.size() > 1
|
|
||||||
|
|
||||||
return repack_gptq_for_marlin(
|
|
||||||
qweight=qweight,
|
|
||||||
scales=scales,
|
|
||||||
qzeros=qzeros,
|
|
||||||
g_idx=g_idx,
|
|
||||||
bits=self.bits,
|
|
||||||
desc_act=self.desc_act,
|
|
||||||
groupsize=self.groupsize,
|
|
||||||
quant_method=self.quant_method,
|
|
||||||
sym=self.sym,
|
|
||||||
sharded_infeatures=sharded_in_features,
|
|
||||||
)
|
|
||||||
|
|
||||||
use_exllama = True
|
use_exllama = True
|
||||||
if self.bits != 4:
|
if self.bits != 4:
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear
|
from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear
|
||||||
from text_generation_server.layers.marlin.gptq import (
|
from text_generation_server.layers.marlin.gptq import (
|
||||||
GPTQMarlinLinear,
|
GPTQMarlinWeightsLoader,
|
||||||
GPTQMarlinWeight,
|
|
||||||
can_use_gptq_marlin,
|
can_use_gptq_marlin,
|
||||||
repack_gptq_for_marlin,
|
repack_gptq_for_marlin,
|
||||||
)
|
)
|
||||||
|
@ -9,8 +8,7 @@ from text_generation_server.layers.marlin.marlin import MarlinWeightsLoader
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"GPTQMarlinFP8Linear",
|
"GPTQMarlinFP8Linear",
|
||||||
"GPTQMarlinLinear",
|
"GPTQMarlinWeightsLoader",
|
||||||
"GPTQMarlinWeight",
|
|
||||||
"MarlinWeightsLoader",
|
"MarlinWeightsLoader",
|
||||||
"can_use_gptq_marlin",
|
"can_use_gptq_marlin",
|
||||||
"repack_gptq_for_marlin",
|
"repack_gptq_for_marlin",
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import torch
|
import torch
|
||||||
|
@ -13,7 +13,7 @@ from text_generation_server.layers.marlin.util import (
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
from text_generation_server.utils.weights import Weight
|
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import marlin_kernels
|
import marlin_kernels
|
||||||
|
@ -48,6 +48,204 @@ def can_use_gptq_marlin(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GPTQMarlinWeightsLoader(WeightsLoader):
|
||||||
|
"""
|
||||||
|
Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
bits: int,
|
||||||
|
desc_act: bool,
|
||||||
|
groupsize: int,
|
||||||
|
quant_method: str,
|
||||||
|
quantize: str,
|
||||||
|
sym: bool,
|
||||||
|
):
|
||||||
|
self.bits = bits
|
||||||
|
self.desc_act = desc_act
|
||||||
|
self.groupsize = groupsize
|
||||||
|
self.quant_method = quant_method
|
||||||
|
self.quantize = quantize
|
||||||
|
self.sym = sym
|
||||||
|
|
||||||
|
def get_weights(self, weights: Weights, prefix: str):
|
||||||
|
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||||
|
try:
|
||||||
|
qweight = weights.get_tensor(f"{prefix}.qweight")
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.sym:
|
||||||
|
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
||||||
|
else:
|
||||||
|
qzeros = None
|
||||||
|
|
||||||
|
if self.quant_method == "awq":
|
||||||
|
g_idx = None
|
||||||
|
else:
|
||||||
|
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
||||||
|
scales = weights.get_tensor(f"{prefix}.scales")
|
||||||
|
|
||||||
|
return repack_gptq_for_marlin(
|
||||||
|
qweight=qweight,
|
||||||
|
scales=scales,
|
||||||
|
qzeros=qzeros,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=self.bits,
|
||||||
|
desc_act=self.desc_act,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
quant_method=self.quant_method,
|
||||||
|
sym=self.sym,
|
||||||
|
sharded_infeatures=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_weights_col_packed(
|
||||||
|
self,
|
||||||
|
weights: Weights,
|
||||||
|
prefix: str,
|
||||||
|
block_sizes: Union[int, List[int]],
|
||||||
|
):
|
||||||
|
|
||||||
|
try:
|
||||||
|
qweight = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
|
||||||
|
)
|
||||||
|
scales = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.scales", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
scales = scales.to(dtype=weights.dtype)
|
||||||
|
|
||||||
|
if not self.sym:
|
||||||
|
qzeros = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
qzeros = None
|
||||||
|
|
||||||
|
if self.quant_method == "awq":
|
||||||
|
g_idx = None
|
||||||
|
else:
|
||||||
|
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
||||||
|
return repack_gptq_for_marlin(
|
||||||
|
qweight=qweight,
|
||||||
|
scales=scales,
|
||||||
|
qzeros=qzeros,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=self.bits,
|
||||||
|
desc_act=self.desc_act,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
quant_method=self.quant_method,
|
||||||
|
sym=self.sym,
|
||||||
|
sharded_infeatures=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||||
|
try:
|
||||||
|
qweight = torch.cat(
|
||||||
|
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
|
||||||
|
)
|
||||||
|
|
||||||
|
scales = torch.cat(
|
||||||
|
[weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.sym:
|
||||||
|
qzeros = torch.cat(
|
||||||
|
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
qzeros = None
|
||||||
|
|
||||||
|
if self.quant_method == "awq":
|
||||||
|
g_idx = None
|
||||||
|
else:
|
||||||
|
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||||
|
for w2 in w[1:]:
|
||||||
|
torch.testing.assert_close(w2, w[0])
|
||||||
|
g_idx = w[0]
|
||||||
|
|
||||||
|
return repack_gptq_for_marlin(
|
||||||
|
qweight=qweight,
|
||||||
|
scales=scales,
|
||||||
|
qzeros=qzeros,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=self.bits,
|
||||||
|
desc_act=self.desc_act,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
quant_method=self.quant_method,
|
||||||
|
sym=self.sym,
|
||||||
|
sharded_infeatures=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_weights_row(self, weights: Weights, prefix: str):
|
||||||
|
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||||
|
try:
|
||||||
|
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.sym:
|
||||||
|
if self.desc_act or self.groupsize == -1:
|
||||||
|
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
||||||
|
else:
|
||||||
|
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||||
|
else:
|
||||||
|
qzeros = None
|
||||||
|
|
||||||
|
if self.quant_method == "awq":
|
||||||
|
g_idx = None
|
||||||
|
else:
|
||||||
|
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||||
|
|
||||||
|
if self.desc_act or self.groupsize == -1:
|
||||||
|
scales = weights.get_tensor(f"{prefix}.scales")
|
||||||
|
else:
|
||||||
|
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
|
||||||
|
|
||||||
|
sharded_in_features = weights.process_group.size() > 1
|
||||||
|
|
||||||
|
return repack_gptq_for_marlin(
|
||||||
|
qweight=qweight,
|
||||||
|
scales=scales,
|
||||||
|
qzeros=qzeros,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=self.bits,
|
||||||
|
desc_act=self.desc_act,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
quant_method=self.quant_method,
|
||||||
|
sym=self.sym,
|
||||||
|
sharded_infeatures=sharded_in_features,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_gptq_params(self, weights: Weights):
|
||||||
|
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"):
|
||||||
|
self.bits = weights.get_tensor("gptq_bits").item()
|
||||||
|
self.groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||||
|
self.desc_act = False
|
||||||
|
# `server quantize` used asymmetric quantization unconditionally
|
||||||
|
# before the `gptq_sym` setting tensor was added.
|
||||||
|
self.sym = (
|
||||||
|
weights.get_tensor("gptq_sym").item()
|
||||||
|
if weights._has_tensor("gptq_sym")
|
||||||
|
else False
|
||||||
|
)
|
||||||
|
self.quant_method = "gptq"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GPTQMarlinWeight(Weight):
|
class GPTQMarlinWeight(Weight):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -4,6 +4,7 @@ from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
from text_generation_server.layers.marlin.gptq import can_use_gptq_marlin
|
||||||
from text_generation_server.utils.weights import (
|
from text_generation_server.utils.weights import (
|
||||||
DefaultWeightsLoader,
|
DefaultWeightsLoader,
|
||||||
WeightsLoader,
|
WeightsLoader,
|
||||||
|
@ -128,6 +129,24 @@ def get_loader(
|
||||||
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
|
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if can_use_gptq_marlin(
|
||||||
|
bits=quantizer_config.bits,
|
||||||
|
groupsize=quantizer_config.groupsize,
|
||||||
|
quant_method=quantizer_config.quant_method,
|
||||||
|
quantize=quantize,
|
||||||
|
sym=quantizer_config.sym,
|
||||||
|
):
|
||||||
|
from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader
|
||||||
|
|
||||||
|
return GPTQMarlinWeightsLoader(
|
||||||
|
bits=quantizer_config.bits,
|
||||||
|
desc_act=quantizer_config.desc_act,
|
||||||
|
groupsize=quantizer_config.groupsize,
|
||||||
|
quant_method=quantizer_config.quant_method,
|
||||||
|
quantize=quantize,
|
||||||
|
sym=quantizer_config.sym,
|
||||||
|
)
|
||||||
|
else:
|
||||||
return GPTQWeightsLoader(
|
return GPTQWeightsLoader(
|
||||||
bits=quantizer_config.bits,
|
bits=quantizer_config.bits,
|
||||||
desc_act=quantizer_config.desc_act,
|
desc_act=quantizer_config.desc_act,
|
||||||
|
|
Loading…
Reference in New Issue