From 14980df2dfad7e3b8e6e5686f0b35afc8d4e8812 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 25 Jun 2024 21:09:00 +0200 Subject: [PATCH] Support AWQ quantization with bias (#2117) When the AWQ quantizer was used with a layer that uses a bias, the bias tensor was not correctly passed/used. Instead, the value `true`/`1.0` was added to the linear transformation. Correctly pass through the bias when it is not `None`. Fixes #2106. --- .../layers/awq/quantize/qmodule.py | 10 +++++----- server/text_generation_server/layers/linear.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/layers/awq/quantize/qmodule.py b/server/text_generation_server/layers/awq/quantize/qmodule.py index ca8caf50..c859db1b 100644 --- a/server/text_generation_server/layers/awq/quantize/qmodule.py +++ b/server/text_generation_server/layers/awq/quantize/qmodule.py @@ -1,6 +1,7 @@ # Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py import math +from typing import Optional import torch import torch.nn as nn import awq_inference_engine # with CUDA kernels @@ -17,7 +18,9 @@ import awq_inference_engine # with CUDA kernels class WQLinear(nn.Module): - def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias): + def __init__( + self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor] + ): super().__init__() if w_bit not in [4]: @@ -35,10 +38,7 @@ class WQLinear(nn.Module): self.qweight = qweight self.qzeros = qzeros self.scales = scales - if bias: - self.bias = bias - else: - self.bias = None + self.bias = bias @torch.no_grad() def forward(self, x): diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index d40b192f..207383a5 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -217,7 +217,7 @@ def get_linear(weight, bias, quantize): qweight=weight.qweight, qzeros=weight.qzeros, scales=weight.scales, - bias=bias is not None, + bias=bias, ) except ImportError: raise NotImplementedError(