add gptq and awq int4 support in intel platform

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2024-08-21 22:47:34 -07:00
parent e4201f44cf
commit 0b02d45a05
5 changed files with 82 additions and 27 deletions

View File

@ -103,12 +103,21 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
curl \
ca-certificates \
make \
g++ \
g++-12 \
gcc-12 \
git \
wget \
cmake \
libnuma-dev
RUN update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-12 12
RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 12
RUN update-alternatives --install /usr/bin/cc cc /usr/bin/gcc 30
RUN update-alternatives --set cc /usr/bin/gcc
RUN update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++ 30
RUN update-alternatives --set c++ /usr/bin/g++
ENV HUGGINGFACE_HUB_CACHE=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80
@ -133,21 +142,19 @@ RUN chmod +x ~/mambaforge.sh && \
RUN conda install -c conda-forge gperftools mkl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.5.0.dev20240815%2Bcpu-cp310-cp310-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.20.0.dev20240815%2Bcpu-cp310-cp310-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240815%2Bcpu-cp310-cp310-linux_x86_64.whl
RUN pip install triton numa
WORKDIR /usr/src
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout eda7a7c42df6f9a64e0de9c2b69304ee02f2c32a
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout ccl_torch_dev_0131
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout f86e93e4890dc2c989024d148d415c9aa8a1649f
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout v2.4.0+cpu+rc0
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install
RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install .
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so
ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch

View File

@ -3,7 +3,12 @@
from typing import Optional
import torch
import torch.nn as nn
import awq_inference_engine # with CUDA kernels
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
else:
import awq_inference_engine # with CUDA kernels
# class ScaledActivation(nn.Module):
@ -38,12 +43,29 @@ class WQLinear(nn.Module):
self.qzeros = qzeros
self.scales = scales
self.bias = bias
if SYSTEM == "ipex":
self.woq_linear = (
ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight(
self.qweight,
self.scales,
self.qzeros,
self.in_features,
self.out_features,
bias=self.bias,
group_size=self.group_size,
quant_method=ipex.llm.quantization.QuantMethod.AWQ_GEMM,
dtype=ipex.llm.quantization.QuantDtype.INT4,
)
)
@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features,)
out = awq_inference_engine.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
)
if SYSTEM == "ipex":
out = self.woq_linear(x.reshape(-1, x.shape[-1]))
else:
out = awq_inference_engine.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)

View File

@ -298,6 +298,7 @@ class GPTQWeightsLoader(WeightsLoader):
self._get_gptq_params(weights)
use_exllama = True
desc_act = self.desc_act
if self.bits != 4:
use_exllama = False
@ -321,7 +322,7 @@ class GPTQWeightsLoader(WeightsLoader):
if g_idx is not None:
if (
not torch.equal(
g_idx.cpu(),
(g_idx - g_idx[0]).cpu(),
torch.tensor(
[i // self.groupsize for i in range(g_idx.shape[0])],
dtype=torch.int32,
@ -332,6 +333,7 @@ class GPTQWeightsLoader(WeightsLoader):
# Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs
use_exllama = False
desc_act = True
from text_generation_server.layers.gptq import (
CAN_EXLLAMA,
@ -350,16 +352,16 @@ class GPTQWeightsLoader(WeightsLoader):
else:
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
if use_exllama and self.groupsize != -1:
if not desc_act and self.groupsize != -1:
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
if g_idx is not None:
# qzeros, scales sharded, and g_idx must be adjusted accordingly
g_idx = g_idx - g_idx[0]
else:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
scales = weights.get_tensor(f"{prefix}.scales")
if use_exllama and g_idx is not None:
g_idx = g_idx - g_idx[0]
if self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."

View File

@ -7,6 +7,10 @@ from torch.cuda.amp import custom_fwd
import triton
import triton.language as tl
from . import custom_autotune
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
# code based https://github.com/fpgaminer/GPTQ-triton
@ -264,6 +268,21 @@ class QuantLinear(nn.Module):
self.outfeatures = qweight.shape[1]
self.infeatures = qweight.shape[0] * 32 // bits
if SYSTEM == "ipex" and bits == 4:
self.woq_linear = (
ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight(
self.qweight,
self.scales,
self.qzeros,
self.infeatures,
self.outfeatures,
bias=self.bias,
group_size=self.groupsize,
g_idx=g_idx,
quant_method=ipex.llm.quantization.QuantMethod.GPTQ_GEMM,
dtype=ipex.llm.quantization.QuantDtype.INT4,
)
)
@classmethod
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
@ -346,14 +365,17 @@ class QuantLinear(nn.Module):
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
out = QuantLinearFunction.apply(
x.reshape(-1, x.shape[-1]),
self.qweight,
self.scales,
self.qzeros,
self.g_idx,
self.bits,
self.maxq,
)
if SYSTEM == "ipex" and self.bits == 4:
out = self.woq_linear(x.reshape(-1, x.shape[-1]))
else:
out = QuantLinearFunction.apply(
x.reshape(-1, x.shape[-1]),
self.qweight,
self.scales,
self.qzeros,
self.g_idx,
self.bits,
self.maxq,
)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)

View File

@ -885,6 +885,8 @@ class FlashCausalLM(Model):
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
if quantize in ["awq", "exl2", "gptq", "marlin"]:
dtype = torch.bfloat16
init_cpu_threads_env(rank_id=rank, world_size=world_size)
else:
raise NotImplementedError(f"{model_class} is only available on GPU")