From 5bd2ab65839af4f54ba59e60ef3dd262004de456 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 12 Jul 2023 10:00:02 +0200 Subject: [PATCH] feat(server): Support for env value for GPTQ_BITS and GPTQ_GROUPSIZE. (#580) # What does this PR do? Some models are already converted, and do not have those values in the file, this enables users to use them with less friction. Went for pure env based because adding flags would end up (imo) very tedious to maintain. There's a lot of sanitation to do: those flags would be errors if not used in conjuction with `--quantize gptq`. Then the flags need to exist in the launcher and the server passing them all throughout all function calls. This PR is intended as an easy escape hatch, not the defacto method to use gptq in TGI. Fixes #500 --- .../custom_modeling/flash_santacoder_modeling.py | 16 ++++++++++++---- server/text_generation_server/utils/weights.py | 15 ++++++++++++--- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 5b1c6e2..a19623a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -20,12 +20,12 @@ from text_generation_server.utils.layers import ( FastLayerNorm, get_linear, ) +from safetensors import SafetensorError def load_multi_mqa( config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size ): - if config.quantize == "gptq": return _load_multi_mqa_gptq( config, prefix, weights, bias, head_size, num_heads, hidden_size @@ -74,8 +74,17 @@ def _load_multi_mqa_gptq( qzeros = torch.cat([q_tensor, kv_tensor], dim=1) g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") - bits = weights.get_tensor("gptq_bits").item() - groupsize = weights.get_tensor("gptq_groupsize").item() + try: + bits = weights.get_tensor("gptq_bits").item() + groupsize = weights.get_tensor("gptq_groupsize").item() + except SafetensorError as e: + try: + import os + + bits = int(os.getenv("GPTQ_BITS")) + groupsize = int(os.getenv("GPTQ_GROUPSIZE")) + except Exception: + raise e weight = (qweight, qzeros, scales, g_idx, bits, groupsize) @@ -99,7 +108,6 @@ def _load_multi_mqa_gptq( def _load_multi_mqa( config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size ): - if any("c_attn" in k for k in weights.routing.keys()): slice_ = weights._get_slice(f"{prefix}.c_attn.weight") shape = slice_.get_shape() diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 83d9df6..39f6686 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,6 +1,6 @@ from pathlib import Path from typing import List, Dict, Optional -from safetensors import safe_open +from safetensors import safe_open, SafetensorError import torch @@ -120,8 +120,17 @@ class Weights: torch.testing.assert_close(w2, w[0]) g_idx = w[0] - bits = self.get_tensor("gptq_bits").item() - groupsize = self.get_tensor("gptq_groupsize").item() + try: + bits = self.get_tensor("gptq_bits").item() + groupsize = self.get_tensor("gptq_groupsize").item() + except SafetensorError as e: + try: + import os + + bits = int(os.getenv("GTPQ_BITS")) + groupsize = int(os.getenv("GTPQ_GROUPSIZE")) + except Exception: + raise e weight = (qweight, qzeros, scales, g_idx, bits, groupsize) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]