add ipex moe implementation to support Mixtral and PhiMoe

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2024-10-29 19:18:53 -07:00
parent 98330df65e
commit d9a8bbc183
4 changed files with 61 additions and 10 deletions

View File

@ -186,7 +186,7 @@ RUN pip install triton py-libnuma
WORKDIR /usr/src
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout f86e93e4890dc2c989024d148d415c9aa8a1649f
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout 2e1c98f74ec1b35ad8dd1ebe7dd4b25470f2fd41
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout v2.4.0+cpu+rc0
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install

View File

@ -29,6 +29,8 @@ if SYSTEM == "rocm":
from vllm.model_executor.layers.fused_moe import fused_topk
elif SYSTEM != "ipex":
from moe_kernels.fused_moe import fused_topk, grouped_topk
else:
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
# NOTE: we are using a protocol here, because multiple inherance is not nice.
@ -140,6 +142,10 @@ class DenseMoELayer(nn.Module):
)
for i in range(self.n_experts)
]
if SYSTEM == "ipex":
self.ipex_fused_moe = GatedMLPMOE(
W13=self.gate_proj, W2=self.down_proj, W3=self.up_proj, use_prepack=True
)
self.process_group = weights.process_group
@ -152,6 +158,17 @@ class DenseMoELayer(nn.Module):
input_shape = x.shape
x = x.view(-1, input_shape[-1])
if SYSTEM == "ipex":
return self.ipex_fused_moe(
hidden_states=x,
router_logits=gating_output,
top_k=self.topk,
renormalize=self.renormalize,
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
)
if self.n_expert_group is not None and self.topk_group is not None:
topk_weights, topk_ids = grouped_topk(
x,

View File

@ -10,6 +10,8 @@ if SYSTEM == "rocm":
from vllm.model_executor.layers.fused_moe import fused_moe
elif SYSTEM != "ipex":
from moe_kernels.fused_moe import fused_moe
else:
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
class UnquantizedSparseMoELayer(nn.Module):
@ -52,6 +54,10 @@ class UnquantizedSparseMoELayer(nn.Module):
name=down_proj_name,
weights=weights,
)
if SYSTEM == "ipex":
self.ipex_fused_moe = GatedMLPMOE(
W13=self.gate_up_proj, W2=self.down_proj, use_prepack=True
)
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
if SYSTEM == "rocm":
@ -64,6 +70,16 @@ class UnquantizedSparseMoELayer(nn.Module):
renormalize=self.renormalize,
inplace=True,
)
elif SYSTEM == "ipex":
return self.ipex_fused_moe(
hidden_states=x,
router_logits=gating_output,
top_k=self.topk,
renormalize=self.renormalize,
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
)
return fused_moe(
x,

View File

@ -25,6 +25,8 @@ from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM != "ipex":
from vllm.model_executor.layers.fused_moe import fused_moe
else:
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
from text_generation_server.layers.attention import (
paged_attention,
@ -488,19 +490,35 @@ class BlockSparseMoE(nn.Module):
)
self.process_group = weights.process_group
if SYSTEM == "ipex":
self.ipex_fused_moe = GatedMLPMOE(
W13=self.wv1, W2=self.w2, use_prepack=True
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(x)
out = fused_moe(
x,
self.wv1,
self.w2,
router_logits,
self.top_k,
renormalize=self.moe_normalize_expert_weights,
inplace=True,
)
if SYSTEM == "ipex":
out = self.ipex_fused_moe(
hidden_states=x,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.moe_normalize_expert_weights,
use_grouped_topk=False,
num_expert_group=None,
topk_group=None,
)
else:
out = fused_moe(
x,
self.wv1,
self.w2,
router_logits,
self.top_k,
renormalize=self.moe_normalize_expert_weights,
inplace=True,
)
# Reduce sum
if self.process_group.size() > 1: