feat(server): Support for env value for GPTQ_BITS and GPTQ_GROUPSIZE. (#580)
# What does this PR do? Some models are already converted, and do not have those values in the file, this enables users to use them with less friction. Went for pure env based because adding flags would end up (imo) very tedious to maintain. There's a lot of sanitation to do: those flags would be errors if not used in conjuction with `--quantize gptq`. Then the flags need to exist in the launcher and the server passing them all throughout all function calls. This PR is intended as an easy escape hatch, not the defacto method to use gptq in TGI. Fixes #500
This commit is contained in:
parent
f0181436f4
commit
5bd2ab6583
|
@ -20,12 +20,12 @@ from text_generation_server.utils.layers import (
|
|||
FastLayerNorm,
|
||||
get_linear,
|
||||
)
|
||||
from safetensors import SafetensorError
|
||||
|
||||
|
||||
def load_multi_mqa(
|
||||
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||
):
|
||||
|
||||
if config.quantize == "gptq":
|
||||
return _load_multi_mqa_gptq(
|
||||
config, prefix, weights, bias, head_size, num_heads, hidden_size
|
||||
|
@ -74,8 +74,17 @@ def _load_multi_mqa_gptq(
|
|||
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||
|
||||
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
||||
bits = weights.get_tensor("gptq_bits").item()
|
||||
groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||
try:
|
||||
bits = weights.get_tensor("gptq_bits").item()
|
||||
groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||
except SafetensorError as e:
|
||||
try:
|
||||
import os
|
||||
|
||||
bits = int(os.getenv("GPTQ_BITS"))
|
||||
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
||||
|
||||
|
@ -99,7 +108,6 @@ def _load_multi_mqa_gptq(
|
|||
def _load_multi_mqa(
|
||||
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||
):
|
||||
|
||||
if any("c_attn" in k for k in weights.routing.keys()):
|
||||
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
|
||||
shape = slice_.get_shape()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
from safetensors import safe_open
|
||||
from safetensors import safe_open, SafetensorError
|
||||
import torch
|
||||
|
||||
|
||||
|
@ -120,8 +120,17 @@ class Weights:
|
|||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
try:
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
except SafetensorError as e:
|
||||
try:
|
||||
import os
|
||||
|
||||
bits = int(os.getenv("GTPQ_BITS"))
|
||||
groupsize = int(os.getenv("GTPQ_GROUPSIZE"))
|
||||
except Exception:
|
||||
raise e
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
||||
else:
|
||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
|
|
Loading…
Reference in New Issue