From 0b02d45a053fc61c7026b0a39075127553f87b8b Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Wed, 21 Aug 2024 22:47:34 -0700 Subject: [PATCH] add gptq and awq int4 support in intel platform Signed-off-by: Wang, Yi A --- Dockerfile_intel | 25 +++++++----- .../layers/awq/quantize/qmodule.py | 30 ++++++++++++-- .../layers/gptq/__init__.py | 12 +++--- .../layers/gptq/quant_linear.py | 40 ++++++++++++++----- .../models/flash_causal_lm.py | 2 + 5 files changed, 82 insertions(+), 27 deletions(-) diff --git a/Dockerfile_intel b/Dockerfile_intel index 12480c70..088f5307 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -103,12 +103,21 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins curl \ ca-certificates \ make \ - g++ \ + g++-12 \ + gcc-12 \ git \ wget \ cmake \ libnuma-dev +RUN update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-12 12 +RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 12 +RUN update-alternatives --install /usr/bin/cc cc /usr/bin/gcc 30 +RUN update-alternatives --set cc /usr/bin/gcc + +RUN update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++ 30 +RUN update-alternatives --set c++ /usr/bin/g++ + ENV HUGGINGFACE_HUB_CACHE=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ PORT=80 @@ -133,21 +142,19 @@ RUN chmod +x ~/mambaforge.sh && \ RUN conda install -c conda-forge gperftools mkl -RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl -RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl -RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl +RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.5.0.dev20240815%2Bcpu-cp310-cp310-linux_x86_64.whl +RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.20.0.dev20240815%2Bcpu-cp310-cp310-linux_x86_64.whl +RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240815%2Bcpu-cp310-cp310-linux_x86_64.whl RUN pip install triton numa WORKDIR /usr/src -RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout eda7a7c42df6f9a64e0de9c2b69304ee02f2c32a - -RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout ccl_torch_dev_0131 - +RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout f86e93e4890dc2c989024d148d415c9aa8a1649f +RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout v2.4.0+cpu+rc0 RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install - RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install . + ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch diff --git a/server/text_generation_server/layers/awq/quantize/qmodule.py b/server/text_generation_server/layers/awq/quantize/qmodule.py index 391371a5..d59b1f18 100644 --- a/server/text_generation_server/layers/awq/quantize/qmodule.py +++ b/server/text_generation_server/layers/awq/quantize/qmodule.py @@ -3,7 +3,12 @@ from typing import Optional import torch import torch.nn as nn -import awq_inference_engine # with CUDA kernels +from text_generation_server.utils.import_utils import SYSTEM + +if SYSTEM == "ipex": + import intel_extension_for_pytorch as ipex +else: + import awq_inference_engine # with CUDA kernels # class ScaledActivation(nn.Module): @@ -38,12 +43,29 @@ class WQLinear(nn.Module): self.qzeros = qzeros self.scales = scales self.bias = bias + if SYSTEM == "ipex": + self.woq_linear = ( + ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight( + self.qweight, + self.scales, + self.qzeros, + self.in_features, + self.out_features, + bias=self.bias, + group_size=self.group_size, + quant_method=ipex.llm.quantization.QuantMethod.AWQ_GEMM, + dtype=ipex.llm.quantization.QuantDtype.INT4, + ) + ) @torch.no_grad() def forward(self, x): out_shape = x.shape[:-1] + (self.out_features,) - out = awq_inference_engine.gemm_forward_cuda( - x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8 - ) + if SYSTEM == "ipex": + out = self.woq_linear(x.reshape(-1, x.shape[-1])) + else: + out = awq_inference_engine.gemm_forward_cuda( + x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8 + ) out = out + self.bias if self.bias is not None else out return out.reshape(out_shape) diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 505caa59..13c07223 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -298,6 +298,7 @@ class GPTQWeightsLoader(WeightsLoader): self._get_gptq_params(weights) use_exllama = True + desc_act = self.desc_act if self.bits != 4: use_exllama = False @@ -321,7 +322,7 @@ class GPTQWeightsLoader(WeightsLoader): if g_idx is not None: if ( not torch.equal( - g_idx.cpu(), + (g_idx - g_idx[0]).cpu(), torch.tensor( [i // self.groupsize for i in range(g_idx.shape[0])], dtype=torch.int32, @@ -332,6 +333,7 @@ class GPTQWeightsLoader(WeightsLoader): # Exllama implementation does not support row tensor parallelism with act-order, as # it would require to reorder input activations that are split unto several GPUs use_exllama = False + desc_act = True from text_generation_server.layers.gptq import ( CAN_EXLLAMA, @@ -350,16 +352,16 @@ class GPTQWeightsLoader(WeightsLoader): else: log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") - if use_exllama and self.groupsize != -1: + if not desc_act and self.groupsize != -1: qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0) scales = weights.get_sharded(f"{prefix}.scales", dim=0) + if g_idx is not None: + # qzeros, scales sharded, and g_idx must be adjusted accordingly + g_idx = g_idx - g_idx[0] else: qzeros = weights.get_tensor(f"{prefix}.qzeros") scales = weights.get_tensor(f"{prefix}.scales") - if use_exllama and g_idx is not None: - g_idx = g_idx - g_idx[0] - if self.quantize == "gptq" and self.quant_method == "awq": log_once( logger.info, "Converting AWQ model to Exllama/GPTQ packing format." diff --git a/server/text_generation_server/layers/gptq/quant_linear.py b/server/text_generation_server/layers/gptq/quant_linear.py index 736c357b..9dc7615e 100644 --- a/server/text_generation_server/layers/gptq/quant_linear.py +++ b/server/text_generation_server/layers/gptq/quant_linear.py @@ -7,6 +7,10 @@ from torch.cuda.amp import custom_fwd import triton import triton.language as tl from . import custom_autotune +from text_generation_server.utils.import_utils import SYSTEM + +if SYSTEM == "ipex": + import intel_extension_for_pytorch as ipex # code based https://github.com/fpgaminer/GPTQ-triton @@ -264,6 +268,21 @@ class QuantLinear(nn.Module): self.outfeatures = qweight.shape[1] self.infeatures = qweight.shape[0] * 32 // bits + if SYSTEM == "ipex" and bits == 4: + self.woq_linear = ( + ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight( + self.qweight, + self.scales, + self.qzeros, + self.infeatures, + self.outfeatures, + bias=self.bias, + group_size=self.groupsize, + g_idx=g_idx, + quant_method=ipex.llm.quantization.QuantMethod.GPTQ_GEMM, + dtype=ipex.llm.quantization.QuantDtype.INT4, + ) + ) @classmethod def new(cls, bits, groupsize, infeatures, outfeatures, bias): @@ -346,14 +365,17 @@ class QuantLinear(nn.Module): def forward(self, x): out_shape = x.shape[:-1] + (self.outfeatures,) - out = QuantLinearFunction.apply( - x.reshape(-1, x.shape[-1]), - self.qweight, - self.scales, - self.qzeros, - self.g_idx, - self.bits, - self.maxq, - ) + if SYSTEM == "ipex" and self.bits == 4: + out = self.woq_linear(x.reshape(-1, x.shape[-1])) + else: + out = QuantLinearFunction.apply( + x.reshape(-1, x.shape[-1]), + self.qweight, + self.scales, + self.qzeros, + self.g_idx, + self.bits, + self.maxq, + ) out = out + self.bias if self.bias is not None else out return out.reshape(out_shape) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5e2fd20a..19b9d554 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -885,6 +885,8 @@ class FlashCausalLM(Model): device = torch.device("cpu") # Float16 doesn't exist on target. dtype = torch.bfloat16 if dtype is None else dtype + if quantize in ["awq", "exl2", "gptq", "marlin"]: + dtype = torch.bfloat16 init_cpu_threads_env(rank_id=rank, world_size=world_size) else: raise NotImplementedError(f"{model_class} is only available on GPU")