Fix: make `moe_kernels` imports conditional

`moe-kernels` is an optional dependency, so make sure that we run
without installing this package.

Fixes #2621.
This commit is contained in:
Daniël de Kok 2024-10-08 11:05:28 +00:00
parent 64142489b6
commit e618ce3ada
2 changed files with 21 additions and 2 deletions

View File

@ -12,7 +12,10 @@ from text_generation_server.layers.marlin.gptq import (
)
if SYSTEM == "cuda":
from moe_kernels.fused_marlin_moe import fused_marlin_moe
try:
from moe_kernels.fused_marlin_moe import fused_marlin_moe
except ImportError:
fused_marlin_moe = None
else:
fused_marlin_moe = None
@ -72,6 +75,11 @@ class GPTQMarlinSparseMoELayer(nn.Module):
):
super().__init__()
if fused_marlin_moe is None:
raise ValueError(
"Fused MoE kernels are not installed. Install the `moe_kernels` package"
)
if not (
isinstance(weights.loader, GPTQMarlinWeightsLoader)
and can_use_marlin_moe_gemm(
@ -107,6 +115,7 @@ class GPTQMarlinSparseMoELayer(nn.Module):
self.bits = weights.loader.bits
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
assert fused_marlin_moe is not None
return fused_marlin_moe(
hidden_states=x,
w1=self.gate_up_proj.qweight,

View File

@ -9,7 +9,10 @@ from text_generation_server.utils.weights import UnquantizedWeight, Weights
if SYSTEM == "rocm":
from vllm.model_executor.layers.fused_moe import fused_moe
elif SYSTEM != "ipex":
from moe_kernels.fused_moe import fused_moe
try:
from moe_kernels.fused_moe import fused_moe
except ImportError:
fused_moe = None
class UnquantizedSparseMoELayer(nn.Module):
@ -29,6 +32,11 @@ class UnquantizedSparseMoELayer(nn.Module):
):
super().__init__()
if fused_moe is None:
raise ValueError(
"Fused MoE kernels are not installed. Install the `moe_kernels` package"
)
assert (n_expert_group is None) == (
topk_group is None
), "n_expert_group and topk_group must both be None or have some value"
@ -54,6 +62,8 @@ class UnquantizedSparseMoELayer(nn.Module):
)
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
assert fused_moe is not None
if SYSTEM == "rocm":
return fused_moe(
x,