Hotfix GPTQ.

This commit is contained in:
Nicolas Patry 2024-06-03 09:32:12 +00:00
parent 9add5d0af5
commit 9a59ebcec3
2 changed files with 13 additions and 1 deletions

View File

@ -196,6 +196,8 @@ def get_linear(weight, bias, quantize):
weight.groupsize, weight.groupsize,
) )
elif quantize == "awq": elif quantize == "awq":
from text_generation_server.layers.gptq import GPTQWeight
if not isinstance(weight, GPTQWeight): if not isinstance(weight, GPTQWeight):
raise NotImplementedError( raise NotImplementedError(
f"The passed weight is not `awq` compatible, loader needs to be updated." f"The passed weight is not `awq` compatible, loader needs to be updated."

View File

@ -154,6 +154,8 @@ class Weights:
already alternating Q,K,V within the main tensor already alternating Q,K,V within the main tensor
""" """
if quantize in ["gptq", "awq"]: if quantize in ["gptq", "awq"]:
from text_generation_server.layers.gptq import GPTQWeight
try: try:
qweight = self._get_qweight(f"{prefix}.qweight") qweight = self._get_qweight(f"{prefix}.qweight")
except RuntimeError: except RuntimeError:
@ -331,6 +333,8 @@ class Weights:
def get_multi_weights_row(self, prefix: str, quantize: str): def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "exl2": if quantize == "exl2":
from text_generation_server.layers.exl2 import Exl2Weight
try: try:
q_weight = self.get_tensor(f"{prefix}.q_weight") q_weight = self.get_tensor(f"{prefix}.q_weight")
except RuntimeError: except RuntimeError:
@ -390,7 +394,11 @@ class Weights:
# it would require to reorder input activations that are split unto several GPUs # it would require to reorder input activations that are split unto several GPUs
use_exllama = False use_exllama = False
from text_generation_server.layers.gptq import HAS_EXLLAMA, CAN_EXLLAMA from text_generation_server.layers.gptq import (
HAS_EXLLAMA,
CAN_EXLLAMA,
GPTQWeight,
)
if use_exllama: if use_exllama:
if not HAS_EXLLAMA: if not HAS_EXLLAMA:
@ -442,6 +450,8 @@ class Weights:
use_exllama=use_exllama, use_exllama=use_exllama,
) )
elif quantize == "awq": elif quantize == "awq":
from text_generation_server.layers.gptq import GPTQWeight
bits, groupsize, _, _ = self._get_gptq_params() bits, groupsize, _, _ = self._get_gptq_params()
try: try: