add ipex moe implementation to support Mixtral and PhiMoe (#2707)

* add ipex moe implementation to support Mixtral and PhiMoe

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* update to ipex xpu 2.5

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* torch has xpu support in 2.5

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix oneapi basekit version

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* Apply suggestions from code review

Co-authored-by: Daniël de Kok <me@github.danieldk.eu>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Daniël de Kok <me@github.danieldk.eu>
This commit is contained in:
Wang, Yi 2024-11-19 00:16:55 +08:00 committed by GitHub
parent fea62e928f
commit a5ecd6e586
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 80 additions and 17 deletions

View File

@ -83,7 +83,11 @@ RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dea
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y intel-basekit xpu-smi cmake ninja-build pciutils
RUN echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/intel-for-pytorch-gpu-dev all main" > /tmp/intel-for-pytorch-gpu-dev.list
RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y intel-basekit=2024.2.1-98 xpu-smi cmake ninja-build pciutils intel-pti-dev-0.9
# Text Generation Inference base env
ENV HF_HOME=/data \
@ -91,8 +95,14 @@ ENV HF_HOME=/data \
PORT=80
WORKDIR /usr/src
RUN pip install torch==2.3.1+cxx11.abi torchvision==0.18.1+cxx11.abi torchaudio==2.3.1+cxx11.abi intel-extension-for-pytorch==2.3.110+xpu oneccl_bind_pt==2.3.100+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --no-cache-dir
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torch-2.5.0a0%2Bgite84e33f-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torchaudio-2.5.0a0%2B56bc006-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torchvision-0.20.0a0%2B8e8a208-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.5.10%2Bgit9d489a8-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/oneccl_bind_pt-2.5.0%2Bxpu-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install triton-xpu==3.0.0b2 --no-cache-dir
# Install server
@ -108,13 +118,13 @@ ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest
ENV I_MPI_ROOT=/opt/intel/oneapi/mpi/latest
ENV FI_PROVIDER_PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib/prov:/usr/lib/x86_64-linux-gnu/libfabric
ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mkl/latest/lib/:/opt/intel/oneapi/compiler/latest/lib
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:/opt/conda/lib
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:/opt/intel/oneapi/pti/0.9/lib:/opt/conda/lib
ENV PATH=/opt/conda/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
ENV CCL_ZE_IPC_EXCHANGE=sockets
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include
ENV TORCH_LLM_ALLREDUCE=1
ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
#ENV TORCH_LLM_ALLREDUCE=1
#ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
@ -187,7 +197,7 @@ RUN pip install triton py-libnuma
WORKDIR /usr/src
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout f86e93e4890dc2c989024d148d415c9aa8a1649f
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout 2e1c98f74ec1b35ad8dd1ebe7dd4b25470f2fd41
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout v2.4.0+cpu+rc0
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install

View File

@ -27,7 +27,9 @@ from text_generation_server.utils.weights import (
if SYSTEM == "rocm":
from .fused_moe_rocm import grouped_topk
from vllm.model_executor.layers.fused_moe import fused_topk
elif SYSTEM != "ipex":
elif SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
else:
from moe_kernels.fused_moe import fused_topk, grouped_topk
@ -140,6 +142,10 @@ class DenseMoELayer(nn.Module):
)
for i in range(self.n_experts)
]
if SYSTEM == "ipex":
self.ipex_fused_moe = GatedMLPMOE(
W13=self.gate_proj, W2=self.down_proj, W3=self.up_proj, use_prepack=True
)
self.process_group = weights.process_group
@ -152,6 +158,17 @@ class DenseMoELayer(nn.Module):
input_shape = x.shape
x = x.view(-1, input_shape[-1])
if SYSTEM == "ipex":
return self.ipex_fused_moe(
hidden_states=x,
router_logits=gating_output,
top_k=self.topk,
renormalize=self.renormalize,
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
)
if self.n_expert_group is not None and self.topk_group is not None:
topk_weights, topk_ids = grouped_topk(
x,

View File

@ -10,6 +10,8 @@ if SYSTEM == "rocm":
from vllm.model_executor.layers.fused_moe import fused_moe
elif SYSTEM != "ipex":
from moe_kernels.fused_moe import fused_moe
else:
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
class UnquantizedSparseMoELayer(nn.Module):
@ -52,6 +54,10 @@ class UnquantizedSparseMoELayer(nn.Module):
name=down_proj_name,
weights=weights,
)
if SYSTEM == "ipex":
self.ipex_fused_moe = GatedMLPMOE(
W13=self.gate_up_proj, W2=self.down_proj, use_prepack=True
)
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
if SYSTEM == "rocm":
@ -64,6 +70,16 @@ class UnquantizedSparseMoELayer(nn.Module):
renormalize=self.renormalize,
inplace=True,
)
elif SYSTEM == "ipex":
return self.ipex_fused_moe(
hidden_states=x,
router_logits=gating_output,
top_k=self.topk,
renormalize=self.renormalize,
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
)
return fused_moe(
x,

View File

@ -390,7 +390,9 @@ def get_model(
if dtype is None:
if quantize in ["awq", "exl2", "gptq", "marlin"]:
if SYSTEM == "ipex" and not hasattr(torch, "xpu"):
if SYSTEM == "ipex" and not (
hasattr(torch, "xpu") and torch.xpu.is_available()
):
dtype = torch.bfloat16
else:
# These quantizers only work with float16 params.

View File

@ -27,6 +27,8 @@ if SYSTEM == "rocm":
from vllm.model_executor.layers.fused_moe import fused_moe
elif SYSTEM != "ipex":
from moe_kernels.fused_moe import fused_moe
else:
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
from text_generation_server.layers.attention import (
paged_attention,
@ -490,10 +492,26 @@ class BlockSparseMoE(nn.Module):
)
self.process_group = weights.process_group
if SYSTEM == "ipex":
self.ipex_fused_moe = GatedMLPMOE(
W13=self.wv1, W2=self.w2, use_prepack=True
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(x)
if SYSTEM == "ipex":
out = self.ipex_fused_moe(
hidden_states=x,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.moe_normalize_expert_weights,
use_grouped_topk=False,
num_expert_group=None,
topk_group=None,
)
else:
out = fused_moe(
x,
self.wv1,