From d9a8bbc18346872f33a4f79a68ad2a3f5c103ab5 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 29 Oct 2024 19:18:53 -0700 Subject: [PATCH] add ipex moe implementation to support Mixtral and PhiMoe Signed-off-by: Wang, Yi A --- Dockerfile_intel | 2 +- .../layers/moe/__init__.py | 17 +++++++++ .../layers/moe/unquantized.py | 16 +++++++++ .../custom_modeling/flash_dbrx_modeling.py | 36 ++++++++++++++----- 4 files changed, 61 insertions(+), 10 deletions(-) diff --git a/Dockerfile_intel b/Dockerfile_intel index 96f24248..d04d6329 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -186,7 +186,7 @@ RUN pip install triton py-libnuma WORKDIR /usr/src -RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout f86e93e4890dc2c989024d148d415c9aa8a1649f +RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout 2e1c98f74ec1b35ad8dd1ebe7dd4b25470f2fd41 RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout v2.4.0+cpu+rc0 RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install diff --git a/server/text_generation_server/layers/moe/__init__.py b/server/text_generation_server/layers/moe/__init__.py index 558d9ed9..f528fcd0 100644 --- a/server/text_generation_server/layers/moe/__init__.py +++ b/server/text_generation_server/layers/moe/__init__.py @@ -29,6 +29,8 @@ if SYSTEM == "rocm": from vllm.model_executor.layers.fused_moe import fused_topk elif SYSTEM != "ipex": from moe_kernels.fused_moe import fused_topk, grouped_topk +else: + from intel_extension_for_pytorch.llm.modules import GatedMLPMOE # NOTE: we are using a protocol here, because multiple inherance is not nice. @@ -140,6 +142,10 @@ class DenseMoELayer(nn.Module): ) for i in range(self.n_experts) ] + if SYSTEM == "ipex": + self.ipex_fused_moe = GatedMLPMOE( + W13=self.gate_proj, W2=self.down_proj, W3=self.up_proj, use_prepack=True + ) self.process_group = weights.process_group @@ -152,6 +158,17 @@ class DenseMoELayer(nn.Module): input_shape = x.shape x = x.view(-1, input_shape[-1]) + if SYSTEM == "ipex": + return self.ipex_fused_moe( + hidden_states=x, + router_logits=gating_output, + top_k=self.topk, + renormalize=self.renormalize, + use_grouped_topk=self.n_expert_group is not None, + num_expert_group=self.n_expert_group, + topk_group=self.topk_group, + ) + if self.n_expert_group is not None and self.topk_group is not None: topk_weights, topk_ids = grouped_topk( x, diff --git a/server/text_generation_server/layers/moe/unquantized.py b/server/text_generation_server/layers/moe/unquantized.py index d9d62c0e..3d6a0b99 100644 --- a/server/text_generation_server/layers/moe/unquantized.py +++ b/server/text_generation_server/layers/moe/unquantized.py @@ -10,6 +10,8 @@ if SYSTEM == "rocm": from vllm.model_executor.layers.fused_moe import fused_moe elif SYSTEM != "ipex": from moe_kernels.fused_moe import fused_moe +else: + from intel_extension_for_pytorch.llm.modules import GatedMLPMOE class UnquantizedSparseMoELayer(nn.Module): @@ -52,6 +54,10 @@ class UnquantizedSparseMoELayer(nn.Module): name=down_proj_name, weights=weights, ) + if SYSTEM == "ipex": + self.ipex_fused_moe = GatedMLPMOE( + W13=self.gate_up_proj, W2=self.down_proj, use_prepack=True + ) def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: if SYSTEM == "rocm": @@ -64,6 +70,16 @@ class UnquantizedSparseMoELayer(nn.Module): renormalize=self.renormalize, inplace=True, ) + elif SYSTEM == "ipex": + return self.ipex_fused_moe( + hidden_states=x, + router_logits=gating_output, + top_k=self.topk, + renormalize=self.renormalize, + use_grouped_topk=self.n_expert_group is not None, + num_expert_group=self.n_expert_group, + topk_group=self.topk_group, + ) return fused_moe( x, diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index f70bff4f..5cc486db 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -25,6 +25,8 @@ from text_generation_server.utils.import_utils import SYSTEM if SYSTEM != "ipex": from vllm.model_executor.layers.fused_moe import fused_moe +else: + from intel_extension_for_pytorch.llm.modules import GatedMLPMOE from text_generation_server.layers.attention import ( paged_attention, @@ -488,19 +490,35 @@ class BlockSparseMoE(nn.Module): ) self.process_group = weights.process_group + if SYSTEM == "ipex": + self.ipex_fused_moe = GatedMLPMOE( + W13=self.wv1, W2=self.w2, use_prepack=True + ) def forward(self, x: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits = self.gate(x) - out = fused_moe( - x, - self.wv1, - self.w2, - router_logits, - self.top_k, - renormalize=self.moe_normalize_expert_weights, - inplace=True, - ) + + if SYSTEM == "ipex": + out = self.ipex_fused_moe( + hidden_states=x, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.moe_normalize_expert_weights, + use_grouped_topk=False, + num_expert_group=None, + topk_group=None, + ) + else: + out = fused_moe( + x, + self.wv1, + self.w2, + router_logits, + self.top_k, + renormalize=self.moe_normalize_expert_weights, + inplace=True, + ) # Reduce sum if self.process_group.size() > 1: