fix(server): Bug fixes for GPTQ_BITS environment variable passthrough (#590)
# What does this PR do? This fixes a typo and extends the GPTP_BITS environment variables through to the second method which requires the same logic. Please let me know if there's anything I've misunderstood in this change. Thanks @Narsil for the original fix.
This commit is contained in:
parent
7f9072228a
commit
2c4bf88268
|
@ -127,8 +127,8 @@ class Weights:
|
||||||
try:
|
try:
|
||||||
import os
|
import os
|
||||||
|
|
||||||
bits = int(os.getenv("GTPQ_BITS"))
|
bits = int(os.getenv("GPTQ_BITS"))
|
||||||
groupsize = int(os.getenv("GTPQ_GROUPSIZE"))
|
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
|
||||||
except Exception:
|
except Exception:
|
||||||
raise e
|
raise e
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
||||||
|
@ -149,8 +149,17 @@ class Weights:
|
||||||
scales = self.get_tensor(f"{prefix}.scales")
|
scales = self.get_tensor(f"{prefix}.scales")
|
||||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||||
|
|
||||||
bits = self.get_tensor("gptq_bits").item()
|
try:
|
||||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
bits = self.get_tensor("gptq_bits").item()
|
||||||
|
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||||
|
except SafetensorError as e:
|
||||||
|
try:
|
||||||
|
import os
|
||||||
|
|
||||||
|
bits = int(os.getenv("GPTQ_BITS"))
|
||||||
|
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
|
||||||
|
except Exception:
|
||||||
|
raise e
|
||||||
|
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue