From b4ec427ad0d8935f83428998a0a9d0d0e532e90c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 19 Nov 2024 08:04:23 +0100 Subject: [PATCH] Simplify two ipex conditions (#2755) --- server/text_generation_server/layers/moe/unquantized.py | 6 +++--- .../models/custom_modeling/flash_dbrx_modeling.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/layers/moe/unquantized.py b/server/text_generation_server/layers/moe/unquantized.py index 3d6a0b99..75af0409 100644 --- a/server/text_generation_server/layers/moe/unquantized.py +++ b/server/text_generation_server/layers/moe/unquantized.py @@ -8,10 +8,10 @@ from text_generation_server.utils.weights import UnquantizedWeight, Weights if SYSTEM == "rocm": from vllm.model_executor.layers.fused_moe import fused_moe -elif SYSTEM != "ipex": - from moe_kernels.fused_moe import fused_moe -else: +elif SYSTEM == "ipex": from intel_extension_for_pytorch.llm.modules import GatedMLPMOE +else: + from moe_kernels.fused_moe import fused_moe class UnquantizedSparseMoELayer(nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index b8041671..2d1aa96c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -25,10 +25,10 @@ from text_generation_server.utils.import_utils import SYSTEM if SYSTEM == "rocm": from vllm.model_executor.layers.fused_moe import fused_moe -elif SYSTEM != "ipex": - from moe_kernels.fused_moe import fused_moe -else: +elif SYSTEM == "ipex": from intel_extension_for_pytorch.llm.modules import GatedMLPMOE +else: + from moe_kernels.fused_moe import fused_moe from text_generation_server.layers.attention import ( paged_attention,