Fix: make `moe_kernels` imports conditional
`moe-kernels` is an optional dependency, so make sure that we run without installing this package. Fixes #2621.
This commit is contained in:
parent
64142489b6
commit
e618ce3ada
|
@ -12,7 +12,10 @@ from text_generation_server.layers.marlin.gptq import (
|
|||
)
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
try:
|
||||
from moe_kernels.fused_marlin_moe import fused_marlin_moe
|
||||
except ImportError:
|
||||
fused_marlin_moe = None
|
||||
else:
|
||||
fused_marlin_moe = None
|
||||
|
||||
|
@ -72,6 +75,11 @@ class GPTQMarlinSparseMoELayer(nn.Module):
|
|||
):
|
||||
super().__init__()
|
||||
|
||||
if fused_marlin_moe is None:
|
||||
raise ValueError(
|
||||
"Fused MoE kernels are not installed. Install the `moe_kernels` package"
|
||||
)
|
||||
|
||||
if not (
|
||||
isinstance(weights.loader, GPTQMarlinWeightsLoader)
|
||||
and can_use_marlin_moe_gemm(
|
||||
|
@ -107,6 +115,7 @@ class GPTQMarlinSparseMoELayer(nn.Module):
|
|||
self.bits = weights.loader.bits
|
||||
|
||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||
assert fused_marlin_moe is not None
|
||||
return fused_marlin_moe(
|
||||
hidden_states=x,
|
||||
w1=self.gate_up_proj.qweight,
|
||||
|
|
|
@ -9,7 +9,10 @@ from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
|||
if SYSTEM == "rocm":
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
elif SYSTEM != "ipex":
|
||||
try:
|
||||
from moe_kernels.fused_moe import fused_moe
|
||||
except ImportError:
|
||||
fused_moe = None
|
||||
|
||||
|
||||
class UnquantizedSparseMoELayer(nn.Module):
|
||||
|
@ -29,6 +32,11 @@ class UnquantizedSparseMoELayer(nn.Module):
|
|||
):
|
||||
super().__init__()
|
||||
|
||||
if fused_moe is None:
|
||||
raise ValueError(
|
||||
"Fused MoE kernels are not installed. Install the `moe_kernels` package"
|
||||
)
|
||||
|
||||
assert (n_expert_group is None) == (
|
||||
topk_group is None
|
||||
), "n_expert_group and topk_group must both be None or have some value"
|
||||
|
@ -54,6 +62,8 @@ class UnquantizedSparseMoELayer(nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||
assert fused_moe is not None
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
return fused_moe(
|
||||
x,
|
||||
|
|
Loading…
Reference in New Issue