add ipex moe implementation to support Mixtral and PhiMoe
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
98330df65e
commit
d9a8bbc183
|
@ -186,7 +186,7 @@ RUN pip install triton py-libnuma
|
||||||
|
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout f86e93e4890dc2c989024d148d415c9aa8a1649f
|
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout 2e1c98f74ec1b35ad8dd1ebe7dd4b25470f2fd41
|
||||||
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout v2.4.0+cpu+rc0
|
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout v2.4.0+cpu+rc0
|
||||||
|
|
||||||
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install
|
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install
|
||||||
|
|
|
@ -29,6 +29,8 @@ if SYSTEM == "rocm":
|
||||||
from vllm.model_executor.layers.fused_moe import fused_topk
|
from vllm.model_executor.layers.fused_moe import fused_topk
|
||||||
elif SYSTEM != "ipex":
|
elif SYSTEM != "ipex":
|
||||||
from moe_kernels.fused_moe import fused_topk, grouped_topk
|
from moe_kernels.fused_moe import fused_topk, grouped_topk
|
||||||
|
else:
|
||||||
|
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||||
|
|
||||||
|
|
||||||
# NOTE: we are using a protocol here, because multiple inherance is not nice.
|
# NOTE: we are using a protocol here, because multiple inherance is not nice.
|
||||||
|
@ -140,6 +142,10 @@ class DenseMoELayer(nn.Module):
|
||||||
)
|
)
|
||||||
for i in range(self.n_experts)
|
for i in range(self.n_experts)
|
||||||
]
|
]
|
||||||
|
if SYSTEM == "ipex":
|
||||||
|
self.ipex_fused_moe = GatedMLPMOE(
|
||||||
|
W13=self.gate_proj, W2=self.down_proj, W3=self.up_proj, use_prepack=True
|
||||||
|
)
|
||||||
|
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
|
@ -152,6 +158,17 @@ class DenseMoELayer(nn.Module):
|
||||||
input_shape = x.shape
|
input_shape = x.shape
|
||||||
x = x.view(-1, input_shape[-1])
|
x = x.view(-1, input_shape[-1])
|
||||||
|
|
||||||
|
if SYSTEM == "ipex":
|
||||||
|
return self.ipex_fused_moe(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=gating_output,
|
||||||
|
top_k=self.topk,
|
||||||
|
renormalize=self.renormalize,
|
||||||
|
use_grouped_topk=self.n_expert_group is not None,
|
||||||
|
num_expert_group=self.n_expert_group,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
)
|
||||||
|
|
||||||
if self.n_expert_group is not None and self.topk_group is not None:
|
if self.n_expert_group is not None and self.topk_group is not None:
|
||||||
topk_weights, topk_ids = grouped_topk(
|
topk_weights, topk_ids = grouped_topk(
|
||||||
x,
|
x,
|
||||||
|
|
|
@ -10,6 +10,8 @@ if SYSTEM == "rocm":
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
elif SYSTEM != "ipex":
|
elif SYSTEM != "ipex":
|
||||||
from moe_kernels.fused_moe import fused_moe
|
from moe_kernels.fused_moe import fused_moe
|
||||||
|
else:
|
||||||
|
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||||
|
|
||||||
|
|
||||||
class UnquantizedSparseMoELayer(nn.Module):
|
class UnquantizedSparseMoELayer(nn.Module):
|
||||||
|
@ -52,6 +54,10 @@ class UnquantizedSparseMoELayer(nn.Module):
|
||||||
name=down_proj_name,
|
name=down_proj_name,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
if SYSTEM == "ipex":
|
||||||
|
self.ipex_fused_moe = GatedMLPMOE(
|
||||||
|
W13=self.gate_up_proj, W2=self.down_proj, use_prepack=True
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
|
@ -64,6 +70,16 @@ class UnquantizedSparseMoELayer(nn.Module):
|
||||||
renormalize=self.renormalize,
|
renormalize=self.renormalize,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
)
|
)
|
||||||
|
elif SYSTEM == "ipex":
|
||||||
|
return self.ipex_fused_moe(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=gating_output,
|
||||||
|
top_k=self.topk,
|
||||||
|
renormalize=self.renormalize,
|
||||||
|
use_grouped_topk=self.n_expert_group is not None,
|
||||||
|
num_expert_group=self.n_expert_group,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
)
|
||||||
|
|
||||||
return fused_moe(
|
return fused_moe(
|
||||||
x,
|
x,
|
||||||
|
|
|
@ -25,6 +25,8 @@ from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
if SYSTEM != "ipex":
|
if SYSTEM != "ipex":
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
|
else:
|
||||||
|
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||||
|
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
|
@ -488,19 +490,35 @@ class BlockSparseMoE(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
|
if SYSTEM == "ipex":
|
||||||
|
self.ipex_fused_moe = GatedMLPMOE(
|
||||||
|
W13=self.wv1, W2=self.w2, use_prepack=True
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.gate(x)
|
router_logits = self.gate(x)
|
||||||
out = fused_moe(
|
|
||||||
x,
|
if SYSTEM == "ipex":
|
||||||
self.wv1,
|
out = self.ipex_fused_moe(
|
||||||
self.w2,
|
hidden_states=x,
|
||||||
router_logits,
|
router_logits=router_logits,
|
||||||
self.top_k,
|
top_k=self.top_k,
|
||||||
renormalize=self.moe_normalize_expert_weights,
|
renormalize=self.moe_normalize_expert_weights,
|
||||||
inplace=True,
|
use_grouped_topk=False,
|
||||||
)
|
num_expert_group=None,
|
||||||
|
topk_group=None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out = fused_moe(
|
||||||
|
x,
|
||||||
|
self.wv1,
|
||||||
|
self.w2,
|
||||||
|
router_logits,
|
||||||
|
self.top_k,
|
||||||
|
renormalize=self.moe_normalize_expert_weights,
|
||||||
|
inplace=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Reduce sum
|
# Reduce sum
|
||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1:
|
||||||
|
|
Loading…
Reference in New Issue