Add `DenseMoELayer` and wire it up in Mixtral/Deepseek V2 (#2537)

This replaces the custom layers in both models.
This commit is contained in:
Daniël de Kok 2024-09-24 14:27:06 +02:00 committed by GitHub
parent c29dc89c18
commit 3f14cd1420
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 211 additions and 220 deletions

View File

@ -1,15 +1,178 @@
from typing import Optional from typing import Optional, Protocol, runtime_checkable
import torch import torch
import torch.nn as nn import torch.nn as nn
from loguru import logger
from transformers.activations import ACT2FN
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
)
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
DefaultWeightsLoader, DefaultWeightsLoader,
UnquantizedWeight, UnquantizedWeight,
Weights, Weights,
) )
if SYSTEM != "ipex":
from moe_kernels.fused_moe import fused_topk, grouped_topk
# NOTE: we are using a protocol here, because multiple inherance is not nice.
# We need `Module`, and `Module` -> some abstract class -> some concrete
# class inheritance is whacky.
@runtime_checkable
class MoELayer(Protocol):
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
hidden_act: str = "silu",
): ...
def forward(
self, x: torch.Tensor, *, gating_output: torch.Tensor
) -> torch.Tensor: ...
class DenseMoELayer(nn.Module):
"""
Layer for MoE that applies *all* experts to each tokens and then weights
their outputs based on the calculated routing. This layer is much slower
than `SparseMoELayer` and should only be used when no fused kernels are
available (e.g. for unsupported quantizers).
"""
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
hidden_act: str = "silu",
):
super().__init__()
log_once(
logger.info,
"No fused layers are available for this model type, using (slower) dense MoE layer",
)
assert (n_expert_group is None) == (
topk_group is None
), "n_expert_group and topk_group must both be None or have some value"
self.n_expert_group = n_expert_group
self.n_experts = n_experts
self.renormalize = renormalize
self.topk = topk
self.topk_group = topk_group
if "gelu" in hidden_act:
self.act = lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh"
if hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none"
),
)
elif "silu" in hidden_act:
self.act = torch.nn.functional.silu
else:
self.act = ACT2FN[hidden_act]
self.gate_proj = [
TensorParallelColumnLinear.load(
None,
prefix=f"{prefix}.{i}.{gate_proj_name}",
weights=weights,
bias=False,
)
for i in range(self.n_experts)
]
self.up_proj = [
TensorParallelColumnLinear.load(
None,
prefix=f"{prefix}.{i}.{up_proj_name}",
weights=weights,
bias=False,
)
for i in range(self.n_experts)
]
self.down_proj = [
TensorParallelRowLinear.load(
None,
prefix=f"{prefix}.{i}.{down_proj_name}",
weights=weights,
bias=False,
)
for i in range(self.n_experts)
]
self.process_group = weights.process_group
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gating_output: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
if self.n_expert_group is not None and self.topk_group is not None:
topk_weights, topk_ids = grouped_topk(
x,
gating_output,
self.topk,
renormalize=self.renormalize,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
)
else:
topk_weights, topk_ids = fused_topk(
x, gating_output, self.topk, self.renormalize
)
topk_weights = topk_weights.to(x.dtype)
weights = torch.zeros(
topk_ids.shape[0], self.n_experts, dtype=x.dtype, device=x.device
)
weights.scatter_(1, topk_ids.long(), topk_weights.to(weights.dtype))
out = torch.zeros_like(x)
for i in range(self.n_experts):
h = self.act(self.gate_proj[i](x)) * self.up_proj[i](x)
h = self.down_proj[i](h, reduce=False)
out += h * weights[:, i].view(-1, 1)
return out
class SparseMoELayer(nn.Module): class SparseMoELayer(nn.Module):
""" """

View File

