diff --git a/server/text_generation_server/layers/moe/unquantized.py b/server/text_generation_server/layers/moe/unquantized.py index 3d6a0b99..75af0409 100644 --- a/server/text_generation_server/layers/moe/unquantized.py +++ b/server/text_generation_server/layers/moe/unquantized.py @@ -8,10 +8,10 @@ from text_generation_server.utils.weights import UnquantizedWeight, Weights if SYSTEM == "rocm": from vllm.model_executor.layers.fused_moe import fused_moe -elif SYSTEM != "ipex": - from moe_kernels.fused_moe import fused_moe -else: +elif SYSTEM == "ipex": from intel_extension_for_pytorch.llm.modules import GatedMLPMOE +else: + from moe_kernels.fused_moe import fused_moe class UnquantizedSparseMoELayer(nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index b8041671..2d1aa96c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -25,10 +25,10 @@ from text_generation_server.utils.import_utils import SYSTEM if SYSTEM == "rocm": from vllm.model_executor.layers.fused_moe import fused_moe -elif SYSTEM != "ipex": - from moe_kernels.fused_moe import fused_moe -else: +elif SYSTEM == "ipex": from intel_extension_for_pytorch.llm.modules import GatedMLPMOE +else: + from moe_kernels.fused_moe import fused_moe from text_generation_server.layers.attention import ( paged_attention,