Add `DenseMoELayer` and wire it up in Mixtral/Deepseek V2 (#2537)
This replaces the custom layers in both models.
This commit is contained in:
parent
c29dc89c18
commit
3f14cd1420
|
@ -1,15 +1,178 @@
|
|||
from typing import Optional
|
||||
from typing import Optional, Protocol, runtime_checkable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from loguru import logger
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import (
|
||||
DefaultWeightsLoader,
|
||||
UnquantizedWeight,
|
||||
Weights,
|
||||
)
|
||||
|
||||
if SYSTEM != "ipex":
|
||||
from moe_kernels.fused_moe import fused_topk, grouped_topk
|
||||
|
||||
|
||||
# NOTE: we are using a protocol here, because multiple inherance is not nice.
|
||||
# We need `Module`, and `Module` -> some abstract class -> some concrete
|
||||
# class inheritance is whacky.
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class MoELayer(Protocol):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
n_expert_group: Optional[int],
|
||||
n_experts: int,
|
||||
prefix: str,
|
||||
renormalize: bool,
|
||||
topk: int,
|
||||
topk_group: Optional[int],
|
||||
weights: Weights,
|
||||
gate_proj_name: str = "gate_proj",
|
||||
up_proj_name: str = "up_proj",
|
||||
down_proj_name: str = "down_proj",
|
||||
hidden_act: str = "silu",
|
||||
): ...
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, *, gating_output: torch.Tensor
|
||||
) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class DenseMoELayer(nn.Module):
|
||||
"""
|
||||
Layer for MoE that applies *all* experts to each tokens and then weights
|
||||
their outputs based on the calculated routing. This layer is much slower
|
||||
than `SparseMoELayer` and should only be used when no fused kernels are
|
||||
available (e.g. for unsupported quantizers).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
n_expert_group: Optional[int],
|
||||
n_experts: int,
|
||||
prefix: str,
|
||||
renormalize: bool,
|
||||
topk: int,
|
||||
topk_group: Optional[int],
|
||||
weights: Weights,
|
||||
gate_proj_name: str = "gate_proj",
|
||||
up_proj_name: str = "up_proj",
|
||||
down_proj_name: str = "down_proj",
|
||||
hidden_act: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
log_once(
|
||||
logger.info,
|
||||
"No fused layers are available for this model type, using (slower) dense MoE layer",
|
||||
)
|
||||
|
||||
assert (n_expert_group is None) == (
|
||||
topk_group is None
|
||||
), "n_expert_group and topk_group must both be None or have some value"
|
||||
|
||||
self.n_expert_group = n_expert_group
|
||||
self.n_experts = n_experts
|
||||
self.renormalize = renormalize
|
||||
self.topk = topk
|
||||
self.topk_group = topk_group
|
||||
|
||||
if "gelu" in hidden_act:
|
||||
self.act = lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate=(
|
||||
"tanh"
|
||||
if hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||
else "none"
|
||||
),
|
||||
)
|
||||
elif "silu" in hidden_act:
|
||||
self.act = torch.nn.functional.silu
|
||||
else:
|
||||
self.act = ACT2FN[hidden_act]
|
||||
|
||||
self.gate_proj = [
|
||||
TensorParallelColumnLinear.load(
|
||||
None,
|
||||
prefix=f"{prefix}.{i}.{gate_proj_name}",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
for i in range(self.n_experts)
|
||||
]
|
||||
self.up_proj = [
|
||||
TensorParallelColumnLinear.load(
|
||||
None,
|
||||
prefix=f"{prefix}.{i}.{up_proj_name}",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
for i in range(self.n_experts)
|
||||
]
|
||||
self.down_proj = [
|
||||
TensorParallelRowLinear.load(
|
||||
None,
|
||||
prefix=f"{prefix}.{i}.{down_proj_name}",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
for i in range(self.n_experts)
|
||||
]
|
||||
|
||||
self.process_group = weights.process_group
|
||||
|
||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x: (sequence_length, model_dim)
|
||||
gating_output: (sequence_length, n_experts)
|
||||
"""
|
||||
# optional reshape
|
||||
input_shape = x.shape
|
||||
x = x.view(-1, input_shape[-1])
|
||||
|
||||
if self.n_expert_group is not None and self.topk_group is not None:
|
||||
topk_weights, topk_ids = grouped_topk(
|
||||
x,
|
||||
gating_output,
|
||||
self.topk,
|
||||
renormalize=self.renormalize,
|
||||
num_expert_group=self.n_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = fused_topk(
|
||||
x, gating_output, self.topk, self.renormalize
|
||||
)
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
|
||||
weights = torch.zeros(
|
||||
topk_ids.shape[0], self.n_experts, dtype=x.dtype, device=x.device
|
||||
)
|
||||
|
||||
weights.scatter_(1, topk_ids.long(), topk_weights.to(weights.dtype))
|
||||
|
||||
out = torch.zeros_like(x)
|
||||
for i in range(self.n_experts):
|
||||
h = self.act(self.gate_proj[i](x)) * self.up_proj[i](x)
|
||||
h = self.down_proj[i](h, reduce=False)
|
||||
out += h * weights[:, i].view(-1, 1)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class SparseMoELayer(nn.Module):
|
||||
"""
|
||||
|
|
|
@ -13,10 +13,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from text_generation_server.layers import (
|
||||
FastLinear,
|
||||
SpeculativeHead,
|
||||
|
@ -26,22 +30,16 @@ from text_generation_server.layers import (
|
|||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||
from text_generation_server.layers.moe import SparseMoELayer
|
||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.weights import Weights
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
if SYSTEM != "ipex":
|
||||
from moe_kernels.fused_moe import grouped_topk
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
try:
|
||||
|
@ -410,8 +408,14 @@ class DeepseekV2MLP(nn.Module):
|
|||
)
|
||||
|
||||
|
||||
class BlockSparseMoE(nn.Module):
|
||||
def __init__(self, prefix, config: DeepseekV2Config, weights):
|
||||
class DeepseekV2MoE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
prefix,
|
||||
config: DeepseekV2Config,
|
||||
moe_layer_cls: Type[MoELayer],
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_dim = config.hidden_size
|
||||
|
@ -423,7 +427,7 @@ class BlockSparseMoE(nn.Module):
|
|||
# Gating
|
||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||
|
||||
self.moe_layer = SparseMoELayer(
|
||||
self.moe_layer = moe_layer_cls(
|
||||
prefix=f"{prefix}.experts",
|
||||
n_experts=config.n_routed_experts,
|
||||
n_expert_group=config.n_group,
|
||||
|
@ -432,6 +436,7 @@ class BlockSparseMoE(nn.Module):
|
|||
topk_group=config.topk_group,
|
||||
weights=weights,
|
||||
)
|
||||
assert isinstance(self.moe_layer, MoELayer)
|
||||
|
||||
if config.n_shared_experts is not None:
|
||||
self.shared_experts = DeepseekV2MLP(
|
||||
|
@ -466,96 +471,6 @@ class BlockSparseMoE(nn.Module):
|
|||
return out.view(*x.shape)
|
||||
|
||||
|
||||
class DenseMoE(nn.Module):
|
||||
def __init__(self, prefix: str, config: DeepseekV2Config, weights: Weights):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.moe_intermediate_size = config.moe_intermediate_size
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
self.n_expert_group = config.n_group
|
||||
self.topk_group = config.topk_group
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
|
||||
# Gating
|
||||
#
|
||||
# Seems like no one quantizes the gate.
|
||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||
|
||||
self.experts = [
|
||||
DeepseekV2MLP(
|
||||
f"{prefix}.experts.{i}", config, weights, self.moe_intermediate_size
|
||||
)
|
||||
for i in range(self.n_routed_experts)
|
||||
]
|
||||
|
||||
if config.n_shared_experts is not None:
|
||||
self.shared_experts = DeepseekV2MLP(
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
config=config,
|
||||
weights=weights,
|
||||
intermediate_size=config.moe_intermediate_size
|
||||
* config.n_shared_experts,
|
||||
)
|
||||
else:
|
||||
self.shared_experts = None
|
||||
|
||||
self.process_group = weights.process_group
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x: (sequence_length, model_dim)
|
||||
gate_logits: (sequence_length, n_experts)
|
||||
"""
|
||||
# optional reshape
|
||||
input_shape = x.shape
|
||||
x = x.view(-1, input_shape[-1])
|
||||
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(x, reduce=False)
|
||||
else:
|
||||
shared_output = None
|
||||
|
||||
# gate_logits: (sequence_length, n_experts)
|
||||
router_logits = self.gate(x)
|
||||
|
||||
topk_weights, topk_ids = grouped_topk(
|
||||
x,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=self.norm_topk_prob,
|
||||
num_expert_group=self.n_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
)
|
||||
|
||||
out = self.moe_infer_gpu(x, topk_ids, topk_weights) * self.routed_scaling_factor
|
||||
|
||||
if shared_output is not None:
|
||||
out = out + shared_output
|
||||
|
||||
# Reduce sum
|
||||
if self.process_group.size() > 1:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
|
||||
return out
|
||||
|
||||
def moe_infer_gpu(
|
||||
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
|
||||
):
|
||||
weights = torch.zeros(
|
||||
topk_ids.shape[0], len(self.experts), dtype=x.dtype, device=x.device
|
||||
)
|
||||
weights.scatter_(1, topk_ids, topk_weight)
|
||||
|
||||
out = x.new_zeros(x.shape[0], self.hidden_dim)
|
||||
for i, expert in enumerate(self.experts):
|
||||
# Add expert output to out with masking
|
||||
out += expert(x, reduce=False) * weights[:, i].view(-1, 1)
|
||||
return out
|
||||
|
||||
|
||||
class DeepseekV2Layer(nn.Module):
|
||||
def __init__(self, prefix, layer_id, config, weights):
|
||||
super().__init__()
|
||||
|
@ -572,10 +487,12 @@ class DeepseekV2Layer(nn.Module):
|
|||
and layer_id >= config.first_k_dense_replace
|
||||
and layer_id % config.moe_layer_freq == 0
|
||||
):
|
||||
moe_cls = (
|
||||
BlockSparseMoE if SparseMoELayer.is_supported(weights) else DenseMoE
|
||||
moe_layer_cls = (
|
||||
SparseMoELayer
|
||||
if SparseMoELayer.is_supported(weights)
|
||||
else DenseMoELayer
|
||||
)
|
||||
self.mlp = moe_cls(f"{prefix}.mlp", config, weights)
|
||||
self.mlp = DeepseekV2MoE(f"{prefix}.mlp", config, moe_layer_cls, weights)
|
||||
else:
|
||||
self.mlp = DeepseekV2MLP(
|
||||
prefix=f"{prefix}.mlp",
|
||||
|
|
|
@ -18,38 +18,31 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
||||
from torch import nn
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
FastLinear,
|
||||
TensorParallelRowLinear,
|
||||
SpeculativeHead,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
SpeculativeHead,
|
||||
TensorParallelRowLinear,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.moe import SparseMoELayer
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
)
|
||||
from text_generation_server.layers.rotary import (
|
||||
PositionRotaryEmbedding,
|
||||
from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
)
|
||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.weights import UnquantizedWeight
|
||||
|
||||
|
||||
|
@ -315,14 +308,16 @@ def round_up(x: torch.Tensor, value: int):
|
|||
return torch.div(x + (value - 1), value, rounding_mode="trunc") * value
|
||||
|
||||
|
||||
class BlockSparseMoE(nn.Module):
|
||||
def __init__(self, prefix, config: MixtralConfig, weights):
|
||||
class MixtralMoE(nn.Module):
|
||||
def __init__(
|
||||
self, prefix, config: MixtralConfig, moe_layer_cls: Type[MoELayer], weights
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# gating
|
||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||
|
||||
self.moe = SparseMoELayer(
|
||||
self.moe = moe_layer_cls(
|
||||
n_expert_group=None,
|
||||
n_experts=config.num_local_experts,
|
||||
prefix=f"{prefix}.experts",
|
||||
|
@ -334,6 +329,7 @@ class BlockSparseMoE(nn.Module):
|
|||
up_proj_name="w3",
|
||||
down_proj_name="w2",
|
||||
)
|
||||
assert isinstance(self.moe, MoELayer)
|
||||
|
||||
self.process_group = weights.process_group
|
||||
|
||||
|
@ -349,95 +345,6 @@ class BlockSparseMoE(nn.Module):
|
|||
return out.view(*x.shape)
|
||||
|
||||
|
||||
class DenseMoE(nn.Module):
|
||||
def __init__(self, prefix, config: MixtralConfig, weights):
|
||||
super().__init__()
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.ffn_dim = config.intermediate_size // weights.process_group.size()
|
||||
self.num_experts = config.num_local_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
act = config.hidden_act
|
||||
if "gelu" in act:
|
||||
self.act = lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate=(
|
||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||
),
|
||||
)
|
||||
elif "silu" in act:
|
||||
self.act = torch.nn.functional.silu
|
||||
else:
|
||||
self.act = ACT2FN[act]
|
||||
|
||||
# gating
|
||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||
|
||||
self.w1 = [
|
||||
TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.experts.{i}.w1", weights=weights, bias=False
|
||||
)
|
||||
for i in range(self.num_experts)
|
||||
]
|
||||
self.w3 = [
|
||||
TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.experts.{i}.w3", weights=weights, bias=False
|
||||
)
|
||||
for i in range(self.num_experts)
|
||||
]
|
||||
self.w2 = [
|
||||
TensorParallelRowLinear.load(
|
||||
config, prefix=f"{prefix}.experts.{i}.w2", weights=weights, bias=False
|
||||
)
|
||||
for i in range(self.num_experts)
|
||||
]
|
||||
|
||||
self.process_group = weights.process_group
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x: (sequence_length, model_dim)
|
||||
gate_logits: (sequence_length, n_experts)
|
||||
"""
|
||||
# optional reshape
|
||||
input_shape = x.shape
|
||||
x = x.view(-1, input_shape[-1])
|
||||
|
||||
# gate_logits: (sequence_length, n_experts)
|
||||
gate_logits = self.gate(x)
|
||||
# all_probs: (sequence_length, n_experts) and upcast for softmax
|
||||
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
||||
|
||||
if self.top_k < self.num_experts:
|
||||
_, not_selected_experts = torch.topk(
|
||||
all_probs,
|
||||
self.num_experts - self.top_k,
|
||||
largest=False,
|
||||
sorted=False,
|
||||
dim=1,
|
||||
)
|
||||
# Mask not selected experts
|
||||
all_probs.scatter_(1, not_selected_experts, 0)
|
||||
|
||||
# Re-normalize
|
||||
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
|
||||
weights = weights.to(x.dtype)
|
||||
|
||||
# Final output tensor
|
||||
out = x.new_zeros(x.shape[0], self.hidden_dim)
|
||||
for i in range(self.num_experts):
|
||||
h = self.act(self.w1[i](x)) * self.w3[i](x)
|
||||
h = self.w2[i](h, reduce=False)
|
||||
# Add expert output to out with masking
|
||||
out += h * weights[:, i].view(-1, 1)
|
||||
|
||||
# Reduce sum
|
||||
if self.process_group.size() > 1:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class MixtralLayer(nn.Module):
|
||||
def __init__(self, prefix: str, layer_id, config, weights):
|
||||
super().__init__()
|
||||
|
@ -447,8 +354,12 @@ class MixtralLayer(nn.Module):
|
|||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
|
||||
moe_cls = BlockSparseMoE if SparseMoELayer.is_supported(weights) else DenseMoE
|
||||
self.moe = moe_cls(f"{prefix}.block_sparse_moe", config, weights)
|
||||
moe_layer_cls = (
|
||||
SparseMoELayer if SparseMoELayer.is_supported(weights) else DenseMoELayer
|
||||
)
|
||||
self.moe = MixtralMoE(
|
||||
f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights
|
||||
)
|
||||
|
||||
self.input_layernorm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||
|
|
Loading…
Reference in New Issue