feat(server): Support for env value for GPTQ_BITS and GPTQ_GROUPSIZE. (#580)

# What does this PR do?

Some models are already converted, and do not have those values in the
file, this enables users to use them with less friction.

Went for pure env based because adding flags would end up (imo) very
tedious to maintain. There's a lot of sanitation to do: those flags
would be errors if not used in conjuction with `--quantize gptq`.
Then the flags need to exist in the launcher and the server passing them
all throughout all function calls.

This PR is intended as an easy escape hatch, not the defacto method to
use gptq in TGI.

Fixes #500
This commit is contained in:
Nicolas Patry 2023-07-12 10:00:02 +02:00 committed by GitHub
parent f0181436f4
commit 5bd2ab6583
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 7 deletions

View File

@ -20,12 +20,12 @@ from text_generation_server.utils.layers import (
FastLayerNorm,
get_linear,
)
from safetensors import SafetensorError
def load_multi_mqa(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
):
if config.quantize == "gptq":
return _load_multi_mqa_gptq(
config, prefix, weights, bias, head_size, num_heads, hidden_size
@ -74,8 +74,17 @@ def _load_multi_mqa_gptq(
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
bits = weights.get_tensor("gptq_bits").item()
groupsize = weights.get_tensor("gptq_groupsize").item()
try:
bits = weights.get_tensor("gptq_bits").item()
groupsize = weights.get_tensor("gptq_groupsize").item()
except SafetensorError as e:
try:
import os
bits = int(os.getenv("GPTQ_BITS"))
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
except Exception:
raise e
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
@ -99,7 +108,6 @@ def _load_multi_mqa_gptq(
def _load_multi_mqa(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
):
if any("c_attn" in k for k in weights.routing.keys()):
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
shape = slice_.get_shape()

View File

@ -1,6 +1,6 @@
from pathlib import Path
from typing import List, Dict, Optional
from safetensors import safe_open
from safetensors import safe_open, SafetensorError
import torch
@ -120,8 +120,17 @@ class Weights:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
except SafetensorError as e:
try:
import os
bits = int(os.getenv("GTPQ_BITS"))
groupsize = int(os.getenv("GTPQ_GROUPSIZE"))
except Exception:
raise e
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]