add ipex moe implementation to support Mixtral and PhiMoe (#2707)

* add ipex moe implementation to support Mixtral and PhiMoe

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* update to ipex xpu 2.5

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* torch has xpu support in 2.5

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix oneapi basekit version

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* Apply suggestions from code review

Co-authored-by: Daniël de Kok <me@github.danieldk.eu>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Daniël de Kok <me@github.danieldk.eu>
This commit is contained in:
Wang, Yi 2024-11-19 00:16:55 +08:00 committed by GitHub
parent fea62e928f
commit a5ecd6e586
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 80 additions and 17 deletions

View File

@ -83,7 +83,11 @@ RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dea
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y intel-basekit xpu-smi cmake ninja-build pciutils RUN echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/intel-for-pytorch-gpu-dev all main" > /tmp/intel-for-pytorch-gpu-dev.list
RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y intel-basekit=2024.2.1-98 xpu-smi cmake ninja-build pciutils intel-pti-dev-0.9
# Text Generation Inference base env # Text Generation Inference base env
ENV HF_HOME=/data \ ENV HF_HOME=/data \
@ -91,8 +95,14 @@ ENV HF_HOME=/data \
PORT=80 PORT=80
WORKDIR /usr/src WORKDIR /usr/src
RUN pip install torch==2.3.1+cxx11.abi torchvision==0.18.1+cxx11.abi torchaudio==2.3.1+cxx11.abi intel-extension-for-pytorch==2.3.110+xpu oneccl_bind_pt==2.3.100+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --no-cache-dir RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torch-2.5.0a0%2Bgite84e33f-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torchaudio-2.5.0a0%2B56bc006-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torchvision-0.20.0a0%2B8e8a208-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.5.10%2Bgit9d489a8-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/oneccl_bind_pt-2.5.0%2Bxpu-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install triton-xpu==3.0.0b2 --no-cache-dir RUN pip install triton-xpu==3.0.0b2 --no-cache-dir
# Install server # Install server
@ -108,13 +118,13 @@ ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest
ENV I_MPI_ROOT=/opt/intel/oneapi/mpi/latest ENV I_MPI_ROOT=/opt/intel/oneapi/mpi/latest
ENV FI_PROVIDER_PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib/prov:/usr/lib/x86_64-linux-gnu/libfabric ENV FI_PROVIDER_PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib/prov:/usr/lib/x86_64-linux-gnu/libfabric
ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mkl/latest/lib/:/opt/intel/oneapi/compiler/latest/lib ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mkl/latest/lib/:/opt/intel/oneapi/compiler/latest/lib
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:/opt/conda/lib ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:/opt/intel/oneapi/pti/0.9/lib:/opt/conda/lib
ENV PATH=/opt/conda/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV PATH=/opt/conda/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
ENV CCL_ZE_IPC_EXCHANGE=sockets ENV CCL_ZE_IPC_EXCHANGE=sockets
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include
ENV TORCH_LLM_ALLREDUCE=1 #ENV TORCH_LLM_ALLREDUCE=1
ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0 #ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
@ -187,7 +197,7 @@ RUN pip install triton py-libnuma
WORKDIR /usr/src WORKDIR /usr/src
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout f86e93e4890dc2c989024d148d415c9aa8a1649f RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout 2e1c98f74ec1b35ad8dd1ebe7dd4b25470f2fd41
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout v2.4.0+cpu+rc0 RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout v2.4.0+cpu+rc0
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install

View File

@ -27,7 +27,9 @@ from text_generation_server.utils.weights import (
if SYSTEM == "rocm": if SYSTEM == "rocm":
from .fused_moe_rocm import grouped_topk from .fused_moe_rocm import grouped_topk
from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe import fused_topk
elif SYSTEM != "ipex": elif SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
else:
from moe_kernels.fused_moe import fused_topk, grouped_topk from moe_kernels.fused_moe import fused_topk, grouped_topk
@ -140,6 +142,10 @@ class DenseMoELayer(nn.Module):
) )
for i in range(self.n_experts) for i in range(self.n_experts)
] ]
if SYSTEM == "ipex":
self.ipex_fused_moe = GatedMLPMOE(
W13=self.gate_proj, W2=self.down_proj, W3=self.up_proj, use_prepack=True
)
self.process_group = weights.process_group self.process_group = weights.process_group
@ -152,6 +158,17 @@ class DenseMoELayer(nn.Module):
input_shape = x.shape input_shape = x.shape
x = x.view(-1, input_shape[-1]) x = x.view(-1, input_shape[-1])
if SYSTEM == "ipex":
return self.ipex_fused_moe(
hidden_states=x,
router_logits=gating_output,
top_k=self.topk,
renormalize=self.renormalize,
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
)
if self.n_expert_group is not None and self.topk_group is not None: if self.n_expert_group is not None and self.topk_group is not None:
topk_weights, topk_ids = grouped_topk( topk_weights, topk_ids = grouped_topk(
x, x,

View File

@ -10,6 +10,8 @@ if SYSTEM == "rocm":
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
elif SYSTEM != "ipex": elif SYSTEM != "ipex":
from moe_kernels.fused_moe import fused_moe from moe_kernels.fused_moe import fused_moe
else:
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
class UnquantizedSparseMoELayer(nn.Module): class UnquantizedSparseMoELayer(nn.Module):
@ -52,6 +54,10 @@ class UnquantizedSparseMoELayer(nn.Module):
name=down_proj_name, name=down_proj_name,
weights=weights, weights=weights,
) )
if SYSTEM == "ipex":
self.ipex_fused_moe = GatedMLPMOE(
W13=self.gate_up_proj, W2=self.down_proj, use_prepack=True
)
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
if SYSTEM == "rocm": if SYSTEM == "rocm":
@ -64,6 +70,16 @@ class UnquantizedSparseMoELayer(nn.Module):
renormalize=self.renormalize, renormalize=self.renormalize,
inplace=True, inplace=True,
) )
elif SYSTEM == "ipex":
return self.ipex_fused_moe(
hidden_states=x,
router_logits=gating_output,
top_k=self.topk,
renormalize=self.renormalize,
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
)
return fused_moe( return fused_moe(
x, x,

View File

@ -390,7 +390,9 @@ def get_model(
if dtype is None: if dtype is None:
if quantize in ["awq", "exl2", "gptq", "marlin"]: if quantize in ["awq", "exl2", "gptq", "marlin"]:
if SYSTEM == "ipex" and not hasattr(torch, "xpu"): if SYSTEM == "ipex" and not (
hasattr(torch, "xpu") and torch.xpu.is_available()
):
dtype = torch.bfloat16 dtype = torch.bfloat16
else: else:
# These quantizers only work with float16 params. # These quantizers only work with float16 params.

View File

@ -27,6 +27,8 @@ if SYSTEM == "rocm":
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
elif SYSTEM != "ipex": elif SYSTEM != "ipex":
from moe_kernels.fused_moe import fused_moe from moe_kernels.fused_moe import fused_moe
else:
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
@ -490,19 +492,35 @@ class BlockSparseMoE(nn.Module):
) )
self.process_group = weights.process_group self.process_group = weights.process_group
if SYSTEM == "ipex":
self.ipex_fused_moe = GatedMLPMOE(
W13=self.wv1, W2=self.w2, use_prepack=True
)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(x) router_logits = self.gate(x)
out = fused_moe(
x, if SYSTEM == "ipex":
self.wv1, out = self.ipex_fused_moe(
self.w2, hidden_states=x,
router_logits, router_logits=router_logits,
self.top_k, top_k=self.top_k,
renormalize=self.moe_normalize_expert_weights, renormalize=self.moe_normalize_expert_weights,
inplace=True, use_grouped_topk=False,
) num_expert_group=None,
topk_group=None,
)
else:
out = fused_moe(
x,
self.wv1,
self.w2,
router_logits,
self.top_k,
renormalize=self.moe_normalize_expert_weights,
inplace=True,
)
# Reduce sum # Reduce sum
if self.process_group.size() > 1: if self.process_group.size() > 1: