Add `DenseMoELayer` and wire it up in Mixtral/Deepseek V2 (#2537)
This replaces the custom layers in both models.
This commit is contained in:
parent
c29dc89c18
commit
3f14cd1420
|
@ -1,15 +1,178 @@
|
||||||
from typing import Optional
|
from typing import Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from loguru import logger
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
)
|
||||||
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||||
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
|
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.log import log_once
|
||||||
from text_generation_server.utils.weights import (
|
from text_generation_server.utils.weights import (
|
||||||
DefaultWeightsLoader,
|
DefaultWeightsLoader,
|
||||||
UnquantizedWeight,
|
UnquantizedWeight,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if SYSTEM != "ipex":
|
||||||
|
from moe_kernels.fused_moe import fused_topk, grouped_topk
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: we are using a protocol here, because multiple inherance is not nice.
|
||||||
|
# We need `Module`, and `Module` -> some abstract class -> some concrete
|
||||||
|
# class inheritance is whacky.
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class MoELayer(Protocol):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
n_expert_group: Optional[int],
|
||||||
|
n_experts: int,
|
||||||
|
prefix: str,
|
||||||
|
renormalize: bool,
|
||||||
|
topk: int,
|
||||||
|
topk_group: Optional[int],
|
||||||
|
weights: Weights,
|
||||||
|
gate_proj_name: str = "gate_proj",
|
||||||
|
up_proj_name: str = "up_proj",
|
||||||
|
down_proj_name: str = "down_proj",
|
||||||
|
hidden_act: str = "silu",
|
||||||
|
): ...
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x: torch.Tensor, *, gating_output: torch.Tensor
|
||||||
|
) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
|
||||||
|
class DenseMoELayer(nn.Module):
|
||||||
|
"""
|
||||||
|
Layer for MoE that applies *all* experts to each tokens and then weights
|
||||||
|
their outputs based on the calculated routing. This layer is much slower
|
||||||
|
than `SparseMoELayer` and should only be used when no fused kernels are
|
||||||
|
available (e.g. for unsupported quantizers).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
n_expert_group: Optional[int],
|
||||||
|
n_experts: int,
|
||||||
|
prefix: str,
|
||||||
|
renormalize: bool,
|
||||||
|
topk: int,
|
||||||
|
topk_group: Optional[int],
|
||||||
|
weights: Weights,
|
||||||
|
gate_proj_name: str = "gate_proj",
|
||||||
|
up_proj_name: str = "up_proj",
|
||||||
|
down_proj_name: str = "down_proj",
|
||||||
|
hidden_act: str = "silu",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
log_once(
|
||||||
|
logger.info,
|
||||||
|
"No fused layers are available for this model type, using (slower) dense MoE layer",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (n_expert_group is None) == (
|
||||||
|
topk_group is None
|
||||||
|
), "n_expert_group and topk_group must both be None or have some value"
|
||||||
|
|
||||||
|
self.n_expert_group = n_expert_group
|
||||||
|
self.n_experts = n_experts
|
||||||
|
self.renormalize = renormalize
|
||||||
|
self.topk = topk
|
||||||
|
self.topk_group = topk_group
|
||||||
|
|
||||||
|
if "gelu" in hidden_act:
|
||||||
|
self.act = lambda x: torch.nn.functional.gelu(
|
||||||
|
x,
|
||||||
|
approximate=(
|
||||||
|
"tanh"
|
||||||
|
if hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||||
|
else "none"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif "silu" in hidden_act:
|
||||||
|
self.act = torch.nn.functional.silu
|
||||||
|
else:
|
||||||
|
self.act = ACT2FN[hidden_act]
|
||||||
|
|
||||||
|
self.gate_proj = [
|
||||||
|
TensorParallelColumnLinear.load(
|
||||||
|
None,
|
||||||
|
prefix=f"{prefix}.{i}.{gate_proj_name}",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
for i in range(self.n_experts)
|
||||||
|
]
|
||||||
|
self.up_proj = [
|
||||||
|
TensorParallelColumnLinear.load(
|
||||||
|
None,
|
||||||
|
prefix=f"{prefix}.{i}.{up_proj_name}",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
for i in range(self.n_experts)
|
||||||
|
]
|
||||||
|
self.down_proj = [
|
||||||
|
TensorParallelRowLinear.load(
|
||||||
|
None,
|
||||||
|
prefix=f"{prefix}.{i}.{down_proj_name}",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
for i in range(self.n_experts)
|
||||||
|
]
|
||||||
|
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
x: (sequence_length, model_dim)
|
||||||
|
gating_output: (sequence_length, n_experts)
|
||||||
|
"""
|
||||||
|
# optional reshape
|
||||||
|
input_shape = x.shape
|
||||||
|
x = x.view(-1, input_shape[-1])
|
||||||
|
|
||||||
|
if self.n_expert_group is not None and self.topk_group is not None:
|
||||||
|
topk_weights, topk_ids = grouped_topk(
|
||||||
|
x,
|
||||||
|
gating_output,
|
||||||
|
self.topk,
|
||||||
|
renormalize=self.renormalize,
|
||||||
|
num_expert_group=self.n_expert_group,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
topk_weights, topk_ids = fused_topk(
|
||||||
|
x, gating_output, self.topk, self.renormalize
|
||||||
|
)
|
||||||
|
topk_weights = topk_weights.to(x.dtype)
|
||||||
|
|
||||||
|
weights = torch.zeros(
|
||||||
|
topk_ids.shape[0], self.n_experts, dtype=x.dtype, device=x.device
|
||||||
|
)
|
||||||
|
|
||||||
|
weights.scatter_(1, topk_ids.long(), topk_weights.to(weights.dtype))
|
||||||
|
|
||||||
|
out = torch.zeros_like(x)
|
||||||
|
for i in range(self.n_experts):
|
||||||
|
h = self.act(self.gate_proj[i](x)) * self.up_proj[i](x)
|
||||||
|
h = self.down_proj[i](h, reduce=False)
|
||||||
|
out += h * weights[:, i].view(-1, 1)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class SparseMoELayer(nn.Module):
|
class SparseMoELayer(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -13,10 +13,14 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
from torch import nn
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
FastLinear,
|
FastLinear,
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
|
@ -26,22 +30,16 @@ from text_generation_server.layers import (
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
|
Seqlen,
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
Seqlen,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||||
from text_generation_server.layers.moe import SparseMoELayer
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.weights import Weights
|
from text_generation_server.utils.weights import Weights
|
||||||
from torch import nn
|
|
||||||
from transformers.activations import ACT2FN
|
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
|
||||||
|
|
||||||
if SYSTEM != "ipex":
|
|
||||||
from moe_kernels.fused_moe import grouped_topk
|
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
try:
|
try:
|
||||||
|
@ -410,8 +408,14 @@ class DeepseekV2MLP(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BlockSparseMoE(nn.Module):
|
class DeepseekV2MoE(nn.Module):
|
||||||
def __init__(self, prefix, config: DeepseekV2Config, weights):
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix,
|
||||||
|
config: DeepseekV2Config,
|
||||||
|
moe_layer_cls: Type[MoELayer],
|
||||||
|
weights,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.hidden_dim = config.hidden_size
|
self.hidden_dim = config.hidden_size
|
||||||
|
@ -423,7 +427,7 @@ class BlockSparseMoE(nn.Module):
|
||||||
# Gating
|
# Gating
|
||||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||||
|
|
||||||
self.moe_layer = SparseMoELayer(
|
self.moe_layer = moe_layer_cls(
|
||||||
prefix=f"{prefix}.experts",
|
prefix=f"{prefix}.experts",
|
||||||
n_experts=config.n_routed_experts,
|
n_experts=config.n_routed_experts,
|
||||||
n_expert_group=config.n_group,
|
n_expert_group=config.n_group,
|
||||||
|
@ -432,6 +436,7 @@ class BlockSparseMoE(nn.Module):
|
||||||
topk_group=config.topk_group,
|
topk_group=config.topk_group,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
assert isinstance(self.moe_layer, MoELayer)
|
||||||
|
|
||||||
if config.n_shared_experts is not None:
|
if config.n_shared_experts is not None:
|
||||||
self.shared_experts = DeepseekV2MLP(
|
self.shared_experts = DeepseekV2MLP(
|
||||||
|
@ -466,96 +471,6 @@ class BlockSparseMoE(nn.Module):
|
||||||
return out.view(*x.shape)
|
return out.view(*x.shape)
|
||||||
|
|
||||||
|
|
||||||
class DenseMoE(nn.Module):
|
|
||||||
def __init__(self, prefix: str, config: DeepseekV2Config, weights: Weights):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.hidden_dim = config.hidden_size
|
|
||||||
self.moe_intermediate_size = config.moe_intermediate_size
|
|
||||||
self.n_routed_experts = config.n_routed_experts
|
|
||||||
self.n_expert_group = config.n_group
|
|
||||||
self.topk_group = config.topk_group
|
|
||||||
self.top_k = config.num_experts_per_tok
|
|
||||||
self.norm_topk_prob = config.norm_topk_prob
|
|
||||||
self.routed_scaling_factor = config.routed_scaling_factor
|
|
||||||
|
|
||||||
# Gating
|
|
||||||
#
|
|
||||||
# Seems like no one quantizes the gate.
|
|
||||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
|
||||||
|
|
||||||
self.experts = [
|
|
||||||
DeepseekV2MLP(
|
|
||||||
f"{prefix}.experts.{i}", config, weights, self.moe_intermediate_size
|
|
||||||
)
|
|
||||||
for i in range(self.n_routed_experts)
|
|
||||||
]
|
|
||||||
|
|
||||||
if config.n_shared_experts is not None:
|
|
||||||
self.shared_experts = DeepseekV2MLP(
|
|
||||||
prefix=f"{prefix}.shared_experts",
|
|
||||||
config=config,
|
|
||||||
weights=weights,
|
|
||||||
intermediate_size=config.moe_intermediate_size
|
|
||||||
* config.n_shared_experts,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.shared_experts = None
|
|
||||||
|
|
||||||
self.process_group = weights.process_group
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
x: (sequence_length, model_dim)
|
|
||||||
gate_logits: (sequence_length, n_experts)
|
|
||||||
"""
|
|
||||||
# optional reshape
|
|
||||||
input_shape = x.shape
|
|
||||||
x = x.view(-1, input_shape[-1])
|
|
||||||
|
|
||||||
if self.shared_experts is not None:
|
|
||||||
shared_output = self.shared_experts(x, reduce=False)
|
|
||||||
else:
|
|
||||||
shared_output = None
|
|
||||||
|
|
||||||
# gate_logits: (sequence_length, n_experts)
|
|
||||||
router_logits = self.gate(x)
|
|
||||||
|
|
||||||
topk_weights, topk_ids = grouped_topk(
|
|
||||||
x,
|
|
||||||
router_logits,
|
|
||||||
self.top_k,
|
|
||||||
renormalize=self.norm_topk_prob,
|
|
||||||
num_expert_group=self.n_expert_group,
|
|
||||||
topk_group=self.topk_group,
|
|
||||||
)
|
|
||||||
|
|
||||||
out = self.moe_infer_gpu(x, topk_ids, topk_weights) * self.routed_scaling_factor
|
|
||||||
|
|
||||||
if shared_output is not None:
|
|
||||||
out = out + shared_output
|
|
||||||
|
|
||||||
# Reduce sum
|
|
||||||
if self.process_group.size() > 1:
|
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
def moe_infer_gpu(
|
|
||||||
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
|
|
||||||
):
|
|
||||||
weights = torch.zeros(
|
|
||||||
topk_ids.shape[0], len(self.experts), dtype=x.dtype, device=x.device
|
|
||||||
)
|
|
||||||
weights.scatter_(1, topk_ids, topk_weight)
|
|
||||||
|
|
||||||
out = x.new_zeros(x.shape[0], self.hidden_dim)
|
|
||||||
for i, expert in enumerate(self.experts):
|
|
||||||
# Add expert output to out with masking
|
|
||||||
out += expert(x, reduce=False) * weights[:, i].view(-1, 1)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2Layer(nn.Module):
|
class DeepseekV2Layer(nn.Module):
|
||||||
def __init__(self, prefix, layer_id, config, weights):
|
def __init__(self, prefix, layer_id, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -572,10 +487,12 @@ class DeepseekV2Layer(nn.Module):
|
||||||
and layer_id >= config.first_k_dense_replace
|
and layer_id >= config.first_k_dense_replace
|
||||||
and layer_id % config.moe_layer_freq == 0
|
and layer_id % config.moe_layer_freq == 0
|
||||||
):
|
):
|
||||||
moe_cls = (
|
moe_layer_cls = (
|
||||||
BlockSparseMoE if SparseMoELayer.is_supported(weights) else DenseMoE
|
SparseMoELayer
|
||||||
|
if SparseMoELayer.is_supported(weights)
|
||||||
|
else DenseMoELayer
|
||||||
)
|
)
|
||||||
self.mlp = moe_cls(f"{prefix}.mlp", config, weights)
|
self.mlp = DeepseekV2MoE(f"{prefix}.mlp", config, moe_layer_cls, weights)
|
||||||
else:
|
else:
|
||||||
self.mlp = DeepseekV2MLP(
|
self.mlp = DeepseekV2MLP(
|
||||||
prefix=f"{prefix}.mlp",
|
prefix=f"{prefix}.mlp",
|
||||||
|
|
|
@ -18,38 +18,31 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
from transformers.activations import ACT2FN
|
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
|
||||||
|
|
||||||
from text_generation_server.layers.attention import (
|
|
||||||
paged_attention,
|
|
||||||
attention,
|
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
|
||||||
)
|
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
FastLinear,
|
FastLinear,
|
||||||
TensorParallelRowLinear,
|
SpeculativeHead,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
SpeculativeHead,
|
TensorParallelRowLinear,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.moe import SparseMoELayer
|
from text_generation_server.layers.attention import (
|
||||||
from text_generation_server.layers.layernorm import (
|
Seqlen,
|
||||||
FastRMSNorm,
|
attention,
|
||||||
)
|
paged_attention,
|
||||||
from text_generation_server.layers.rotary import (
|
reshape_and_cache,
|
||||||
PositionRotaryEmbedding,
|
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||||
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.weights import UnquantizedWeight
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
|
||||||
|
|
||||||
|
@ -315,14 +308,16 @@ def round_up(x: torch.Tensor, value: int):
|
||||||
return torch.div(x + (value - 1), value, rounding_mode="trunc") * value
|
return torch.div(x + (value - 1), value, rounding_mode="trunc") * value
|
||||||
|
|
||||||
|
|
||||||
class BlockSparseMoE(nn.Module):
|
class MixtralMoE(nn.Module):
|
||||||
def __init__(self, prefix, config: MixtralConfig, weights):
|
def __init__(
|
||||||
|
self, prefix, config: MixtralConfig, moe_layer_cls: Type[MoELayer], weights
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# gating
|
# gating
|
||||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||||
|
|
||||||
self.moe = SparseMoELayer(
|
self.moe = moe_layer_cls(
|
||||||
n_expert_group=None,
|
n_expert_group=None,
|
||||||
n_experts=config.num_local_experts,
|
n_experts=config.num_local_experts,
|
||||||
prefix=f"{prefix}.experts",
|
prefix=f"{prefix}.experts",
|
||||||
|
@ -334,6 +329,7 @@ class BlockSparseMoE(nn.Module):
|
||||||
up_proj_name="w3",
|
up_proj_name="w3",
|
||||||
down_proj_name="w2",
|
down_proj_name="w2",
|
||||||
)
|
)
|
||||||
|
assert isinstance(self.moe, MoELayer)
|
||||||
|
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
|
@ -349,95 +345,6 @@ class BlockSparseMoE(nn.Module):
|
||||||
return out.view(*x.shape)
|
return out.view(*x.shape)
|
||||||
|
|
||||||
|
|
||||||
class DenseMoE(nn.Module):
|
|
||||||
def __init__(self, prefix, config: MixtralConfig, weights):
|
|
||||||
super().__init__()
|
|
||||||
self.hidden_dim = config.hidden_size
|
|
||||||
self.ffn_dim = config.intermediate_size // weights.process_group.size()
|
|
||||||
self.num_experts = config.num_local_experts
|
|
||||||
self.top_k = config.num_experts_per_tok
|
|
||||||
|
|
||||||
act = config.hidden_act
|
|
||||||
if "gelu" in act:
|
|
||||||
self.act = lambda x: torch.nn.functional.gelu(
|
|
||||||
x,
|
|
||||||
approximate=(
|
|
||||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
elif "silu" in act:
|
|
||||||
self.act = torch.nn.functional.silu
|
|
||||||
else:
|
|
||||||
self.act = ACT2FN[act]
|
|
||||||
|
|
||||||
# gating
|
|
||||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
|
||||||
|
|
||||||
self.w1 = [
|
|
||||||
TensorParallelColumnLinear.load(
|
|
||||||
config, prefix=f"{prefix}.experts.{i}.w1", weights=weights, bias=False
|
|
||||||
)
|
|
||||||
for i in range(self.num_experts)
|
|
||||||
]
|
|
||||||
self.w3 = [
|
|
||||||
TensorParallelColumnLinear.load(
|
|
||||||
config, prefix=f"{prefix}.experts.{i}.w3", weights=weights, bias=False
|
|
||||||
)
|
|
||||||
for i in range(self.num_experts)
|
|
||||||
]
|
|
||||||
self.w2 = [
|
|
||||||
TensorParallelRowLinear.load(
|
|
||||||
config, prefix=f"{prefix}.experts.{i}.w2", weights=weights, bias=False
|
|
||||||
)
|
|
||||||
for i in range(self.num_experts)
|
|
||||||
]
|
|
||||||
|
|
||||||
self.process_group = weights.process_group
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
x: (sequence_length, model_dim)
|
|
||||||
gate_logits: (sequence_length, n_experts)
|
|
||||||
"""
|
|
||||||
# optional reshape
|
|
||||||
input_shape = x.shape
|
|
||||||
x = x.view(-1, input_shape[-1])
|
|
||||||
|
|
||||||
# gate_logits: (sequence_length, n_experts)
|
|
||||||
gate_logits = self.gate(x)
|
|
||||||
# all_probs: (sequence_length, n_experts) and upcast for softmax
|
|
||||||
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
|
||||||
|
|
||||||
if self.top_k < self.num_experts:
|
|
||||||
_, not_selected_experts = torch.topk(
|
|
||||||
all_probs,
|
|
||||||
self.num_experts - self.top_k,
|
|
||||||
largest=False,
|
|
||||||
sorted=False,
|
|
||||||
dim=1,
|
|
||||||
)
|
|
||||||
# Mask not selected experts
|
|
||||||
all_probs.scatter_(1, not_selected_experts, 0)
|
|
||||||
|
|
||||||
# Re-normalize
|
|
||||||
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
|
|
||||||
weights = weights.to(x.dtype)
|
|
||||||
|
|
||||||
# Final output tensor
|
|
||||||
out = x.new_zeros(x.shape[0], self.hidden_dim)
|
|
||||||
for i in range(self.num_experts):
|
|
||||||
h = self.act(self.w1[i](x)) * self.w3[i](x)
|
|
||||||
h = self.w2[i](h, reduce=False)
|
|
||||||
# Add expert output to out with masking
|
|
||||||
out += h * weights[:, i].view(-1, 1)
|
|
||||||
|
|
||||||
# Reduce sum
|
|
||||||
if self.process_group.size() > 1:
|
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class MixtralLayer(nn.Module):
|
class MixtralLayer(nn.Module):
|
||||||
def __init__(self, prefix: str, layer_id, config, weights):
|
def __init__(self, prefix: str, layer_id, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -447,8 +354,12 @@ class MixtralLayer(nn.Module):
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
moe_cls = BlockSparseMoE if SparseMoELayer.is_supported(weights) else DenseMoE
|
moe_layer_cls = (
|
||||||
self.moe = moe_cls(f"{prefix}.block_sparse_moe", config, weights)
|
SparseMoELayer if SparseMoELayer.is_supported(weights) else DenseMoELayer
|
||||||
|
)
|
||||||
|
self.moe = MixtralMoE(
|
||||||
|
f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights
|
||||||
|
)
|
||||||
|
|
||||||
self.input_layernorm = FastRMSNorm.load(
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
|
Loading…
Reference in New Issue