diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 4bae8cc0..2f330d9c 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -212,7 +212,9 @@ class Weights: g_idx = None bits, groupsize = self._get_gptq_params() - weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) + from text_generation_server.utils.layers import HAS_EXLLAMA + use_exllama = bits==4 and HAS_EXLLAMA and quantize == "gptq" + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] weight = torch.cat(w, dim=dim)