@ -13,10 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Type
import torch import torch
import torch.distributed import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from text_generation_server.layers import ( from text_generation_server.layers import (
FastLinear, FastLinear,
SpeculativeHead, SpeculativeHead,
@ -26,22 +30,16 @@ from text_generation_server.layers import (
get_linear, get_linear,
) )
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
Seqlen,
attention, attention,
paged_attention, paged_attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import Weights from text_generation_server.utils.weights import Weights
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
if SYSTEM != "ipex":
from moe_kernels.fused_moe import grouped_topk
if SYSTEM == "rocm": if SYSTEM == "rocm":
try: try:
@ -410,8 +408,14 @@ class DeepseekV2MLP(nn.Module):
) )
class BlockSparseMoE(nn.Module): class DeepseekV2MoE(nn.Module):
def __init__(self, prefix, config: DeepseekV2Config, weights): def __init__(
self,
prefix,
config: DeepseekV2Config,
moe_layer_cls: Type[MoELayer],
weights,
):
super().__init__() super().__init__()
self.hidden_dim = config.hidden_size self.hidden_dim = config.hidden_size
@ -423,7 +427,7 @@ class BlockSparseMoE(nn.Module):
# Gating # Gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
self.moe_layer = SparseMoELayer( self.moe_layer = moe_layer_cls(
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
n_experts=config.n_routed_experts, n_experts=config.n_routed_experts,
n_expert_group=config.n_group, n_expert_group=config.n_group,
@ -432,6 +436,7 @@ class BlockSparseMoE(nn.Module):
topk_group=config.topk_group, topk_group=config.topk_group,
weights=weights, weights=weights,
) )
assert isinstance(self.moe_layer, MoELayer)
if config.n_shared_experts is not None: if config.n_shared_experts is not None:
self.shared_experts = DeepseekV2MLP( self.shared_experts = DeepseekV2MLP(
@ -466,96 +471,6 @@ class BlockSparseMoE(nn.Module):
return out.view(*x.shape) return out.view(*x.shape)
class DenseMoE(nn.Module):
def __init__(self, prefix: str, config: DeepseekV2Config, weights: Weights):
super().__init__()
self.hidden_dim = config.hidden_size
self.moe_intermediate_size = config.moe_intermediate_size
self.n_routed_experts = config.n_routed_experts
self.n_expert_group = config.n_group
self.topk_group = config.topk_group
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.routed_scaling_factor = config.routed_scaling_factor
# Gating
#
# Seems like no one quantizes the gate.
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
self.experts = [
DeepseekV2MLP(
f"{prefix}.experts.{i}", config, weights, self.moe_intermediate_size
)
for i in range(self.n_routed_experts)
]
if config.n_shared_experts is not None:
self.shared_experts = DeepseekV2MLP(
prefix=f"{prefix}.shared_experts",
config=config,
weights=weights,
intermediate_size=config.moe_intermediate_size
* config.n_shared_experts,
)
else:
self.shared_experts = None
self.process_group = weights.process_group
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
if self.shared_experts is not None:
shared_output = self.shared_experts(x, reduce=False)
else:
shared_output = None
# gate_logits: (sequence_length, n_experts)
router_logits = self.gate(x)
topk_weights, topk_ids = grouped_topk(
x,
router_logits,
self.top_k,
renormalize=self.norm_topk_prob,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
)
out = self.moe_infer_gpu(x, topk_ids, topk_weights) * self.routed_scaling_factor
if shared_output is not None:
out = out + shared_output
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out
def moe_infer_gpu(
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
):
weights = torch.zeros(
topk_ids.shape[0], len(self.experts), dtype=x.dtype, device=x.device
)
weights.scatter_(1, topk_ids, topk_weight)
out = x.new_zeros(x.shape[0], self.hidden_dim)
for i, expert in enumerate(self.experts):
# Add expert output to out with masking
out += expert(x, reduce=False) * weights[:, i].view(-1, 1)
return out
class DeepseekV2Layer(nn.Module): class DeepseekV2Layer(nn.Module):
def __init__(self, prefix, layer_id, config, weights): def __init__(self, prefix, layer_id, config, weights):
super().__init__() super().__init__()
@ -572,10 +487,12 @@ class DeepseekV2Layer(nn.Module):
and layer_id >= config.first_k_dense_replace and layer_id >= config.first_k_dense_replace
and layer_id % config.moe_layer_freq == 0 and layer_id % config.moe_layer_freq == 0
): ):
moe_cls = ( moe_layer_cls = (
BlockSparseMoE if SparseMoELayer.is_supported(weights) else DenseMoE SparseMoELayer
if SparseMoELayer.is_supported(weights)
else DenseMoELayer
) )
self.mlp = moe_cls(f"{prefix}.mlp", config, weights) self.mlp = DeepseekV2MoE(f"{prefix}.mlp", config, moe_layer_cls, weights)
else: else:
self.mlp = DeepseekV2MLP( self.mlp = DeepseekV2MLP(
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",

View File

@ -18,38 +18,31 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Optional, Tuple, Type
import torch import torch
import torch.distributed import torch.distributed
from torch import nn from torch import nn
from text_generation_server.utils.import_utils import SYSTEM
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import ( from text_generation_server.layers import (
FastLinear, FastLinear,
TensorParallelRowLinear, SpeculativeHead,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
SpeculativeHead, TensorParallelRowLinear,
get_linear, get_linear,
) )
from text_generation_server.layers.moe import SparseMoELayer from text_generation_server.layers.attention import (
from text_generation_server.layers.layernorm import ( Seqlen,
FastRMSNorm, attention,
) paged_attention,
from text_generation_server.layers.rotary import ( reshape_and_cache,
PositionRotaryEmbedding,
) )
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import UnquantizedWeight from text_generation_server.utils.weights import UnquantizedWeight
@ -315,14 +308,16 @@ def round_up(x: torch.Tensor, value: int):
return torch.div(x + (value - 1), value, rounding_mode="trunc") * value return torch.div(x + (value - 1), value, rounding_mode="trunc") * value
class BlockSparseMoE(nn.Module): class MixtralMoE(nn.Module):
def __init__(self, prefix, config: MixtralConfig, weights): def __init__(
self, prefix, config: MixtralConfig, moe_layer_cls: Type[MoELayer], weights
):
super().__init__() super().__init__()
# gating # gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
self.moe = SparseMoELayer( self.moe = moe_layer_cls(
n_expert_group=None, n_expert_group=None,
n_experts=config.num_local_experts, n_experts=config.num_local_experts,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
@ -334,6 +329,7 @@ class BlockSparseMoE(nn.Module):
up_proj_name="w3", up_proj_name="w3",
down_proj_name="w2", down_proj_name="w2",
) )
assert isinstance(self.moe, MoELayer)
self.process_group = weights.process_group self.process_group = weights.process_group
@ -349,95 +345,6 @@ class BlockSparseMoE(nn.Module):
return out.view(*x.shape) return out.view(*x.shape)
class DenseMoE(nn.Module):
def __init__(self, prefix, config: MixtralConfig, weights):
super().__init__()
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size // weights.process_group.size()
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
act = config.hidden_act
if "gelu" in act:
self.act = lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
elif "silu" in act:
self.act = torch.nn.functional.silu
else:
self.act = ACT2FN[act]
# gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
self.w1 = [
TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.experts.{i}.w1", weights=weights, bias=False
)
for i in range(self.num_experts)
]
self.w3 = [
TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.experts.{i}.w3", weights=weights, bias=False
)
for i in range(self.num_experts)
]
self.w2 = [
TensorParallelRowLinear.load(
config, prefix=f"{prefix}.experts.{i}.w2", weights=weights, bias=False
)
for i in range(self.num_experts)
]
self.process_group = weights.process_group
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
# gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x)
# all_probs: (sequence_length, n_experts) and upcast for softmax
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
if self.top_k < self.num_experts:
_, not_selected_experts = torch.topk(
all_probs,
self.num_experts - self.top_k,
largest=False,
sorted=False,
dim=1,
)
# Mask not selected experts
all_probs.scatter_(1, not_selected_experts, 0)
# Re-normalize
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
weights = weights.to(x.dtype)
# Final output tensor
out = x.new_zeros(x.shape[0], self.hidden_dim)
for i in range(self.num_experts):
h = self.act(self.w1[i](x)) * self.w3[i](x)
h = self.w2[i](h, reduce=False)
# Add expert output to out with masking
out += h * weights[:, i].view(-1, 1)
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out
class MixtralLayer(nn.Module): class MixtralLayer(nn.Module):
def __init__(self, prefix: str, layer_id, config, weights): def __init__(self, prefix: str, layer_id, config, weights):
super().__init__() super().__init__()
@ -447,8 +354,12 @@ class MixtralLayer(nn.Module):
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights
) )
moe_cls = BlockSparseMoE if SparseMoELayer.is_supported(weights) else DenseMoE moe_layer_cls = (
self.moe = moe_cls(f"{prefix}.block_sparse_moe", config, weights) SparseMoELayer if SparseMoELayer.is_supported(weights) else DenseMoELayer
)
self.moe = MixtralMoE(
f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights
)
self.input_layernorm = FastRMSNorm.load( self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps