fix(server): Bug fixes for GPTQ_BITS environment variable passthrough (#590)
# What does this PR do? This fixes a typo and extends the GPTP_BITS environment variables through to the second method which requires the same logic. Please let me know if there's anything I've misunderstood in this change. Thanks @Narsil for the original fix.
This commit is contained in:
parent
7f9072228a
commit
2c4bf88268
|
@ -127,8 +127,8 @@ class Weights:
|
|||
try:
|
||||
import os
|
||||
|
||||
bits = int(os.getenv("GTPQ_BITS"))
|
||||
groupsize = int(os.getenv("GTPQ_GROUPSIZE"))
|
||||
bits = int(os.getenv("GPTQ_BITS"))
|
||||
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
|
||||
except Exception:
|
||||
raise e
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
||||
|
@ -149,8 +149,17 @@ class Weights:
|
|||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
|
||||
try:
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
except SafetensorError as e:
|
||||
try:
|
||||
import os
|
||||
|
||||
bits = int(os.getenv("GPTQ_BITS"))
|
||||
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue