Support AWQ quantization with bias (#2117)
When the AWQ quantizer was used with a layer that uses a bias, the bias tensor was not correctly passed/used. Instead, the value `true`/`1.0` was added to the linear transformation. Correctly pass through the bias when it is not `None`. Fixes #2106.
This commit is contained in:
parent
04e1af94d7
commit
14980df2df
|
@ -1,6 +1,7 @@
|
||||||
# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py
|
# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import awq_inference_engine # with CUDA kernels
|
import awq_inference_engine # with CUDA kernels
|
||||||
|
@ -17,7 +18,9 @@ import awq_inference_engine # with CUDA kernels
|
||||||
|
|
||||||
|
|
||||||
class WQLinear(nn.Module):
|
class WQLinear(nn.Module):
|
||||||
def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias):
|
def __init__(
|
||||||
|
self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if w_bit not in [4]:
|
if w_bit not in [4]:
|
||||||
|
@ -35,10 +38,7 @@ class WQLinear(nn.Module):
|
||||||
self.qweight = qweight
|
self.qweight = qweight
|
||||||
self.qzeros = qzeros
|
self.qzeros = qzeros
|
||||||
self.scales = scales
|
self.scales = scales
|
||||||
if bias:
|
self.bias = bias
|
||||||
self.bias = bias
|
|
||||||
else:
|
|
||||||
self.bias = None
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
|
@ -217,7 +217,7 @@ def get_linear(weight, bias, quantize):
|
||||||
qweight=weight.qweight,
|
qweight=weight.qweight,
|
||||||
qzeros=weight.qzeros,
|
qzeros=weight.qzeros,
|
||||||
scales=weight.scales,
|
scales=weight.scales,
|
||||||
bias=bias is not None,
|
bias=bias,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
|
Loading…
Reference in New Issue