Fix: make `moe_kernels` imports conditional

`moe-kernels` is an optional dependency, so make sure that we run
without installing this package.

Fixes #2621.
This commit is contained in:
Daniël de Kok 2024-10-08 11:05:28 +00:00
parent 64142489b6
commit e618ce3ada
2 changed files with 21 additions and 2 deletions

View File

@ -12,7 +12,10 @@ from text_generation_server.layers.marlin.gptq import (
) )
if SYSTEM == "cuda": if SYSTEM == "cuda":
try:
from moe_kernels.fused_marlin_moe import fused_marlin_moe from moe_kernels.fused_marlin_moe import fused_marlin_moe
except ImportError:
fused_marlin_moe = None
else: else:
fused_marlin_moe = None fused_marlin_moe = None
@ -72,6 +75,11 @@ class GPTQMarlinSparseMoELayer(nn.Module):
): ):
super().__init__() super().__init__()
if fused_marlin_moe is None:
raise ValueError(
"Fused MoE kernels are not installed. Install the `moe_kernels` package"
)
if not ( if not (
isinstance(weights.loader, GPTQMarlinWeightsLoader) isinstance(weights.loader, GPTQMarlinWeightsLoader)
and can_use_marlin_moe_gemm( and can_use_marlin_moe_gemm(
@ -107,6 +115,7 @@ class GPTQMarlinSparseMoELayer(nn.Module):
self.bits = weights.loader.bits self.bits = weights.loader.bits
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
assert fused_marlin_moe is not None
return fused_marlin_moe( return fused_marlin_moe(
hidden_states=x, hidden_states=x,
w1=self.gate_up_proj.qweight, w1=self.gate_up_proj.qweight,

View File

@ -9,7 +9,10 @@ from text_generation_server.utils.weights import UnquantizedWeight, Weights
if SYSTEM == "rocm": if SYSTEM == "rocm":
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
elif SYSTEM != "ipex": elif SYSTEM != "ipex":
try:
from moe_kernels.fused_moe import fused_moe from moe_kernels.fused_moe import fused_moe
except ImportError:
fused_moe = None
class UnquantizedSparseMoELayer(nn.Module): class UnquantizedSparseMoELayer(nn.Module):
@ -29,6 +32,11 @@ class UnquantizedSparseMoELayer(nn.Module):
): ):
super().__init__() super().__init__()
if fused_moe is None:
raise ValueError(
"Fused MoE kernels are not installed. Install the `moe_kernels` package"
)
assert (n_expert_group is None) == ( assert (n_expert_group is None) == (
topk_group is None topk_group is None
), "n_expert_group and topk_group must both be None or have some value" ), "n_expert_group and topk_group must both be None or have some value"
@ -54,6 +62,8 @@ class UnquantizedSparseMoELayer(nn.Module):
) )
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
assert fused_moe is not None
if SYSTEM == "rocm": if SYSTEM == "rocm":
return fused_moe( return fused_moe(
x, x,