add ipex moe implementation to support Mixtral and PhiMoe
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
98330df65e
commit
d9a8bbc183
|
@ -186,7 +186,7 @@ RUN pip install triton py-libnuma
|
|||
|
||||
WORKDIR /usr/src
|
||||
|
||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout f86e93e4890dc2c989024d148d415c9aa8a1649f
|
||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout 2e1c98f74ec1b35ad8dd1ebe7dd4b25470f2fd41
|
||||
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout v2.4.0+cpu+rc0
|
||||
|
||||
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install
|
||||
|
|
|
@ -29,6 +29,8 @@ if SYSTEM == "rocm":
|
|||
from vllm.model_executor.layers.fused_moe import fused_topk
|
||||
elif SYSTEM != "ipex":
|
||||
from moe_kernels.fused_moe import fused_topk, grouped_topk
|
||||
else:
|
||||
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||
|
||||
|
||||
# NOTE: we are using a protocol here, because multiple inherance is not nice.
|
||||
|
@ -140,6 +142,10 @@ class DenseMoELayer(nn.Module):
|
|||
)
|
||||
for i in range(self.n_experts)
|
||||
]
|
||||
if SYSTEM == "ipex":
|
||||
self.ipex_fused_moe = GatedMLPMOE(
|
||||
W13=self.gate_proj, W2=self.down_proj, W3=self.up_proj, use_prepack=True
|
||||
)
|
||||
|
||||
self.process_group = weights.process_group
|
||||
|
||||
|
@ -152,6 +158,17 @@ class DenseMoELayer(nn.Module):
|
|||
input_shape = x.shape
|
||||
x = x.view(-1, input_shape[-1])
|
||||
|
||||
if SYSTEM == "ipex":
|
||||
return self.ipex_fused_moe(
|
||||
hidden_states=x,
|
||||
router_logits=gating_output,
|
||||
top_k=self.topk,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.n_expert_group is not None,
|
||||
num_expert_group=self.n_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
)
|
||||
|
||||
if self.n_expert_group is not None and self.topk_group is not None:
|
||||
topk_weights, topk_ids = grouped_topk(
|
||||
x,
|
||||
|
|
|
@ -10,6 +10,8 @@ if SYSTEM == "rocm":
|
|||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
elif SYSTEM != "ipex":
|
||||
from moe_kernels.fused_moe import fused_moe
|
||||
else:
|
||||
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||
|
||||
|
||||
class UnquantizedSparseMoELayer(nn.Module):
|
||||
|
@ -52,6 +54,10 @@ class UnquantizedSparseMoELayer(nn.Module):
|
|||
name=down_proj_name,
|
||||
weights=weights,
|
||||
)
|
||||
if SYSTEM == "ipex":
|
||||
self.ipex_fused_moe = GatedMLPMOE(
|
||||
W13=self.gate_up_proj, W2=self.down_proj, use_prepack=True
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||
if SYSTEM == "rocm":
|
||||
|
@ -64,6 +70,16 @@ class UnquantizedSparseMoELayer(nn.Module):
|
|||
renormalize=self.renormalize,
|
||||
inplace=True,
|
||||
)
|
||||
elif SYSTEM == "ipex":
|
||||
return self.ipex_fused_moe(
|
||||
hidden_states=x,
|
||||
router_logits=gating_output,
|
||||
top_k=self.topk,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.n_expert_group is not None,
|
||||
num_expert_group=self.n_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
)
|
||||
|
||||
return fused_moe(
|
||||
x,
|
||||
|
|
|
@ -25,6 +25,8 @@ from text_generation_server.utils.import_utils import SYSTEM
|
|||
|
||||
if SYSTEM != "ipex":
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
else:
|
||||
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
|
@ -488,19 +490,35 @@ class BlockSparseMoE(nn.Module):
|
|||
)
|
||||
|
||||
self.process_group = weights.process_group
|
||||
if SYSTEM == "ipex":
|
||||
self.ipex_fused_moe = GatedMLPMOE(
|
||||
W13=self.wv1, W2=self.w2, use_prepack=True
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(x)
|
||||
out = fused_moe(
|
||||
x,
|
||||
self.wv1,
|
||||
self.w2,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=self.moe_normalize_expert_weights,
|
||||
inplace=True,
|
||||
)
|
||||
|
||||
if SYSTEM == "ipex":
|
||||
out = self.ipex_fused_moe(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.moe_normalize_expert_weights,
|
||||
use_grouped_topk=False,
|
||||
num_expert_group=None,
|
||||
topk_group=None,
|
||||
)
|
||||
else:
|
||||
out = fused_moe(
|
||||
x,
|
||||
self.wv1,
|
||||
self.w2,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=self.moe_normalize_expert_weights,
|
||||
inplace=True,
|
||||
)
|
||||
|
||||
# Reduce sum
|
||||
if self.process_group.size() > 1:
|
||||
|
|
Loading…
Reference in New Issue