diff --git a/server/text_generation_server/layers/moe/gptq_marlin.py b/server/text_generation_server/layers/moe/gptq_marlin.py index 3d4ca9d8..0ea4bcaf 100644 --- a/server/text_generation_server/layers/moe/gptq_marlin.py +++ b/server/text_generation_server/layers/moe/gptq_marlin.py @@ -12,7 +12,10 @@ from text_generation_server.layers.marlin.gptq import ( ) if SYSTEM == "cuda": - from moe_kernels.fused_marlin_moe import fused_marlin_moe + try: + from moe_kernels.fused_marlin_moe import fused_marlin_moe + except ImportError: + fused_marlin_moe = None else: fused_marlin_moe = None @@ -72,6 +75,11 @@ class GPTQMarlinSparseMoELayer(nn.Module): ): super().__init__() + if fused_marlin_moe is None: + raise ValueError( + "Fused MoE kernels are not installed. Install the `moe_kernels` package" + ) + if not ( isinstance(weights.loader, GPTQMarlinWeightsLoader) and can_use_marlin_moe_gemm( @@ -107,6 +115,7 @@ class GPTQMarlinSparseMoELayer(nn.Module): self.bits = weights.loader.bits def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: + assert fused_marlin_moe is not None return fused_marlin_moe( hidden_states=x, w1=self.gate_up_proj.qweight, diff --git a/server/text_generation_server/layers/moe/unquantized.py b/server/text_generation_server/layers/moe/unquantized.py index d9d62c0e..5402123d 100644 --- a/server/text_generation_server/layers/moe/unquantized.py +++ b/server/text_generation_server/layers/moe/unquantized.py @@ -9,7 +9,10 @@ from text_generation_server.utils.weights import UnquantizedWeight, Weights if SYSTEM == "rocm": from vllm.model_executor.layers.fused_moe import fused_moe elif SYSTEM != "ipex": - from moe_kernels.fused_moe import fused_moe + try: + from moe_kernels.fused_moe import fused_moe + except ImportError: + fused_moe = None class UnquantizedSparseMoELayer(nn.Module): @@ -29,6 +32,11 @@ class UnquantizedSparseMoELayer(nn.Module): ): super().__init__() + if fused_moe is None: + raise ValueError( + "Fused MoE kernels are not installed. Install the `moe_kernels` package" + ) + assert (n_expert_group is None) == ( topk_group is None ), "n_expert_group and topk_group must both be None or have some value" @@ -54,6 +62,8 @@ class UnquantizedSparseMoELayer(nn.Module): ) def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: + assert fused_moe is not None + if SYSTEM == "rocm": return fused_moe( x,