add ipex moe implementation to support Mixtral and PhiMoe (#2707)
* add ipex moe implementation to support Mixtral and PhiMoe Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * update to ipex xpu 2.5 Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * torch has xpu support in 2.5 Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix oneapi basekit version Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * Apply suggestions from code review Co-authored-by: Daniël de Kok <me@github.danieldk.eu> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: Daniël de Kok <me@github.danieldk.eu>
This commit is contained in:
parent
fea62e928f
commit
a5ecd6e586
|
@ -83,7 +83,11 @@ RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dea
|
||||||
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
|
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
|
||||||
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
|
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
|
||||||
|
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y intel-basekit xpu-smi cmake ninja-build pciutils
|
RUN echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/intel-for-pytorch-gpu-dev all main" > /tmp/intel-for-pytorch-gpu-dev.list
|
||||||
|
|
||||||
|
RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d
|
||||||
|
|
||||||
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y intel-basekit=2024.2.1-98 xpu-smi cmake ninja-build pciutils intel-pti-dev-0.9
|
||||||
|
|
||||||
# Text Generation Inference base env
|
# Text Generation Inference base env
|
||||||
ENV HF_HOME=/data \
|
ENV HF_HOME=/data \
|
||||||
|
@ -91,8 +95,14 @@ ENV HF_HOME=/data \
|
||||||
PORT=80
|
PORT=80
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
RUN pip install torch==2.3.1+cxx11.abi torchvision==0.18.1+cxx11.abi torchaudio==2.3.1+cxx11.abi intel-extension-for-pytorch==2.3.110+xpu oneccl_bind_pt==2.3.100+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --no-cache-dir
|
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torch-2.5.0a0%2Bgite84e33f-cp311-cp311-linux_x86_64.whl --no-cache-dir
|
||||||
|
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torchaudio-2.5.0a0%2B56bc006-cp311-cp311-linux_x86_64.whl --no-cache-dir
|
||||||
|
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torchvision-0.20.0a0%2B8e8a208-cp311-cp311-linux_x86_64.whl --no-cache-dir
|
||||||
|
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.5.10%2Bgit9d489a8-cp311-cp311-linux_x86_64.whl --no-cache-dir
|
||||||
|
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/oneccl_bind_pt-2.5.0%2Bxpu-cp311-cp311-linux_x86_64.whl --no-cache-dir
|
||||||
|
|
||||||
RUN pip install triton-xpu==3.0.0b2 --no-cache-dir
|
RUN pip install triton-xpu==3.0.0b2 --no-cache-dir
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
|
@ -108,13 +118,13 @@ ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest
|
||||||
ENV I_MPI_ROOT=/opt/intel/oneapi/mpi/latest
|
ENV I_MPI_ROOT=/opt/intel/oneapi/mpi/latest
|
||||||
ENV FI_PROVIDER_PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib/prov:/usr/lib/x86_64-linux-gnu/libfabric
|
ENV FI_PROVIDER_PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib/prov:/usr/lib/x86_64-linux-gnu/libfabric
|
||||||
ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mkl/latest/lib/:/opt/intel/oneapi/compiler/latest/lib
|
ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mkl/latest/lib/:/opt/intel/oneapi/compiler/latest/lib
|
||||||
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:/opt/conda/lib
|
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:/opt/intel/oneapi/pti/0.9/lib:/opt/conda/lib
|
||||||
ENV PATH=/opt/conda/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
ENV PATH=/opt/conda/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
||||||
ENV CCL_ZE_IPC_EXCHANGE=sockets
|
ENV CCL_ZE_IPC_EXCHANGE=sockets
|
||||||
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
|
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
|
||||||
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include
|
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include
|
||||||
ENV TORCH_LLM_ALLREDUCE=1
|
#ENV TORCH_LLM_ALLREDUCE=1
|
||||||
ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
|
#ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
|
||||||
|
|
||||||
# Install benchmarker
|
# Install benchmarker
|
||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
|
@ -187,7 +197,7 @@ RUN pip install triton py-libnuma
|
||||||
|
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout f86e93e4890dc2c989024d148d415c9aa8a1649f
|
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout 2e1c98f74ec1b35ad8dd1ebe7dd4b25470f2fd41
|
||||||
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout v2.4.0+cpu+rc0
|
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout v2.4.0+cpu+rc0
|
||||||
|
|
||||||
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install
|
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install
|
||||||
|
|
|
@ -27,7 +27,9 @@ from text_generation_server.utils.weights import (
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
from .fused_moe_rocm import grouped_topk
|
from .fused_moe_rocm import grouped_topk
|
||||||
from vllm.model_executor.layers.fused_moe import fused_topk
|
from vllm.model_executor.layers.fused_moe import fused_topk
|
||||||
elif SYSTEM != "ipex":
|
elif SYSTEM == "ipex":
|
||||||
|
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||||
|
else:
|
||||||
from moe_kernels.fused_moe import fused_topk, grouped_topk
|
from moe_kernels.fused_moe import fused_topk, grouped_topk
|
||||||
|
|
||||||
|
|
||||||
|
@ -140,6 +142,10 @@ class DenseMoELayer(nn.Module):
|
||||||
)
|
)
|
||||||
for i in range(self.n_experts)
|
for i in range(self.n_experts)
|
||||||
]
|
]
|
||||||
|
if SYSTEM == "ipex":
|
||||||
|
self.ipex_fused_moe = GatedMLPMOE(
|
||||||
|
W13=self.gate_proj, W2=self.down_proj, W3=self.up_proj, use_prepack=True
|
||||||
|
)
|
||||||
|
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
|
@ -152,6 +158,17 @@ class DenseMoELayer(nn.Module):
|
||||||
input_shape = x.shape
|
input_shape = x.shape
|
||||||
x = x.view(-1, input_shape[-1])
|
x = x.view(-1, input_shape[-1])
|
||||||
|
|
||||||
|
if SYSTEM == "ipex":
|
||||||
|
return self.ipex_fused_moe(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=gating_output,
|
||||||
|
top_k=self.topk,
|
||||||
|
renormalize=self.renormalize,
|
||||||
|
use_grouped_topk=self.n_expert_group is not None,
|
||||||
|
num_expert_group=self.n_expert_group,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
)
|
||||||
|
|
||||||
if self.n_expert_group is not None and self.topk_group is not None:
|
if self.n_expert_group is not None and self.topk_group is not None:
|
||||||
topk_weights, topk_ids = grouped_topk(
|
topk_weights, topk_ids = grouped_topk(
|
||||||
x,
|
x,
|
||||||
|
|
|
@ -10,6 +10,8 @@ if SYSTEM == "rocm":
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
elif SYSTEM != "ipex":
|
elif SYSTEM != "ipex":
|
||||||
from moe_kernels.fused_moe import fused_moe
|
from moe_kernels.fused_moe import fused_moe
|
||||||
|
else:
|
||||||
|
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||||
|
|
||||||
|
|
||||||
class UnquantizedSparseMoELayer(nn.Module):
|
class UnquantizedSparseMoELayer(nn.Module):
|
||||||
|
@ -52,6 +54,10 @@ class UnquantizedSparseMoELayer(nn.Module):
|
||||||
name=down_proj_name,
|
name=down_proj_name,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
if SYSTEM == "ipex":
|
||||||
|
self.ipex_fused_moe = GatedMLPMOE(
|
||||||
|
W13=self.gate_up_proj, W2=self.down_proj, use_prepack=True
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
|
@ -64,6 +70,16 @@ class UnquantizedSparseMoELayer(nn.Module):
|
||||||
renormalize=self.renormalize,
|
renormalize=self.renormalize,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
)
|
)
|
||||||
|
elif SYSTEM == "ipex":
|
||||||
|
return self.ipex_fused_moe(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=gating_output,
|
||||||
|
top_k=self.topk,
|
||||||
|
renormalize=self.renormalize,
|
||||||
|
use_grouped_topk=self.n_expert_group is not None,
|
||||||
|
num_expert_group=self.n_expert_group,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
)
|
||||||
|
|
||||||
return fused_moe(
|
return fused_moe(
|
||||||
x,
|
x,
|
||||||
|
|
|
@ -390,7 +390,9 @@ def get_model(
|
||||||
|
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||||
if SYSTEM == "ipex" and not hasattr(torch, "xpu"):
|
if SYSTEM == "ipex" and not (
|
||||||
|
hasattr(torch, "xpu") and torch.xpu.is_available()
|
||||||
|
):
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
else:
|
else:
|
||||||
# These quantizers only work with float16 params.
|
# These quantizers only work with float16 params.
|
||||||
|
|
|
@ -27,6 +27,8 @@ if SYSTEM == "rocm":
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
elif SYSTEM != "ipex":
|
elif SYSTEM != "ipex":
|
||||||
from moe_kernels.fused_moe import fused_moe
|
from moe_kernels.fused_moe import fused_moe
|
||||||
|
else:
|
||||||
|
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||||
|
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
|
@ -490,19 +492,35 @@ class BlockSparseMoE(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
|
if SYSTEM == "ipex":
|
||||||
|
self.ipex_fused_moe = GatedMLPMOE(
|
||||||
|
W13=self.wv1, W2=self.w2, use_prepack=True
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.gate(x)
|
router_logits = self.gate(x)
|
||||||
out = fused_moe(
|
|
||||||
x,
|
if SYSTEM == "ipex":
|
||||||
self.wv1,
|
out = self.ipex_fused_moe(
|
||||||
self.w2,
|
hidden_states=x,
|
||||||
router_logits,
|
router_logits=router_logits,
|
||||||
self.top_k,
|
top_k=self.top_k,
|
||||||
renormalize=self.moe_normalize_expert_weights,
|
renormalize=self.moe_normalize_expert_weights,
|
||||||
inplace=True,
|
use_grouped_topk=False,
|
||||||
)
|
num_expert_group=None,
|
||||||
|
topk_group=None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out = fused_moe(
|
||||||
|
x,
|
||||||
|
self.wv1,
|
||||||
|
self.w2,
|
||||||
|
router_logits,
|
||||||
|
self.top_k,
|
||||||
|
renormalize=self.moe_normalize_expert_weights,
|
||||||
|
inplace=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Reduce sum
|
# Reduce sum
|
||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1:
|
||||||
|
|
Loading…
Reference in New Issue