Add AWQ quantization inference support (#1019)
# Add AWQ quantization inference support
Fixes
https://github.com/huggingface/text-generation-inference/issues/781
This PR (partially) adds support for AWQ quantization for inference.
More information on AWQ [here](https://arxiv.org/abs/2306.00978). In
general, AWQ is faster and more accurate than GPTQ, which is currently
supported by TGI.
This PR installs 4-bit GEMM custom CUDA kernels released by AWQ authors
(in `requirements.txt`, just one line change).
Quick way to test this PR would be bring up TGI as follows:
```
text-generation-server download-weights abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq
text-generation-launcher \
--huggingface-hub-cache ~/.cache/huggingface/hub/ \
--model-id abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq \
--trust-remote-code --port 8080 \
--max-input-length 2048 --max-total-tokens 4096 --max-batch-prefill-tokens 4096 \
--quantize awq
```
Please note:
* This PR was tested with FlashAttention v2 and vLLM.
* This PR adds support for AWQ inference, not quantizing the models.
That needs to be done outside of TGI, instructions
[here](f084f40bd9
).
* This PR only adds support for `FlashLlama` models for now.
* Multi-GPU setup has not been tested.
* No integration tests have been added so far, will add later if
maintainers are interested in this change.
* This PR can be tested on any of the models released
[here](https://huggingface.co/abhinavkulkarni?sort_models=downloads#models).
Please refer to the linked issue for benchmarks for
[abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq](https://huggingface.co/abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq)
vs
[TheBloke/Llama-2-7b-Chat-GPTQ](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GPTQ).
Please note, AWQ has released faster (and in case of Llama, fused)
kernels for 4-bit GEMM, currently at the top of the `main` branch at
https://github.com/mit-han-lab/llm-awq, but this PR uses an older commit
that has been tested to work. We can switch to latest commit later on.
## Who can review?
@OlivierDehaene OR @Narsil
---------
Co-authored-by: Abhinav Kulkarni <abhinav@concentric.ai>
This commit is contained in:
parent
123749a3c9
commit
c35f39cf83
|
@ -4,7 +4,7 @@ Text Generation Inference improves the model in several aspects.
|
|||
|
||||
## Quantization
|
||||
|
||||
TGI supports [bits-and-bytes](https://github.com/TimDettmers/bitsandbytes#bitsandbytes) and [GPT-Q](https://arxiv.org/abs/2210.17323) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes` or `gptq` depending on the quantization technique you wish to use. When using GPT-Q quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq). To get more information about quantization, please refer to (./conceptual/quantization.md)
|
||||
TGI supports [bits-and-bytes](https://github.com/TimDettmers/bitsandbytes#bitsandbytes), [GPT-Q](https://arxiv.org/abs/2210.17323) and [AWQ](https://arxiv.org/abs/2306.00978) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes`, `gptq` or `awq` depending on the quantization technique you wish to use. When using GPT-Q quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq) when using AWQ quantization, you need to point to one of the models [here](https://huggingface.co/models?search=awq). To get more information about quantization, please refer to (./conceptual/quantization.md)
|
||||
|
||||
|
||||
## RoPE Scaling
|
||||
|
|
|
@ -25,6 +25,7 @@ enum Quantization {
|
|||
BitsandbytesNF4,
|
||||
BitsandbytesFP4,
|
||||
Gptq,
|
||||
Awq,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Quantization {
|
||||
|
@ -43,6 +44,9 @@ impl std::fmt::Display for Quantization {
|
|||
Quantization::Gptq => {
|
||||
write!(f, "gptq")
|
||||
}
|
||||
Quantization::Awq => {
|
||||
write!(f, "awq")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -73,3 +73,5 @@ win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and
|
|||
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
xxhash==3.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
# Custom 4-bit GEMM AWQ kernels
|
||||
git+https://github.com/mit-han-lab/llm-awq.git@f084f40bd996f3cf3a0633c1ad7d9d476c318aaa#subdirectory=awq/kernels
|
||||
|
|
|
@ -17,6 +17,7 @@ class Quantization(str, Enum):
|
|||
bitsandbytes_nf4 = "bitsandbytes-nf4"
|
||||
bitsandbytes_fp4 = "bitsandbytes-fp4"
|
||||
gptq = "gptq"
|
||||
awq = "awq"
|
||||
|
||||
|
||||
class Dtype(str, Enum):
|
||||
|
|
|
@ -268,6 +268,10 @@ def get_model(
|
|||
raise ValueError(
|
||||
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
if quantize == "awq":
|
||||
raise ValueError(
|
||||
"awq quantization is not supported for AutoModel"
|
||||
)
|
||||
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
|
||||
raise ValueError(
|
||||
"4bit quantization is not supported for AutoModel"
|
||||
|
|
|
@ -62,7 +62,7 @@ class FlashLlama(FlashCausalLM):
|
|||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize == "gptq":
|
||||
if config.quantize in ["gptq", "awq"]:
|
||||
weights._set_gptq_params(model_id)
|
||||
|
||||
model = FlashLlamaForCausalLM(config, weights)
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import awq_inference_engine # with CUDA kernels
|
||||
|
||||
|
||||
class ScaledActivation(nn.Module):
|
||||
def __init__(self, module, scales):
|
||||
super().__init__()
|
||||
self.act = module
|
||||
self.scales = nn.Parameter(scales.data)
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
|
||||
|
||||
|
||||
class WQLinear(nn.Module):
|
||||
def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias):
|
||||
super().__init__()
|
||||
|
||||
if w_bit not in [4]:
|
||||
raise NotImplementedError("Only 4-bit are supported for now.")
|
||||
|
||||
self.in_features = qweight.shape[0]
|
||||
self.out_features = qweight.shape[1] * 32 // w_bit
|
||||
|
||||
self.w_bit = w_bit
|
||||
self.group_size = group_size if group_size != -1 else self.in_features
|
||||
# quick sanity check (make sure aligment)
|
||||
assert self.in_features % self.group_size == 0
|
||||
assert self.out_features % (32 // self.w_bit) == 0
|
||||
|
||||
self.register_buffer('qweight', qweight)
|
||||
self.register_buffer('qzeros', qzeros)
|
||||
self.register_buffer('scales', scales)
|
||||
if bias:
|
||||
self.register_buffer('bias', bias)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x):
|
||||
out_shape = x.shape[:-1] + (self.out_features, )
|
||||
out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8)
|
||||
out = out + self.bias if self.bias is not None else out
|
||||
return out.reshape(out_shape)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return 'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'.format(
|
||||
self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size
|
||||
)
|
|
@ -17,6 +17,7 @@ except ImportError:
|
|||
from accelerate import init_empty_weights
|
||||
|
||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||
from text_generation_server.utils.awq.quantize.qmodule import WQLinear
|
||||
|
||||
try:
|
||||
major, _minor = torch.cuda.get_device_capability()
|
||||
|
@ -248,6 +249,14 @@ def get_linear(weight, bias, quantize):
|
|||
bits,
|
||||
groupsize,
|
||||
)
|
||||
elif quantize == "awq":
|
||||
try:
|
||||
qweight, qzeros, scales, _, bits, groupsize, _ = weight
|
||||
except Exception:
|
||||
raise NotImplementedError(
|
||||
f"The passed weight is not `awq` compatible, loader needs to be updated."
|
||||
)
|
||||
linear = WQLinear(w_bit=bits, group_size=groupsize, qweight=qweight, qzeros=qzeros, scales=scales, bias=bias is not None)
|
||||
else:
|
||||
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
||||
return linear
|
||||
|
@ -283,8 +292,8 @@ class TensorParallelHead(SuperLayer):
|
|||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
should_gather = False
|
||||
|
||||
# GPTQ doesn't quantize heads (nor embeddings)
|
||||
if config.quantize == "gptq":
|
||||
# GPTQ and AWQ don't quantize heads (nor embeddings)
|
||||
if config.quantize in ["gptq", "awq"]:
|
||||
quantize = None
|
||||
else:
|
||||
quantize = config.quantize
|
||||
|
|
|
@ -135,18 +135,26 @@ class Weights:
|
|||
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
|
||||
already alternating Q,K,V within the main tensor
|
||||
"""
|
||||
if quantize == "gptq":
|
||||
if quantize in ["gptq", "awq"]:
|
||||
try:
|
||||
qweight = self._get_qweight(f"{prefix}.qweight")
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
if quantize == "gptq":
|
||||
raise RuntimeError(
|
||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Cannot load `awq` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
qzeros = self._get_qweight(f"{prefix}.qzeros")
|
||||
scales = self._get_qweight(f"{prefix}.scales")
|
||||
scales = scales.to(dtype=self.dtype)
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
try:
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
except RuntimeError:
|
||||
g_idx = None
|
||||
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||
|
@ -171,15 +179,20 @@ class Weights:
|
|||
return weight
|
||||
|
||||
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
||||
if quantize == "gptq":
|
||||
if quantize in ["gptq", "awq"]:
|
||||
try:
|
||||
qweight = torch.cat(
|
||||
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
if quantize == "gptq":
|
||||
raise RuntimeError(
|
||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Cannot load `awq` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
qzeros = torch.cat(
|
||||
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||
|
@ -187,10 +200,14 @@ class Weights:
|
|||
scales = torch.cat(
|
||||
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
|
||||
try:
|
||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
except RuntimeError:
|
||||
g_idx = None
|
||||
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||
|
@ -216,7 +233,7 @@ class Weights:
|
|||
return tensor
|
||||
|
||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||
if quantize == "gptq":
|
||||
if quantize in "gptq":
|
||||
use_exllama = True
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
|
||||
|
@ -282,6 +299,20 @@ class Weights:
|
|||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||
elif quantize == "awq":
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
|
||||
try:
|
||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `awq` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
|
||||
weight = (qweight, qzeros, scales, None, bits, groupsize, None)
|
||||
else:
|
||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||
return weight
|
||||
|
|
Loading…
Reference in New Issue