Move to moe-kernels package and switch to common MoE layer
This change introduces the new `moe-kernels` package: - Add `moe-kernels` as a dependency. - Introduce a `SparseMoELayer` module that can be used by MoE models. - Port over Mixtral and Deepseek.
This commit is contained in:
parent
7774655297
commit
5726a9ca81
|
@ -258,7 +258,7 @@ COPY server/Makefile server/Makefile
|
|||
RUN cd server && \
|
||||
make gen-server && \
|
||||
pip install -r requirements_cuda.txt && \
|
||||
pip install ".[bnb, accelerate, marlin, quantize, peft, outlines]" --no-cache-dir && \
|
||||
pip install ".[bnb, accelerate, marlin, moe, quantize, peft, outlines]" --no-cache-dir && \
|
||||
pip install nvidia-nccl-cu12==2.22.3
|
||||
|
||||
ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
|
||||
|
|
|
@ -16,17 +16,17 @@
|
|||
},
|
||||
{
|
||||
"id": 28804,
|
||||
"logprob": -7.4335938,
|
||||
"logprob": -7.4375,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.8017578,
|
||||
"logprob": -0.8046875,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.32958984,
|
||||
"logprob": -0.33032227,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
|
@ -64,7 +64,7 @@
|
|||
},
|
||||
{
|
||||
"id": 369,
|
||||
"logprob": -0.06585693,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
loguru,
|
||||
mamba-ssm,
|
||||
marlin-kernels,
|
||||
moe-kernels,
|
||||
opentelemetry-api,
|
||||
opentelemetry-exporter-otlp,
|
||||
opentelemetry-instrumentation-grpc,
|
||||
|
@ -88,6 +89,7 @@ buildPythonPackage {
|
|||
loguru
|
||||
mamba-ssm
|
||||
marlin-kernels
|
||||
moe-kernels
|
||||
opentelemetry-api
|
||||
opentelemetry-exporter-otlp
|
||||
opentelemetry-instrumentation-grpc
|
||||
|
|
|
@ -1242,6 +1242,82 @@ files = [
|
|||
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "moe-kernels"
|
||||
version = "0.2.2"
|
||||
description = "MoE kernels"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "moe_kernels-0.2.2+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:d268d818932ddcbca9bc71021dc63b008aae832827a7c0484cf206bd59cfc9ab"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
nvidia-ml-py = "*"
|
||||
torch = "*"
|
||||
triton = "*"
|
||||
|
||||
[package.source]
|
||||
type = "url"
|
||||
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp310-cp310-linux_x86_64.whl"
|
||||
|
||||
[[package]]
|
||||
name = "moe-kernels"
|
||||
version = "0.2.2"
|
||||
description = "MoE kernels"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "moe_kernels-0.2.2+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:614bbc3f41b707b0c40372f0bb00e218ad0842d306f90bef28ce8e98e7fcb7cb"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
nvidia-ml-py = "*"
|
||||
torch = "*"
|
||||
triton = "*"
|
||||
|
||||
[package.source]
|
||||
type = "url"
|
||||
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp311-cp311-linux_x86_64.whl"
|
||||
|
||||
[[package]]
|
||||
name = "moe-kernels"
|
||||
version = "0.2.2"
|
||||
description = "MoE kernels"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "moe_kernels-0.2.2+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:c2f48ed541353be03157d4015270dff797f7b7b8a664babdcbdf7414867d5abd"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
nvidia-ml-py = "*"
|
||||
torch = "*"
|
||||
triton = "*"
|
||||
|
||||
[package.source]
|
||||
type = "url"
|
||||
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp312-cp312-linux_x86_64.whl"
|
||||
|
||||
[[package]]
|
||||
name = "moe-kernels"
|
||||
version = "0.2.2"
|
||||
description = "MoE kernels"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "moe_kernels-0.2.2+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:d5f0339b73426c422872f7ff060433df6cd8e881451baf85ee7454e0e905f9d8"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
nvidia-ml-py = "*"
|
||||
torch = "*"
|
||||
triton = "*"
|
||||
|
||||
[package.source]
|
||||
type = "url"
|
||||
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
|
||||
|
||||
[[package]]
|
||||
name = "mpmath"
|
||||
version = "1.3.0"
|
||||
|
@ -1600,6 +1676,17 @@ files = [
|
|||
[package.dependencies]
|
||||
nvidia-nvjitlink-cu12 = "*"
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-ml-py"
|
||||
version = "12.560.30"
|
||||
description = "Python Bindings for the NVIDIA Management Library"
|
||||
optional = true
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "nvidia-ml-py-12.560.30.tar.gz", hash = "sha256:f0254dc7400647680a072ee02509bfd46102b60bdfeca321576d4d4817e7fe97"},
|
||||
{file = "nvidia_ml_py-12.560.30-py3-none-any.whl", hash = "sha256:fea371c94d63e38a611c17bbb85fe400e9c8ddb9e8684a9cd0e47786a4bc3c73"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-nccl-cu12"
|
||||
version = "2.20.5"
|
||||
|
@ -3638,6 +3725,7 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
|
|||
accelerate = ["accelerate"]
|
||||
bnb = ["bitsandbytes"]
|
||||
marlin = ["marlin-kernels", "marlin-kernels", "marlin-kernels", "marlin-kernels"]
|
||||
moe = ["moe-kernels", "moe-kernels", "moe-kernels", "moe-kernels"]
|
||||
outlines = ["outlines"]
|
||||
peft = ["peft"]
|
||||
quantize = ["accelerate", "datasets", "texttable"]
|
||||
|
|
|
@ -46,6 +46,12 @@ marlin-kernels = [
|
|||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||
]
|
||||
moe-kernels = [
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||
]
|
||||
rich = "^13.7.1"
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
@ -53,6 +59,7 @@ torch = ["torch"]
|
|||
accelerate = ["accelerate"]
|
||||
bnb = ["bitsandbytes"]
|
||||
marlin = ["marlin-kernels"]
|
||||
moe = ["moe-kernels"]
|
||||
peft = ["peft"]
|
||||
quantize = ["texttable", "datasets", "accelerate"]
|
||||
outlines = ["outlines"]
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
|
||||
from text_generation_server.utils.weights import (
|
||||
DefaultWeightsLoader,
|
||||
UnquantizedWeight,
|
||||
Weights,
|
||||
)
|
||||
|
||||
|
||||
class SparseMoELayer(nn.Module):
|
||||
"""
|
||||
Layer for MoE that uses fused kernels to only apply the active experts
|
||||
for each token (rather than applying all experts and selecting the
|
||||
outputs of active experts).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
n_expert_group: Optional[int],
|
||||
n_experts: int,
|
||||
prefix: str,
|
||||
renormalize: bool,
|
||||
topk: int,
|
||||
topk_group: Optional[int],
|
||||
weights: Weights,
|
||||
gate_proj_name: str = "gate_proj",
|
||||
up_proj_name: str = "up_proj",
|
||||
down_proj_name: str = "down_proj",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if (
|
||||
isinstance(weights.loader, DefaultWeightsLoader)
|
||||
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
||||
) or isinstance(weights.loader, HybridFP8UnquantLoader):
|
||||
cls = UnquantizedSparseMoELayer
|
||||
# Once we wire up GPTQ-Marlin MoE:
|
||||
# elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym:
|
||||
# cls = GPTQMarlinSparseMoELayer
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights"
|
||||
)
|
||||
|
||||
self.moe = cls(
|
||||
n_expert_group=n_expert_group,
|
||||
n_experts=n_experts,
|
||||
prefix=prefix,
|
||||
renormalize=renormalize,
|
||||
topk=topk,
|
||||
topk_group=topk_group,
|
||||
weights=weights,
|
||||
gate_proj_name=gate_proj_name,
|
||||
up_proj_name=up_proj_name,
|
||||
down_proj_name=down_proj_name,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||
return self.moe(x, gating_output=gating_output)
|
||||
|
||||
@staticmethod
|
||||
def is_supported(weights: Weights) -> bool:
|
||||
return (
|
||||
(
|
||||
isinstance(weights.loader, DefaultWeightsLoader)
|
||||
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
||||
)
|
||||
or isinstance(weights.loader, HybridFP8UnquantLoader)
|
||||
# Once we wire up GPTQ-Marlin MoE:
|
||||
# or isinstance(weights.loader, GPTQMarlinWeightsLoader)
|
||||
)
|
|
@ -0,0 +1,125 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
||||
|
||||
if SYSTEM != "ipex":
|
||||
from moe_kernels.fused_moe import fused_moe
|
||||
|
||||
|
||||
class UnquantizedSparseMoELayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
n_expert_group: Optional[int],
|
||||
n_experts: int,
|
||||
prefix: str,
|
||||
renormalize: bool,
|
||||
topk: int,
|
||||
topk_group: Optional[int],
|
||||
weights: Weights,
|
||||
gate_proj_name: str = "gate_proj",
|
||||
up_proj_name: str = "up_proj",
|
||||
down_proj_name: str = "down_proj",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert (n_expert_group is None) == (
|
||||
topk_group is None
|
||||
), "n_expert_group and topk_group must both be None or have some value"
|
||||
|
||||
self.n_expert_group = n_expert_group
|
||||
self.topk = topk
|
||||
self.topk_group = topk_group
|
||||
self.renormalize = renormalize
|
||||
|
||||
self.gate_up_proj = _load_expert_multi_weights_col(
|
||||
prefix=prefix,
|
||||
n_experts=n_experts,
|
||||
gate_proj_name=gate_proj_name,
|
||||
up_proj_name=up_proj_name,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.down_proj = _load_expert_weights_row(
|
||||
prefix=prefix,
|
||||
n_experts=n_experts,
|
||||
name=down_proj_name,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||
return fused_moe(
|
||||
x,
|
||||
w1=self.gate_up_proj,
|
||||
w2=self.down_proj,
|
||||
gating_output=gating_output,
|
||||
topk=self.topk,
|
||||
renormalize=self.renormalize,
|
||||
inplace=True,
|
||||
use_grouped_topk=self.n_expert_group is not None,
|
||||
num_expert_group=self.n_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
)
|
||||
|
||||
|
||||
def _load_expert_multi_weights_col(
|
||||
*,
|
||||
prefix: str,
|
||||
n_experts: int,
|
||||
gate_proj_name: str,
|
||||
up_proj_name: str,
|
||||
weights: Weights,
|
||||
) -> torch.Tensor:
|
||||
all_weight = None
|
||||
for i in range(n_experts):
|
||||
weight = weights.get_multi_weights_col(
|
||||
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
|
||||
)
|
||||
|
||||
assert isinstance(weight, UnquantizedWeight)
|
||||
|
||||
if all_weight is None:
|
||||
all_weight = torch.empty(
|
||||
(n_experts,) + weight.weight.shape,
|
||||
dtype=weight.weight.dtype,
|
||||
device=weight.weight.device,
|
||||
)
|
||||
|
||||
all_weight[i] = weight.weight
|
||||
|
||||
assert all_weight is not None
|
||||
|
||||
return all_weight
|
||||
|
||||
|
||||
def _load_expert_weights_row(
|
||||
*,
|
||||
prefix: str,
|
||||
n_experts: int,
|
||||
name: str,
|
||||
weights: Weights,
|
||||
) -> torch.Tensor:
|
||||
all_weight = None
|
||||
for i in range(n_experts):
|
||||
weight = weights.get_weights_row(
|
||||
f"{prefix}.{i}.{name}",
|
||||
)
|
||||
|
||||
assert isinstance(weight, UnquantizedWeight)
|
||||
|
||||
if all_weight is None:
|
||||
all_weight = torch.empty(
|
||||
(n_experts,) + weight.weight.shape,
|
||||
dtype=weight.weight.dtype,
|
||||
device=weight.weight.device,
|
||||
)
|
||||
|
||||
all_weight[i] = weight.weight
|
||||
|
||||
assert all_weight is not None
|
||||
|
||||
return all_weight
|
|
@ -13,8 +13,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from moe_kernels.fused_moe import grouped_topk
|
||||
import torch
|
||||
import torch.distributed
|
||||
from text_generation_server.layers import (
|
||||
|
@ -32,6 +33,7 @@ from text_generation_server.layers.attention import (
|
|||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||
from text_generation_server.layers.moe import SparseMoELayer
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.weights import Weights
|
||||
|
@ -152,44 +154,6 @@ class DeepseekV2Config(PretrainedConfig):
|
|||
)
|
||||
|
||||
|
||||
def _load_experts(config, prefix: str, mat: str, weights: Weights):
|
||||
if config.quantize is not None:
|
||||
raise NotImplementedError(
|
||||
"Deepseek V2 does not support weight quantization yet."
|
||||
)
|
||||
|
||||
assert mat in ["gate_proj", "up_proj", "down_proj"]
|
||||
|
||||
world_size = weights.process_group.size()
|
||||
rank = weights.process_group.rank()
|
||||
|
||||
assert (
|
||||
config.moe_intermediate_size % world_size == 0
|
||||
), f"The chosen size {config.moe_intermediate_size} is not compatible with sharding on {world_size} shards"
|
||||
|
||||
block_size = config.moe_intermediate_size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
|
||||
tensor = torch.empty(
|
||||
(config.n_routed_experts * block_size, config.hidden_size),
|
||||
dtype=weights.dtype,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
for i in range(config.n_routed_experts):
|
||||
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
|
||||
|
||||
if mat == "down_proj":
|
||||
expert_slice = slice_[:, start:stop].t().contiguous()
|
||||
else:
|
||||
expert_slice = slice_[start:stop]
|
||||
tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
|
||||
dtype=weights.dtype
|
||||
).to(device=weights.device)
|
||||
return tensor
|
||||
|
||||
|
||||
class DeepseekV2Attention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -452,33 +416,21 @@ class BlockSparseMoE(nn.Module):
|
|||
self.moe_intermediate_size = (
|
||||
config.moe_intermediate_size // weights.process_group.size()
|
||||
)
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
self.n_expert_group = config.n_group
|
||||
self.topk_group = config.topk_group
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
|
||||
gate_proj = _load_experts(
|
||||
config, f"{prefix}.experts", "gate_proj", weights
|
||||
).view(self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim)
|
||||
|
||||
up_proj = _load_experts(config, f"{prefix}.experts", "up_proj", weights).view(
|
||||
self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim
|
||||
)
|
||||
|
||||
self.gate_up_proj = torch.cat([gate_proj, up_proj], dim=1)
|
||||
|
||||
self.down_proj = (
|
||||
_load_experts(config, f"{prefix}.experts", "down_proj", weights)
|
||||
.view(self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
# Gating
|
||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||
|
||||
self.moe_layer = SparseMoELayer(
|
||||
prefix=f"{prefix}.experts",
|
||||
n_experts=config.n_routed_experts,
|
||||
n_expert_group=config.n_group,
|
||||
renormalize=config.norm_topk_prob,
|
||||
topk=config.num_experts_per_tok,
|
||||
topk_group=config.topk_group,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
if config.n_shared_experts is not None:
|
||||
self.shared_experts = DeepseekV2MLP(
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
|
@ -499,25 +451,8 @@ class BlockSparseMoE(nn.Module):
|
|||
shared_output = None
|
||||
|
||||
router_logits = self.gate(x)
|
||||
topk_weights, topk_ids = grouped_topk(
|
||||
x,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=self.norm_topk_prob,
|
||||
num_expert_group=self.n_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
)
|
||||
out = (
|
||||
fused_experts(
|
||||
x,
|
||||
self.gate_up_proj,
|
||||
self.down_proj,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=True,
|
||||
)
|
||||
* self.routed_scaling_factor
|
||||
)
|
||||
|
||||
out = self.moe_layer(x, gating_output=router_logits)
|
||||
|
||||
if shared_output is not None:
|
||||
out = out + shared_output
|
||||
|
@ -635,7 +570,9 @@ class DeepseekV2Layer(nn.Module):
|
|||
and layer_id >= config.first_k_dense_replace
|
||||
and layer_id % config.moe_layer_freq == 0
|
||||
):
|
||||
moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE
|
||||
moe_cls = (
|
||||
BlockSparseMoE if SparseMoELayer.is_supported(weights) else DenseMoE
|
||||
)
|
||||
self.mlp = moe_cls(f"{prefix}.mlp", config, weights)
|
||||
else:
|
||||
self.mlp = DeepseekV2MLP(
|
||||
|
@ -799,183 +736,3 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
|
|||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
||||
|
||||
|
||||
# Functions below are from vLLM:
|
||||
#
|
||||
# https://github.com/vllm-project/vllm/blob/f7160d946a0a07703e72d81ba9ecf3913f192605/vllm/model_executor/layers/fused_moe/fused_moe.py#L397
|
||||
#
|
||||
# Remove after we have synced our version with upstream.
|
||||
|
||||
|
||||
def grouped_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
num_token = scores.shape[0]
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||
) # [n, n_group]
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
||||
1
|
||||
] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
||||
.reshape(num_token, -1)
|
||||
) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def get_default_config(
|
||||
M: int,
|
||||
E: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
dtype: Optional[str],
|
||||
) -> Dict[str, int]:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
}
|
||||
if M <= E:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
override_config: Optional[Dict[str, Any]] = None,
|
||||
use_fp8: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
import triton.language as tl
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
get_moe_configs,
|
||||
invoke_fused_moe_kernel,
|
||||
moe_align_block_size,
|
||||
)
|
||||
|
||||
M, _ = hidden_states.shape
|
||||
E, N, _ = w1.shape
|
||||
|
||||
if override_config:
|
||||
config = override_config
|
||||
else:
|
||||
# First try to load optimal config from the file
|
||||
configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
|
||||
|
||||
if configs:
|
||||
# If an optimal configuration map has been found, look up the
|
||||
# optimal config
|
||||
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
||||
else:
|
||||
# Else use the default config
|
||||
config = get_default_config(
|
||||
M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None
|
||||
)
|
||||
|
||||
intermediate_cache1 = torch.empty(
|
||||
(M, topk_ids.shape[1], N),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache2 = torch.empty(
|
||||
(M * topk_ids.shape[1], N // 2),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache3 = torch.empty(
|
||||
(M, topk_ids.shape[1], w2.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
topk_ids, config["BLOCK_SIZE_M"], E
|
||||
)
|
||||
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
hidden_states,
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
a1_scale,
|
||||
w1_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
False,
|
||||
topk_ids.shape[1],
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8=use_fp8,
|
||||
)
|
||||
|
||||
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
intermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
True,
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8=use_fp8,
|
||||
)
|
||||
|
||||
if inplace:
|
||||
return torch.sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
dim=1,
|
||||
out=hidden_states,
|
||||
)
|
||||
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
|
||||
|
|
|
@ -25,8 +25,6 @@ import torch.distributed
|
|||
from torch import nn
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if SYSTEM != "ipex":
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
@ -45,6 +43,7 @@ from text_generation_server.layers import (
|
|||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.moe import SparseMoELayer
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
)
|
||||
|
@ -319,40 +318,21 @@ def round_up(x: torch.Tensor, value: int):
|
|||
class BlockSparseMoE(nn.Module):
|
||||
def __init__(self, prefix, config: MixtralConfig, weights):
|
||||
super().__init__()
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.ffn_dim = config.intermediate_size // weights.process_group.size()
|
||||
self.num_experts = config.num_local_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
act = config.hidden_act
|
||||
if "gelu" in act:
|
||||
self.act = lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate=(
|
||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||
),
|
||||
)
|
||||
elif "silu" in act:
|
||||
self.act = torch.nn.functional.silu
|
||||
else:
|
||||
self.act = ACT2FN[act]
|
||||
|
||||
# gating
|
||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||
|
||||
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
|
||||
w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).view(
|
||||
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||
)
|
||||
w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).view(
|
||||
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||
)
|
||||
self.w13 = torch.cat([w1, w3], dim=1)
|
||||
self.w2 = (
|
||||
_load_experts(config, f"{prefix}.experts", "w2", weights)
|
||||
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
self.moe = SparseMoELayer(
|
||||
n_expert_group=None,
|
||||
n_experts=config.num_local_experts,
|
||||
prefix=f"{prefix}.experts",
|
||||
renormalize=True,
|
||||
topk=config.num_experts_per_tok,
|
||||
topk_group=None,
|
||||
weights=weights,
|
||||
gate_proj_name="w1",
|
||||
up_proj_name="w3",
|
||||
down_proj_name="w2",
|
||||
)
|
||||
|
||||
self.process_group = weights.process_group
|
||||
|
@ -360,15 +340,7 @@ class BlockSparseMoE(nn.Module):
|
|||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(x)
|
||||
out = fused_moe(
|
||||
x,
|
||||
self.w13,
|
||||
self.w2,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
)
|
||||
out = self.moe(x, gating_output=router_logits)
|
||||
|
||||
# Reduce sum
|
||||
if self.process_group.size() > 1:
|
||||
|
@ -475,7 +447,7 @@ class MixtralLayer(nn.Module):
|
|||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
|
||||
moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE
|
||||
moe_cls = BlockSparseMoE if SparseMoELayer.is_supported(weights) else DenseMoE
|
||||
self.moe = moe_cls(f"{prefix}.block_sparse_moe", config, weights)
|
||||
|
||||
self.input_layernorm = FastRMSNorm.load(
|
||||
|
|
|
@ -120,7 +120,6 @@ class DefaultWeightsLoader(WeightsLoader):
|
|||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
|
||||
return self.weight_class(
|
||||
weights.get_packed_sharded(
|
||||
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||
|
@ -405,6 +404,10 @@ class Weights:
|
|||
finally:
|
||||
self.weights_loader = old_loader
|
||||
|
||||
@property
|
||||
def loader(self):
|
||||
return self.weights_loader
|
||||
|
||||
|
||||
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue