fix: prefer recent gptq changes
This commit is contained in:
parent
7e773b0f20
commit
b84bb19ece
|
@ -1,12 +1,76 @@
|
||||||
import os
|
import os
|
||||||
from typing import List, Union
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
from text_generation_server.utils.weights import Weights, WeightsLoader
|
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||||
from text_generation_server.layers.gptq.gptq_types import GPTQWeight
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GPTQWeight(Weight):
|
||||||
|
qweight: torch.Tensor
|
||||||
|
qzeros: torch.Tensor
|
||||||
|
scales: torch.Tensor
|
||||||
|
g_idx: Optional[torch.Tensor]
|
||||||
|
bits: int
|
||||||
|
groupsize: int
|
||||||
|
use_awq_kernel: bool
|
||||||
|
use_exllama: bool
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.scales.dtype == torch.float:
|
||||||
|
self.scales = self.scales.half()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return self.qweight.device
|
||||||
|
|
||||||
|
def get_linear(self, bias: torch.Tensor):
|
||||||
|
if self.use_awq_kernel:
|
||||||
|
if SYSTEM == "rocm":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
|
||||||
|
"to use Exllama/GPTQ kernels for AWQ inference."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from text_generation_server.layers.awq.quantize.qmodule import WQLinear
|
||||||
|
|
||||||
|
return WQLinear(
|
||||||
|
w_bit=self.bits,
|
||||||
|
group_size=self.groupsize,
|
||||||
|
qweight=self.qweight,
|
||||||
|
qzeros=self.qzeros,
|
||||||
|
scales=self.scales,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
|
||||||
|
)
|
||||||
|
elif self.use_exllama:
|
||||||
|
try:
|
||||||
|
from text_generation_server.layers.gptq import ExllamaQuantLinear
|
||||||
|
except ImportError:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
|
||||||
|
)
|
||||||
|
|
||||||
|
return ExllamaQuantLinear(self, bias)
|
||||||
|
else:
|
||||||
|
from text_generation_server.layers.gptq.quant_linear import QuantLinear
|
||||||
|
|
||||||
|
return QuantLinear(
|
||||||
|
self.qweight,
|
||||||
|
self.qzeros,
|
||||||
|
self.scales,
|
||||||
|
self.g_idx,
|
||||||
|
bias,
|
||||||
|
self.bits,
|
||||||
|
self.groupsize,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GPTQWeightsLoader(WeightsLoader):
|
class GPTQWeightsLoader(WeightsLoader):
|
||||||
|
|
Loading…
Reference in New Issue