From a5ecd6e586d94ecac46a814e23c7fa7cfd518c21 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Tue, 19 Nov 2024 00:16:55 +0800 Subject: [PATCH] add ipex moe implementation to support Mixtral and PhiMoe (#2707) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add ipex moe implementation to support Mixtral and PhiMoe Signed-off-by: Wang, Yi A * update to ipex xpu 2.5 Signed-off-by: Wang, Yi A * torch has xpu support in 2.5 Signed-off-by: Wang, Yi A * fix oneapi basekit version Signed-off-by: Wang, Yi A * Apply suggestions from code review Co-authored-by: Daniël de Kok --------- Signed-off-by: Wang, Yi A Co-authored-by: Daniël de Kok --- Dockerfile_intel | 22 ++++++++---- .../layers/moe/__init__.py | 19 +++++++++- .../layers/moe/unquantized.py | 16 +++++++++ .../text_generation_server/models/__init__.py | 4 ++- .../custom_modeling/flash_dbrx_modeling.py | 36 ++++++++++++++----- 5 files changed, 80 insertions(+), 17 deletions(-) diff --git a/Dockerfile_intel b/Dockerfile_intel index c3555eab..ea38b081 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -83,7 +83,11 @@ RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dea RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \ | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list -RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y intel-basekit xpu-smi cmake ninja-build pciutils +RUN echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/intel-for-pytorch-gpu-dev all main" > /tmp/intel-for-pytorch-gpu-dev.list + +RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y intel-basekit=2024.2.1-98 xpu-smi cmake ninja-build pciutils intel-pti-dev-0.9 # Text Generation Inference base env ENV HF_HOME=/data \ @@ -91,8 +95,14 @@ ENV HF_HOME=/data \ PORT=80 + WORKDIR /usr/src -RUN pip install torch==2.3.1+cxx11.abi torchvision==0.18.1+cxx11.abi torchaudio==2.3.1+cxx11.abi intel-extension-for-pytorch==2.3.110+xpu oneccl_bind_pt==2.3.100+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --no-cache-dir +RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torch-2.5.0a0%2Bgite84e33f-cp311-cp311-linux_x86_64.whl --no-cache-dir +RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torchaudio-2.5.0a0%2B56bc006-cp311-cp311-linux_x86_64.whl --no-cache-dir +RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torchvision-0.20.0a0%2B8e8a208-cp311-cp311-linux_x86_64.whl --no-cache-dir +RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.5.10%2Bgit9d489a8-cp311-cp311-linux_x86_64.whl --no-cache-dir +RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/oneccl_bind_pt-2.5.0%2Bxpu-cp311-cp311-linux_x86_64.whl --no-cache-dir + RUN pip install triton-xpu==3.0.0b2 --no-cache-dir # Install server @@ -108,13 +118,13 @@ ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest ENV I_MPI_ROOT=/opt/intel/oneapi/mpi/latest ENV FI_PROVIDER_PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib/prov:/usr/lib/x86_64-linux-gnu/libfabric ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mkl/latest/lib/:/opt/intel/oneapi/compiler/latest/lib -ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:/opt/conda/lib +ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:/opt/intel/oneapi/pti/0.9/lib:/opt/conda/lib ENV PATH=/opt/conda/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV CCL_ZE_IPC_EXCHANGE=sockets ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include -ENV TORCH_LLM_ALLREDUCE=1 -ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0 +#ENV TORCH_LLM_ALLREDUCE=1 +#ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0 # Install benchmarker COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark @@ -187,7 +197,7 @@ RUN pip install triton py-libnuma WORKDIR /usr/src -RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout f86e93e4890dc2c989024d148d415c9aa8a1649f +RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout 2e1c98f74ec1b35ad8dd1ebe7dd4b25470f2fd41 RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout v2.4.0+cpu+rc0 RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install diff --git a/server/text_generation_server/layers/moe/__init__.py b/server/text_generation_server/layers/moe/__init__.py index 558d9ed9..a5ae7ff4 100644 --- a/server/text_generation_server/layers/moe/__init__.py +++ b/server/text_generation_server/layers/moe/__init__.py @@ -27,7 +27,9 @@ from text_generation_server.utils.weights import ( if SYSTEM == "rocm": from .fused_moe_rocm import grouped_topk from vllm.model_executor.layers.fused_moe import fused_topk -elif SYSTEM != "ipex": +elif SYSTEM == "ipex": + from intel_extension_for_pytorch.llm.modules import GatedMLPMOE +else: from moe_kernels.fused_moe import fused_topk, grouped_topk @@ -140,6 +142,10 @@ class DenseMoELayer(nn.Module): ) for i in range(self.n_experts) ] + if SYSTEM == "ipex": + self.ipex_fused_moe = GatedMLPMOE( + W13=self.gate_proj, W2=self.down_proj, W3=self.up_proj, use_prepack=True + ) self.process_group = weights.process_group @@ -152,6 +158,17 @@ class DenseMoELayer(nn.Module): input_shape = x.shape x = x.view(-1, input_shape[-1]) + if SYSTEM == "ipex": + return self.ipex_fused_moe( + hidden_states=x, + router_logits=gating_output, + top_k=self.topk, + renormalize=self.renormalize, + use_grouped_topk=self.n_expert_group is not None, + num_expert_group=self.n_expert_group, + topk_group=self.topk_group, + ) + if self.n_expert_group is not None and self.topk_group is not None: topk_weights, topk_ids = grouped_topk( x, diff --git a/server/text_generation_server/layers/moe/unquantized.py b/server/text_generation_server/layers/moe/unquantized.py index d9d62c0e..3d6a0b99 100644 --- a/server/text_generation_server/layers/moe/unquantized.py +++ b/server/text_generation_server/layers/moe/unquantized.py @@ -10,6 +10,8 @@ if SYSTEM == "rocm": from vllm.model_executor.layers.fused_moe import fused_moe elif SYSTEM != "ipex": from moe_kernels.fused_moe import fused_moe +else: + from intel_extension_for_pytorch.llm.modules import GatedMLPMOE class UnquantizedSparseMoELayer(nn.Module): @@ -52,6 +54,10 @@ class UnquantizedSparseMoELayer(nn.Module): name=down_proj_name, weights=weights, ) + if SYSTEM == "ipex": + self.ipex_fused_moe = GatedMLPMOE( + W13=self.gate_up_proj, W2=self.down_proj, use_prepack=True + ) def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: if SYSTEM == "rocm": @@ -64,6 +70,16 @@ class UnquantizedSparseMoELayer(nn.Module): renormalize=self.renormalize, inplace=True, ) + elif SYSTEM == "ipex": + return self.ipex_fused_moe( + hidden_states=x, + router_logits=gating_output, + top_k=self.topk, + renormalize=self.renormalize, + use_grouped_topk=self.n_expert_group is not None, + num_expert_group=self.n_expert_group, + topk_group=self.topk_group, + ) return fused_moe( x, diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index c6e406c9..89164577 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -390,7 +390,9 @@ def get_model( if dtype is None: if quantize in ["awq", "exl2", "gptq", "marlin"]: - if SYSTEM == "ipex" and not hasattr(torch, "xpu"): + if SYSTEM == "ipex" and not ( + hasattr(torch, "xpu") and torch.xpu.is_available() + ): dtype = torch.bfloat16 else: # These quantizers only work with float16 params. diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 57118362..b8041671 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -27,6 +27,8 @@ if SYSTEM == "rocm": from vllm.model_executor.layers.fused_moe import fused_moe elif SYSTEM != "ipex": from moe_kernels.fused_moe import fused_moe +else: + from intel_extension_for_pytorch.llm.modules import GatedMLPMOE from text_generation_server.layers.attention import ( paged_attention, @@ -490,19 +492,35 @@ class BlockSparseMoE(nn.Module): ) self.process_group = weights.process_group + if SYSTEM == "ipex": + self.ipex_fused_moe = GatedMLPMOE( + W13=self.wv1, W2=self.w2, use_prepack=True + ) def forward(self, x: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits = self.gate(x) - out = fused_moe( - x, - self.wv1, - self.w2, - router_logits, - self.top_k, - renormalize=self.moe_normalize_expert_weights, - inplace=True, - ) + + if SYSTEM == "ipex": + out = self.ipex_fused_moe( + hidden_states=x, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.moe_normalize_expert_weights, + use_grouped_topk=False, + num_expert_group=None, + topk_group=None, + ) + else: + out = fused_moe( + x, + self.wv1, + self.w2, + router_logits, + self.top_k, + renormalize=self.moe_normalize_expert_weights, + inplace=True, + ) # Reduce sum if self.process_group.size() > 1: