hf_text-generation-inference/server/text_generation_server/layers/moe/unquantized.py

139 lines
3.6 KiB
Python
Raw Normal View History

from typing import Optional
import torch
import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import UnquantizedWeight, Weights
if SYSTEM == "rocm":
from vllm.model_executor.layers.fused_moe import fused_moe
elif SYSTEM != "ipex":
from moe_kernels.fused_moe import fused_moe
class UnquantizedSparseMoELayer(nn.Module):
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
):
super().__init__()
assert (n_expert_group is None) == (
topk_group is None
), "n_expert_group and topk_group must both be None or have some value"
self.n_expert_group = n_expert_group
self.topk = topk
self.topk_group = topk_group
self.renormalize = renormalize
self.gate_up_proj = _load_expert_multi_weights_col(
prefix=prefix,
n_experts=n_experts,
gate_proj_name=gate_proj_name,
up_proj_name=up_proj_name,
weights=weights,
)
self.down_proj = _load_expert_weights_row(
prefix=prefix,
n_experts=n_experts,
name=down_proj_name,
weights=weights,
)
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
if SYSTEM == "rocm":
return fused_moe(
x,
self.gate_up_proj,
self.down_proj,
gating_output,
self.topk,
renormalize=self.renormalize,
inplace=True,
)
return fused_moe(
x,
w1=self.gate_up_proj,
w2=self.down_proj,
gating_output=gating_output,
topk=self.topk,
renormalize=self.renormalize,
inplace=True,
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
)
def _load_expert_multi_weights_col(
*,
prefix: str,
n_experts: int,
gate_proj_name: str,
up_proj_name: str,
weights: Weights,
) -> torch.Tensor:
all_weight = None
for i in range(n_experts):
weight = weights.get_multi_weights_col(
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
)
assert isinstance(weight, UnquantizedWeight)
if all_weight is None:
all_weight = torch.empty(
(n_experts,) + weight.weight.shape,
dtype=weight.weight.dtype,
device=weight.weight.device,
)
all_weight[i] = weight.weight
assert all_weight is not None
return all_weight
def _load_expert_weights_row(
*,
prefix: str,
n_experts: int,
name: str,
weights: Weights,
) -> torch.Tensor:
all_weight = None
for i in range(n_experts):
weight = weights.get_weights_row(
f"{prefix}.{i}.{name}",
)
assert isinstance(weight, UnquantizedWeight)
if all_weight is None:
all_weight = torch.empty(
(n_experts,) + weight.weight.shape,
dtype=weight.weight.dtype,
device=weight.weight.device,
)
all_weight[i] = weight.weight
assert all_weight is not None
return all_weight