From 2c4bf882689fda969c84dfc86c4158032de10e1b Mon Sep 17 00:00:00 2001 From: ssmi153 <129111316+ssmi153@users.noreply.github.com> Date: Wed, 12 Jul 2023 20:17:35 +0800 Subject: [PATCH] fix(server): Bug fixes for GPTQ_BITS environment variable passthrough (#590) # What does this PR do? This fixes a typo and extends the GPTP_BITS environment variables through to the second method which requires the same logic. Please let me know if there's anything I've misunderstood in this change. Thanks @Narsil for the original fix. --- server/text_generation_server/utils/weights.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 39f66862..4f300fe7 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -127,8 +127,8 @@ class Weights: try: import os - bits = int(os.getenv("GTPQ_BITS")) - groupsize = int(os.getenv("GTPQ_GROUPSIZE")) + bits = int(os.getenv("GPTQ_BITS")) + groupsize = int(os.getenv("GPTQ_GROUPSIZE")) except Exception: raise e weight = (qweight, qzeros, scales, g_idx, bits, groupsize) @@ -149,8 +149,17 @@ class Weights: scales = self.get_tensor(f"{prefix}.scales") g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - bits = self.get_tensor("gptq_bits").item() - groupsize = self.get_tensor("gptq_groupsize").item() + try: + bits = self.get_tensor("gptq_bits").item() + groupsize = self.get_tensor("gptq_groupsize").item() + except SafetensorError as e: + try: + import os + + bits = int(os.getenv("GPTQ_BITS")) + groupsize = int(os.getenv("GPTQ_GROUPSIZE")) + except Exception: + raise e weight = (qweight, qzeros, scales, g_idx, bits, groupsize) else: