CI job. Gpt awq 4 (#2665)

* add gptq and awq int4 support in intel platform

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix ci failure

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* set kv cache dtype

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* refine the code according to the review command

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* Simplifying conditionals + reverting integration tests values.

* Unused import

* Fix redundant import.

* Revert change after rebase.

* Upgrading the tests (TP>1 fix changes to use different kernels.)

* Update server/text_generation_server/layers/gptq/__init__.py

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Nicolas Patry 2024-10-18 17:55:53 +02:00 committed by GitHub
parent 8ec57558cd
commit 153ff3740b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 226 additions and 38 deletions

View File

@ -11,57 +11,57 @@
}, },
{ {
"id": 3226, "id": 3226,
"logprob": -8.9453125, "logprob": -9.0234375,
"text": " ge" "text": " ge"
}, },
{ {
"id": 21017, "id": 21017,
"logprob": -8.8515625, "logprob": -9.0859375,
"text": "ometric" "text": "ometric"
}, },
{ {
"id": 81, "id": 81,
"logprob": -0.21875, "logprob": -0.25585938,
"text": "_" "text": "_"
}, },
{ {
"id": 6009, "id": 6009,
"logprob": -1.2773438, "logprob": -2.1972656,
"text": "mean" "text": "mean"
}, },
{ {
"id": 26, "id": 26,
"logprob": -0.25195312, "logprob": -0.2998047,
"text": "(" "text": "("
}, },
{ {
"id": 62, "id": 62,
"logprob": -4.8203125, "logprob": -5.6445312,
"text": "L" "text": "L"
}, },
{ {
"id": 44, "id": 44,
"logprob": -3.7734375, "logprob": -3.0839844,
"text": ":" "text": ":"
}, },
{ {
"id": 1682, "id": 1682,
"logprob": -0.8310547, "logprob": -0.6748047,
"text": " List" "text": " List"
}, },
{ {
"id": 77, "id": 77,
"logprob": -0.22766113, "logprob": -0.3864746,
"text": "[" "text": "["
}, },
{ {
"id": 1808, "id": 1808,
"logprob": -0.46240234, "logprob": -0.9355469,
"text": "float" "text": "float"
}, },
{ {
"id": 10794, "id": 10794,
"logprob": -3.0234375, "logprob": -2.5371094,
"text": "]):" "text": "]):"
} }
], ],
@ -69,7 +69,7 @@
"tokens": [ "tokens": [
{ {
"id": 284, "id": 284,
"logprob": -0.04626465, "logprob": -1.1679688,
"special": false, "special": false,
"text": "\n " "text": "\n "
}, },

View File

@ -11,57 +11,57 @@
}, },
{ {
"id": 3226, "id": 3226,
"logprob": -8.9453125, "logprob": -9.015625,
"text": " ge" "text": " ge"
}, },
{ {
"id": 21017, "id": 21017,
"logprob": -8.859375, "logprob": -9.0859375,
"text": "ometric" "text": "ometric"
}, },
{ {
"id": 81, "id": 81,
"logprob": -0.21984863, "logprob": -0.25585938,
"text": "_" "text": "_"
}, },
{ {
"id": 6009, "id": 6009,
"logprob": -1.2861328, "logprob": -2.2304688,
"text": "mean" "text": "mean"
}, },
{ {
"id": 26, "id": 26,
"logprob": -0.25219727, "logprob": -0.29760742,
"text": "(" "text": "("
}, },
{ {
"id": 62, "id": 62,
"logprob": -4.8007812, "logprob": -5.6796875,
"text": "L" "text": "L"
}, },
{ {
"id": 44, "id": 44,
"logprob": -3.7949219, "logprob": -3.0742188,
"text": ":" "text": ":"
}, },
{ {
"id": 1682, "id": 1682,
"logprob": -0.8046875, "logprob": -0.67626953,
"text": " List" "text": " List"
}, },
{ {
"id": 77, "id": 77,
"logprob": -0.22424316, "logprob": -0.38842773,
"text": "[" "text": "["
}, },
{ {
"id": 1808, "id": 1808,
"logprob": -0.46191406, "logprob": -0.9165039,
"text": "float" "text": "float"
}, },
{ {
"id": 10794, "id": 10794,
"logprob": -3.0253906, "logprob": -2.5527344,
"text": "]):" "text": "]):"
} }
], ],
@ -69,7 +69,7 @@
"tokens": [ "tokens": [
{ {
"id": 284, "id": 284,
"logprob": 0.0, "logprob": -0.048583984,
"special": false, "special": false,
"text": "\n " "text": "\n "
}, },

View File

@ -0,0 +1,8 @@
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "ipex":
from .ipex import WQLinear
elif SYSTEM == "cuda":
from .cuda import WQLinear
__all__ = ["WQLinear"]

View File

@ -0,0 +1,48 @@
from typing import Optional
import torch
import torch.nn as nn
import intel_extension_for_pytorch as ipex
class WQLinear(nn.Module):
def __init__(
self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]
):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.in_features = qweight.shape[0]
self.out_features = qweight.shape[1] * 32 // w_bit
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else self.in_features
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert self.out_features % (32 // self.w_bit) == 0
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
self.bias = bias
self.woq_linear = (
ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight(
self.qweight,
self.scales,
self.qzeros,
self.in_features,
self.out_features,
bias=self.bias,
group_size=self.group_size,
quant_method=ipex.llm.quantization.QuantMethod.AWQ_GEMM,
dtype=ipex.llm.quantization.QuantDtype.INT4,
)
)
@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features,)
out = self.woq_linear(x.reshape(-1, x.shape[-1]))
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)

View File

@ -8,6 +8,11 @@ from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
if SYSTEM == "ipex":
from .ipex import QuantLinear
elif SYSTEM == "cuda":
from .cuda import QuantLinear
@dataclass @dataclass
class GPTQWeight(Weight): class GPTQWeight(Weight):
@ -36,7 +41,7 @@ class GPTQWeight(Weight):
"to use Exllama/GPTQ kernels for AWQ inference." "to use Exllama/GPTQ kernels for AWQ inference."
) )
try: try:
from text_generation_server.layers.awq.quantize.qmodule import WQLinear from text_generation_server.layers.awq.quantize import WQLinear
return WQLinear( return WQLinear(
w_bit=self.bits, w_bit=self.bits,
@ -60,8 +65,6 @@ class GPTQWeight(Weight):
return ExllamaQuantLinear(self, bias) return ExllamaQuantLinear(self, bias)
else: else:
from text_generation_server.layers.gptq.quant_linear import QuantLinear
return QuantLinear( return QuantLinear(
self.qweight, self.qweight,
self.qzeros, self.qzeros,
@ -298,6 +301,7 @@ class GPTQWeightsLoader(WeightsLoader):
self._get_gptq_params(weights) self._get_gptq_params(weights)
use_exllama = True use_exllama = True
desc_act = self.desc_act
if self.bits != 4: if self.bits != 4:
use_exllama = False use_exllama = False
@ -321,7 +325,8 @@ class GPTQWeightsLoader(WeightsLoader):
if g_idx is not None: if g_idx is not None:
if ( if (
not torch.equal( not torch.equal(
g_idx.cpu(), # Remove g_idx[0] to adapt the check with TP>1.
(g_idx - g_idx[0]).cpu(),
torch.tensor( torch.tensor(
[i // self.groupsize for i in range(g_idx.shape[0])], [i // self.groupsize for i in range(g_idx.shape[0])],
dtype=torch.int32, dtype=torch.int32,
@ -332,6 +337,7 @@ class GPTQWeightsLoader(WeightsLoader):
# Exllama implementation does not support row tensor parallelism with act-order, as # Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs # it would require to reorder input activations that are split unto several GPUs
use_exllama = False use_exllama = False
desc_act = True
from text_generation_server.layers.gptq import ( from text_generation_server.layers.gptq import (
CAN_EXLLAMA, CAN_EXLLAMA,
@ -350,16 +356,16 @@ class GPTQWeightsLoader(WeightsLoader):
else: else:
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
if use_exllama and self.groupsize != -1: if not desc_act and self.groupsize != -1:
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0) qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
scales = weights.get_sharded(f"{prefix}.scales", dim=0) scales = weights.get_sharded(f"{prefix}.scales", dim=0)
if g_idx is not None:
# qzeros, scales sharded, and g_idx must be adjusted accordingly
g_idx = g_idx - g_idx[0]
else: else:
qzeros = weights.get_tensor(f"{prefix}.qzeros") qzeros = weights.get_tensor(f"{prefix}.qzeros")
scales = weights.get_tensor(f"{prefix}.scales") scales = weights.get_tensor(f"{prefix}.scales")
if use_exllama and g_idx is not None:
g_idx = g_idx - g_idx[0]
if self.quantize == "gptq" and self.quant_method == "awq": if self.quantize == "gptq" and self.quant_method == "awq":
log_once( log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format." logger.info, "Converting AWQ model to Exllama/GPTQ packing format."

View File

@ -0,0 +1,126 @@
import math
import numpy as np
import torch
import torch.nn as nn
import intel_extension_for_pytorch as ipex
class QuantLinear(nn.Module):
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
super().__init__()
self.register_buffer("qweight", qweight)
self.register_buffer("qzeros", qzeros)
self.register_buffer("scales", scales)
self.register_buffer("g_idx", g_idx)
if bias is not None:
self.register_buffer("bias", bias)
else:
self.bias = None
if bits not in [4]:
raise NotImplementedError("Only 4 bits are supported.")
self.bits = bits
self.maxq = 2**self.bits - 1
self.groupsize = groupsize
self.outfeatures = qweight.shape[1]
self.infeatures = qweight.shape[0] * 32 // bits
self.woq_linear = (
ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight(
self.qweight,
self.scales,
self.qzeros,
self.infeatures,
self.outfeatures,
bias=self.bias,
group_size=self.groupsize,
g_idx=g_idx,
quant_method=ipex.llm.quantization.QuantMethod.GPTQ_GEMM,
dtype=ipex.llm.quantization.QuantDtype.INT4,
)
)
@classmethod
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
if bits not in [4]:
raise NotImplementedError("Only 4 bits are supported.")
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
qzeros = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
dtype=torch.int32,
)
scales = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
)
g_idx = torch.tensor(
[i // groupsize for i in range(infeatures)], dtype=torch.int32
)
if bias:
bias = torch.zeros((outfeatures), dtype=torch.float16)
else:
bias = None
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
def pack(self, linear, scales, zeros, g_idx=None):
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
intweight = []
for idx in range(self.infeatures):
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
/ self.scales[self.g_idx[idx]]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
qweight = np.zeros(
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
)
i = 0
row = 0
while row < qweight.shape[0]:
if self.bits in [4]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
else:
raise NotImplementedError("Only 4 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros(
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [4]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
else:
raise NotImplementedError("Only 4 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
out = self.woq_linear(x.reshape(-1, x.shape[-1]))
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)

View File

@ -12,7 +12,7 @@ from huggingface_hub import HfApi
from accelerate import init_empty_weights from accelerate import init_empty_weights
from text_generation_server.utils import initialize_torch_distributed, Weights from text_generation_server.utils import initialize_torch_distributed, Weights
from text_generation_server.utils.hub import weight_files from text_generation_server.utils.hub import weight_files
from text_generation_server.layers.gptq.quant_linear import QuantLinear from text_generation_server.layers.gptq import QuantLinear
from loguru import logger from loguru import logger
from typing import Optional from typing import Optional
from text_generation_server.layers.gptq.utils import torch_snr_error from text_generation_server.layers.gptq.utils import torch_snr_error

View File

@ -400,8 +400,11 @@ def get_model(
if dtype is None: if dtype is None:
if quantize in ["awq", "exl2", "gptq", "marlin"]: if quantize in ["awq", "exl2", "gptq", "marlin"]:
# These quantizers only work with float16 params. if SYSTEM == "ipex" and not hasattr(torch, "xpu"):
dtype = torch.float16 dtype = torch.bfloat16
else:
# These quantizers only work with float16 params.
dtype = torch.float16
elif quantize == "fp8": elif quantize == "fp8":
from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE

View File

@ -1122,7 +1122,6 @@ class FlashCausalLM(Model):
dtype = default_dtype if dtype is None else dtype dtype = default_dtype if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
init_cpu_threads_env(rank_id=rank, world_size=world_size) init_cpu_threads_env(rank_id=rank, world_size=world_size)
else: else:
@ -1602,8 +1601,6 @@ class FlashCausalLM(Model):
max_s = batch.max_current_length max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
print(slots)
if cu_seqlen_prefill is None and self.max_past() is not None: if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache # In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode. # in a circular buffer mode.