fix(server): Bug fixes for GPTQ_BITS environment variable passthrough (#590)

# What does this PR do?

This fixes a typo and extends the GPTP_BITS environment variables
through to the second method which requires the same logic. Please let
me know if there's anything I've misunderstood in this change.

Thanks @Narsil for the original fix.
This commit is contained in:
ssmi153 2023-07-12 20:17:35 +08:00 committed by GitHub
parent 7f9072228a
commit 2c4bf88268
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 13 additions and 4 deletions

View File

@ -127,8 +127,8 @@ class Weights:
try:
import os
bits = int(os.getenv("GTPQ_BITS"))
groupsize = int(os.getenv("GTPQ_GROUPSIZE"))
bits = int(os.getenv("GPTQ_BITS"))
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
except Exception:
raise e
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
@ -149,8 +149,17 @@ class Weights:
scales = self.get_tensor(f"{prefix}.scales")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
except SafetensorError as e:
try:
import os
bits = int(os.getenv("GPTQ_BITS"))
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
except Exception:
raise e
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
else: