Move to moe-kernels package and switch to common MoE layer
This change introduces the new `moe-kernels` package: - Add `moe-kernels` as a dependency. - Introduce a `SparseMoELayer` module that can be used by MoE models. - Port over Mixtral and Deepseek.
This commit is contained in:
parent
7774655297
commit
5726a9ca81
|
@ -258,7 +258,7 @@ COPY server/Makefile server/Makefile
|
||||||
RUN cd server && \
|
RUN cd server && \
|
||||||
make gen-server && \
|
make gen-server && \
|
||||||
pip install -r requirements_cuda.txt && \
|
pip install -r requirements_cuda.txt && \
|
||||||
pip install ".[bnb, accelerate, marlin, quantize, peft, outlines]" --no-cache-dir && \
|
pip install ".[bnb, accelerate, marlin, moe, quantize, peft, outlines]" --no-cache-dir && \
|
||||||
pip install nvidia-nccl-cu12==2.22.3
|
pip install nvidia-nccl-cu12==2.22.3
|
||||||
|
|
||||||
ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
|
ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
|
||||||
|
|
|
@ -16,17 +16,17 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 28804,
|
"id": 28804,
|
||||||
"logprob": -7.4335938,
|
"logprob": -7.4375,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -0.8017578,
|
"logprob": -0.8046875,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -0.32958984,
|
"logprob": -0.33032227,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -64,7 +64,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 369,
|
"id": 369,
|
||||||
"logprob": -0.06585693,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " that"
|
"text": " that"
|
||||||
},
|
},
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
loguru,
|
loguru,
|
||||||
mamba-ssm,
|
mamba-ssm,
|
||||||
marlin-kernels,
|
marlin-kernels,
|
||||||
|
moe-kernels,
|
||||||
opentelemetry-api,
|
opentelemetry-api,
|
||||||
opentelemetry-exporter-otlp,
|
opentelemetry-exporter-otlp,
|
||||||
opentelemetry-instrumentation-grpc,
|
opentelemetry-instrumentation-grpc,
|
||||||
|
@ -88,6 +89,7 @@ buildPythonPackage {
|
||||||
loguru
|
loguru
|
||||||
mamba-ssm
|
mamba-ssm
|
||||||
marlin-kernels
|
marlin-kernels
|
||||||
|
moe-kernels
|
||||||
opentelemetry-api
|
opentelemetry-api
|
||||||
opentelemetry-exporter-otlp
|
opentelemetry-exporter-otlp
|
||||||
opentelemetry-instrumentation-grpc
|
opentelemetry-instrumentation-grpc
|
||||||
|
|
|
@ -1242,6 +1242,82 @@ files = [
|
||||||
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
|
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "moe-kernels"
|
||||||
|
version = "0.2.2"
|
||||||
|
description = "MoE kernels"
|
||||||
|
optional = true
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "moe_kernels-0.2.2+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:d268d818932ddcbca9bc71021dc63b008aae832827a7c0484cf206bd59cfc9ab"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
nvidia-ml-py = "*"
|
||||||
|
torch = "*"
|
||||||
|
triton = "*"
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "url"
|
||||||
|
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp310-cp310-linux_x86_64.whl"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "moe-kernels"
|
||||||
|
version = "0.2.2"
|
||||||
|
description = "MoE kernels"
|
||||||
|
optional = true
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "moe_kernels-0.2.2+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:614bbc3f41b707b0c40372f0bb00e218ad0842d306f90bef28ce8e98e7fcb7cb"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
nvidia-ml-py = "*"
|
||||||
|
torch = "*"
|
||||||
|
triton = "*"
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "url"
|
||||||
|
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp311-cp311-linux_x86_64.whl"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "moe-kernels"
|
||||||
|
version = "0.2.2"
|
||||||
|
description = "MoE kernels"
|
||||||
|
optional = true
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "moe_kernels-0.2.2+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:c2f48ed541353be03157d4015270dff797f7b7b8a664babdcbdf7414867d5abd"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
nvidia-ml-py = "*"
|
||||||
|
torch = "*"
|
||||||
|
triton = "*"
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "url"
|
||||||
|
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp312-cp312-linux_x86_64.whl"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "moe-kernels"
|
||||||
|
version = "0.2.2"
|
||||||
|
description = "MoE kernels"
|
||||||
|
optional = true
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "moe_kernels-0.2.2+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:d5f0339b73426c422872f7ff060433df6cd8e881451baf85ee7454e0e905f9d8"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
nvidia-ml-py = "*"
|
||||||
|
torch = "*"
|
||||||
|
triton = "*"
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "url"
|
||||||
|
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mpmath"
|
name = "mpmath"
|
||||||
version = "1.3.0"
|
version = "1.3.0"
|
||||||
|
@ -1600,6 +1676,17 @@ files = [
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
nvidia-nvjitlink-cu12 = "*"
|
nvidia-nvjitlink-cu12 = "*"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nvidia-ml-py"
|
||||||
|
version = "12.560.30"
|
||||||
|
description = "Python Bindings for the NVIDIA Management Library"
|
||||||
|
optional = true
|
||||||
|
python-versions = "*"
|
||||||
|
files = [
|
||||||
|
{file = "nvidia-ml-py-12.560.30.tar.gz", hash = "sha256:f0254dc7400647680a072ee02509bfd46102b60bdfeca321576d4d4817e7fe97"},
|
||||||
|
{file = "nvidia_ml_py-12.560.30-py3-none-any.whl", hash = "sha256:fea371c94d63e38a611c17bbb85fe400e9c8ddb9e8684a9cd0e47786a4bc3c73"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "nvidia-nccl-cu12"
|
name = "nvidia-nccl-cu12"
|
||||||
version = "2.20.5"
|
version = "2.20.5"
|
||||||
|
@ -3638,6 +3725,7 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
|
||||||
accelerate = ["accelerate"]
|
accelerate = ["accelerate"]
|
||||||
bnb = ["bitsandbytes"]
|
bnb = ["bitsandbytes"]
|
||||||
marlin = ["marlin-kernels", "marlin-kernels", "marlin-kernels", "marlin-kernels"]
|
marlin = ["marlin-kernels", "marlin-kernels", "marlin-kernels", "marlin-kernels"]
|
||||||
|
moe = ["moe-kernels", "moe-kernels", "moe-kernels", "moe-kernels"]
|
||||||
outlines = ["outlines"]
|
outlines = ["outlines"]
|
||||||
peft = ["peft"]
|
peft = ["peft"]
|
||||||
quantize = ["accelerate", "datasets", "texttable"]
|
quantize = ["accelerate", "datasets", "texttable"]
|
||||||
|
|
|
@ -46,6 +46,12 @@ marlin-kernels = [
|
||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||||
]
|
]
|
||||||
|
moe-kernels = [
|
||||||
|
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
||||||
|
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
||||||
|
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
||||||
|
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||||
|
]
|
||||||
rich = "^13.7.1"
|
rich = "^13.7.1"
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
|
@ -53,6 +59,7 @@ torch = ["torch"]
|
||||||
accelerate = ["accelerate"]
|
accelerate = ["accelerate"]
|
||||||
bnb = ["bitsandbytes"]
|
bnb = ["bitsandbytes"]
|
||||||
marlin = ["marlin-kernels"]
|
marlin = ["marlin-kernels"]
|
||||||
|
moe = ["moe-kernels"]
|
||||||
peft = ["peft"]
|
peft = ["peft"]
|
||||||
quantize = ["texttable", "datasets", "accelerate"]
|
quantize = ["texttable", "datasets", "accelerate"]
|
||||||
outlines = ["outlines"]
|
outlines = ["outlines"]
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||||
|
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
|
||||||
|
from text_generation_server.utils.weights import (
|
||||||
|
DefaultWeightsLoader,
|
||||||
|
UnquantizedWeight,
|
||||||
|
Weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SparseMoELayer(nn.Module):
|
||||||
|
"""
|
||||||
|
Layer for MoE that uses fused kernels to only apply the active experts
|
||||||
|
for each token (rather than applying all experts and selecting the
|
||||||
|
outputs of active experts).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
n_expert_group: Optional[int],
|
||||||
|
n_experts: int,
|
||||||
|
prefix: str,
|
||||||
|
renormalize: bool,
|
||||||
|
topk: int,
|
||||||
|
topk_group: Optional[int],
|
||||||
|
weights: Weights,
|
||||||
|
gate_proj_name: str = "gate_proj",
|
||||||
|
up_proj_name: str = "up_proj",
|
||||||
|
down_proj_name: str = "down_proj",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(weights.loader, DefaultWeightsLoader)
|
||||||
|
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
||||||
|
) or isinstance(weights.loader, HybridFP8UnquantLoader):
|
||||||
|
cls = UnquantizedSparseMoELayer
|
||||||
|
# Once we wire up GPTQ-Marlin MoE:
|
||||||
|
# elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym:
|
||||||
|
# cls = GPTQMarlinSparseMoELayer
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.moe = cls(
|
||||||
|
n_expert_group=n_expert_group,
|
||||||
|
n_experts=n_experts,
|
||||||
|
prefix=prefix,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk=topk,
|
||||||
|
topk_group=topk_group,
|
||||||
|
weights=weights,
|
||||||
|
gate_proj_name=gate_proj_name,
|
||||||
|
up_proj_name=up_proj_name,
|
||||||
|
down_proj_name=down_proj_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.moe(x, gating_output=gating_output)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_supported(weights: Weights) -> bool:
|
||||||
|
return (
|
||||||
|
(
|
||||||
|
isinstance(weights.loader, DefaultWeightsLoader)
|
||||||
|
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
||||||
|
)
|
||||||
|
or isinstance(weights.loader, HybridFP8UnquantLoader)
|
||||||
|
# Once we wire up GPTQ-Marlin MoE:
|
||||||
|
# or isinstance(weights.loader, GPTQMarlinWeightsLoader)
|
||||||
|
)
|
|
@ -0,0 +1,125 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
||||||
|
|
||||||
|
if SYSTEM != "ipex":
|
||||||
|
from moe_kernels.fused_moe import fused_moe
|
||||||
|
|
||||||
|
|
||||||
|
class UnquantizedSparseMoELayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
n_expert_group: Optional[int],
|
||||||
|
n_experts: int,
|
||||||
|
prefix: str,
|
||||||
|
renormalize: bool,
|
||||||
|
topk: int,
|
||||||
|
topk_group: Optional[int],
|
||||||
|
weights: Weights,
|
||||||
|
gate_proj_name: str = "gate_proj",
|
||||||
|
up_proj_name: str = "up_proj",
|
||||||
|
down_proj_name: str = "down_proj",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert (n_expert_group is None) == (
|
||||||
|
topk_group is None
|
||||||
|
), "n_expert_group and topk_group must both be None or have some value"
|
||||||
|
|
||||||
|
self.n_expert_group = n_expert_group
|
||||||
|
self.topk = topk
|
||||||
|
self.topk_group = topk_group
|
||||||
|
self.renormalize = renormalize
|
||||||
|
|
||||||
|
self.gate_up_proj = _load_expert_multi_weights_col(
|
||||||
|
prefix=prefix,
|
||||||
|
n_experts=n_experts,
|
||||||
|
gate_proj_name=gate_proj_name,
|
||||||
|
up_proj_name=up_proj_name,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_proj = _load_expert_weights_row(
|
||||||
|
prefix=prefix,
|
||||||
|
n_experts=n_experts,
|
||||||
|
name=down_proj_name,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||||
|
return fused_moe(
|
||||||
|
x,
|
||||||
|
w1=self.gate_up_proj,
|
||||||
|
w2=self.down_proj,
|
||||||
|
gating_output=gating_output,
|
||||||
|
topk=self.topk,
|
||||||
|
renormalize=self.renormalize,
|
||||||
|
inplace=True,
|
||||||
|
use_grouped_topk=self.n_expert_group is not None,
|
||||||
|
num_expert_group=self.n_expert_group,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_expert_multi_weights_col(
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
n_experts: int,
|
||||||
|
gate_proj_name: str,
|
||||||
|
up_proj_name: str,
|
||||||
|
weights: Weights,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
all_weight = None
|
||||||
|
for i in range(n_experts):
|
||||||
|
weight = weights.get_multi_weights_col(
|
||||||
|
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(weight, UnquantizedWeight)
|
||||||
|
|
||||||
|
if all_weight is None:
|
||||||
|
all_weight = torch.empty(
|
||||||
|
(n_experts,) + weight.weight.shape,
|
||||||
|
dtype=weight.weight.dtype,
|
||||||
|
device=weight.weight.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_weight[i] = weight.weight
|
||||||
|
|
||||||
|
assert all_weight is not None
|
||||||
|
|
||||||
|
return all_weight
|
||||||
|
|
||||||
|
|
||||||
|
def _load_expert_weights_row(
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
n_experts: int,
|
||||||
|
name: str,
|
||||||
|
weights: Weights,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
all_weight = None
|
||||||
|
for i in range(n_experts):
|
||||||
|
weight = weights.get_weights_row(
|
||||||
|
f"{prefix}.{i}.{name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(weight, UnquantizedWeight)
|
||||||
|
|
||||||
|
if all_weight is None:
|
||||||
|
all_weight = torch.empty(
|
||||||
|
(n_experts,) + weight.weight.shape,
|
||||||
|
dtype=weight.weight.dtype,
|
||||||
|
device=weight.weight.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_weight[i] = weight.weight
|
||||||
|
|
||||||
|
assert all_weight is not None
|
||||||
|
|
||||||
|
return all_weight
|
|
@ -13,8 +13,9 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from moe_kernels.fused_moe import grouped_topk
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
|
@ -32,6 +33,7 @@ from text_generation_server.layers.attention import (
|
||||||
Seqlen,
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||||
|
from text_generation_server.layers.moe import SparseMoELayer
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.weights import Weights
|
from text_generation_server.utils.weights import Weights
|
||||||
|
@ -152,44 +154,6 @@ class DeepseekV2Config(PretrainedConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _load_experts(config, prefix: str, mat: str, weights: Weights):
|
|
||||||
if config.quantize is not None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Deepseek V2 does not support weight quantization yet."
|
|
||||||
)
|
|
||||||
|
|
||||||
assert mat in ["gate_proj", "up_proj", "down_proj"]
|
|
||||||
|
|
||||||
world_size = weights.process_group.size()
|
|
||||||
rank = weights.process_group.rank()
|
|
||||||
|
|
||||||
assert (
|
|
||||||
config.moe_intermediate_size % world_size == 0
|
|
||||||
), f"The chosen size {config.moe_intermediate_size} is not compatible with sharding on {world_size} shards"
|
|
||||||
|
|
||||||
block_size = config.moe_intermediate_size // world_size
|
|
||||||
start = rank * block_size
|
|
||||||
stop = (rank + 1) * block_size
|
|
||||||
|
|
||||||
tensor = torch.empty(
|
|
||||||
(config.n_routed_experts * block_size, config.hidden_size),
|
|
||||||
dtype=weights.dtype,
|
|
||||||
device=weights.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in range(config.n_routed_experts):
|
|
||||||
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
|
|
||||||
|
|
||||||
if mat == "down_proj":
|
|
||||||
expert_slice = slice_[:, start:stop].t().contiguous()
|
|
||||||
else:
|
|
||||||
expert_slice = slice_[start:stop]
|
|
||||||
tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
|
|
||||||
dtype=weights.dtype
|
|
||||||
).to(device=weights.device)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2Attention(torch.nn.Module):
|
class DeepseekV2Attention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -452,33 +416,21 @@ class BlockSparseMoE(nn.Module):
|
||||||
self.moe_intermediate_size = (
|
self.moe_intermediate_size = (
|
||||||
config.moe_intermediate_size // weights.process_group.size()
|
config.moe_intermediate_size // weights.process_group.size()
|
||||||
)
|
)
|
||||||
self.n_routed_experts = config.n_routed_experts
|
|
||||||
self.n_expert_group = config.n_group
|
|
||||||
self.topk_group = config.topk_group
|
|
||||||
self.top_k = config.num_experts_per_tok
|
|
||||||
self.norm_topk_prob = config.norm_topk_prob
|
|
||||||
self.routed_scaling_factor = config.routed_scaling_factor
|
self.routed_scaling_factor = config.routed_scaling_factor
|
||||||
|
|
||||||
gate_proj = _load_experts(
|
|
||||||
config, f"{prefix}.experts", "gate_proj", weights
|
|
||||||
).view(self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim)
|
|
||||||
|
|
||||||
up_proj = _load_experts(config, f"{prefix}.experts", "up_proj", weights).view(
|
|
||||||
self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
self.gate_up_proj = torch.cat([gate_proj, up_proj], dim=1)
|
|
||||||
|
|
||||||
self.down_proj = (
|
|
||||||
_load_experts(config, f"{prefix}.experts", "down_proj", weights)
|
|
||||||
.view(self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim)
|
|
||||||
.transpose(1, 2)
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Gating
|
# Gating
|
||||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||||
|
|
||||||
|
self.moe_layer = SparseMoELayer(
|
||||||
|
prefix=f"{prefix}.experts",
|
||||||
|
n_experts=config.n_routed_experts,
|
||||||
|
n_expert_group=config.n_group,
|
||||||
|
renormalize=config.norm_topk_prob,
|
||||||
|
topk=config.num_experts_per_tok,
|
||||||
|
topk_group=config.topk_group,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
if config.n_shared_experts is not None:
|
if config.n_shared_experts is not None:
|
||||||
self.shared_experts = DeepseekV2MLP(
|
self.shared_experts = DeepseekV2MLP(
|
||||||
prefix=f"{prefix}.shared_experts",
|
prefix=f"{prefix}.shared_experts",
|
||||||
|
@ -499,25 +451,8 @@ class BlockSparseMoE(nn.Module):
|
||||||
shared_output = None
|
shared_output = None
|
||||||
|
|
||||||
router_logits = self.gate(x)
|
router_logits = self.gate(x)
|
||||||
topk_weights, topk_ids = grouped_topk(
|
|
||||||
x,
|
out = self.moe_layer(x, gating_output=router_logits)
|
||||||
router_logits,
|
|
||||||
self.top_k,
|
|
||||||
renormalize=self.norm_topk_prob,
|
|
||||||
num_expert_group=self.n_expert_group,
|
|
||||||
topk_group=self.topk_group,
|
|
||||||
)
|
|
||||||
out = (
|
|
||||||
fused_experts(
|
|
||||||
x,
|
|
||||||
self.gate_up_proj,
|
|
||||||
self.down_proj,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
inplace=True,
|
|
||||||
)
|
|
||||||
* self.routed_scaling_factor
|
|
||||||
)
|
|
||||||
|
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
out = out + shared_output
|
out = out + shared_output
|
||||||
|
@ -635,7 +570,9 @@ class DeepseekV2Layer(nn.Module):
|
||||||
and layer_id >= config.first_k_dense_replace
|
and layer_id >= config.first_k_dense_replace
|
||||||
and layer_id % config.moe_layer_freq == 0
|
and layer_id % config.moe_layer_freq == 0
|
||||||
):
|
):
|
||||||
moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE
|
moe_cls = (
|
||||||
|
BlockSparseMoE if SparseMoELayer.is_supported(weights) else DenseMoE
|
||||||
|
)
|
||||||
self.mlp = moe_cls(f"{prefix}.mlp", config, weights)
|
self.mlp = moe_cls(f"{prefix}.mlp", config, weights)
|
||||||
else:
|
else:
|
||||||
self.mlp = DeepseekV2MLP(
|
self.mlp = DeepseekV2MLP(
|
||||||
|
@ -799,183 +736,3 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits, speculative_logits = self.lm_head(hidden_states)
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
|
||||||
# Functions below are from vLLM:
|
|
||||||
#
|
|
||||||
# https://github.com/vllm-project/vllm/blob/f7160d946a0a07703e72d81ba9ecf3913f192605/vllm/model_executor/layers/fused_moe/fused_moe.py#L397
|
|
||||||
#
|
|
||||||
# Remove after we have synced our version with upstream.
|
|
||||||
|
|
||||||
|
|
||||||
def grouped_topk(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
gating_output: torch.Tensor,
|
|
||||||
topk: int,
|
|
||||||
renormalize: bool,
|
|
||||||
num_expert_group: int = 0,
|
|
||||||
topk_group: int = 0,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
scores = torch.softmax(gating_output, dim=-1)
|
|
||||||
num_token = scores.shape[0]
|
|
||||||
group_scores = (
|
|
||||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
|
||||||
) # [n, n_group]
|
|
||||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
|
||||||
1
|
|
||||||
] # [n, top_k_group]
|
|
||||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
|
||||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
|
||||||
score_mask = (
|
|
||||||
group_mask.unsqueeze(-1)
|
|
||||||
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
|
||||||
.reshape(num_token, -1)
|
|
||||||
) # [n, e]
|
|
||||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
|
||||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
|
||||||
|
|
||||||
if renormalize:
|
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
return topk_weights, topk_ids
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_config(
|
|
||||||
M: int,
|
|
||||||
E: int,
|
|
||||||
N: int,
|
|
||||||
K: int,
|
|
||||||
topk: int,
|
|
||||||
dtype: Optional[str],
|
|
||||||
) -> Dict[str, int]:
|
|
||||||
config = {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 64,
|
|
||||||
"BLOCK_SIZE_K": 32,
|
|
||||||
"GROUP_SIZE_M": 8,
|
|
||||||
}
|
|
||||||
if M <= E:
|
|
||||||
config = {
|
|
||||||
"BLOCK_SIZE_M": 16,
|
|
||||||
"BLOCK_SIZE_N": 32,
|
|
||||||
"BLOCK_SIZE_K": 64,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
}
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def fused_experts(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
inplace: bool = False,
|
|
||||||
override_config: Optional[Dict[str, Any]] = None,
|
|
||||||
use_fp8: bool = False,
|
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
# Check constraints.
|
|
||||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
|
||||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
|
||||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
|
||||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
|
||||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
|
||||||
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
|
||||||
|
|
||||||
import triton.language as tl
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
|
||||||
get_moe_configs,
|
|
||||||
invoke_fused_moe_kernel,
|
|
||||||
moe_align_block_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
M, _ = hidden_states.shape
|
|
||||||
E, N, _ = w1.shape
|
|
||||||
|
|
||||||
if override_config:
|
|
||||||
config = override_config
|
|
||||||
else:
|
|
||||||
# First try to load optimal config from the file
|
|
||||||
configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
|
|
||||||
|
|
||||||
if configs:
|
|
||||||
# If an optimal configuration map has been found, look up the
|
|
||||||
# optimal config
|
|
||||||
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
|
||||||
else:
|
|
||||||
# Else use the default config
|
|
||||||
config = get_default_config(
|
|
||||||
M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None
|
|
||||||
)
|
|
||||||
|
|
||||||
intermediate_cache1 = torch.empty(
|
|
||||||
(M, topk_ids.shape[1], N),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
intermediate_cache2 = torch.empty(
|
|
||||||
(M * topk_ids.shape[1], N // 2),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
intermediate_cache3 = torch.empty(
|
|
||||||
(M, topk_ids.shape[1], w2.shape[1]),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
|
||||||
topk_ids, config["BLOCK_SIZE_M"], E
|
|
||||||
)
|
|
||||||
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
|
||||||
|
|
||||||
invoke_fused_moe_kernel(
|
|
||||||
hidden_states,
|
|
||||||
w1,
|
|
||||||
intermediate_cache1,
|
|
||||||
a1_scale,
|
|
||||||
w1_scale,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
sorted_token_ids,
|
|
||||||
expert_ids,
|
|
||||||
num_tokens_post_padded,
|
|
||||||
False,
|
|
||||||
topk_ids.shape[1],
|
|
||||||
config,
|
|
||||||
compute_type=compute_type,
|
|
||||||
use_fp8=use_fp8,
|
|
||||||
)
|
|
||||||
|
|
||||||
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
|
||||||
|
|
||||||
invoke_fused_moe_kernel(
|
|
||||||
intermediate_cache2,
|
|
||||||
w2,
|
|
||||||
intermediate_cache3,
|
|
||||||
a2_scale,
|
|
||||||
w2_scale,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
sorted_token_ids,
|
|
||||||
expert_ids,
|
|
||||||
num_tokens_post_padded,
|
|
||||||
True,
|
|
||||||
1,
|
|
||||||
config,
|
|
||||||
compute_type=compute_type,
|
|
||||||
use_fp8=use_fp8,
|
|
||||||
)
|
|
||||||
|
|
||||||
if inplace:
|
|
||||||
return torch.sum(
|
|
||||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
|
||||||
dim=1,
|
|
||||||
out=hidden_states,
|
|
||||||
)
|
|
||||||
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
|
|
||||||
|
|
|
@ -25,8 +25,6 @@ import torch.distributed
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
if SYSTEM != "ipex":
|
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
@ -45,6 +43,7 @@ from text_generation_server.layers import (
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.moe import SparseMoELayer
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
|
@ -319,40 +318,21 @@ def round_up(x: torch.Tensor, value: int):
|
||||||
class BlockSparseMoE(nn.Module):
|
class BlockSparseMoE(nn.Module):
|
||||||
def __init__(self, prefix, config: MixtralConfig, weights):
|
def __init__(self, prefix, config: MixtralConfig, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_dim = config.hidden_size
|
|
||||||
self.ffn_dim = config.intermediate_size // weights.process_group.size()
|
|
||||||
self.num_experts = config.num_local_experts
|
|
||||||
self.top_k = config.num_experts_per_tok
|
|
||||||
|
|
||||||
act = config.hidden_act
|
|
||||||
if "gelu" in act:
|
|
||||||
self.act = lambda x: torch.nn.functional.gelu(
|
|
||||||
x,
|
|
||||||
approximate=(
|
|
||||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
elif "silu" in act:
|
|
||||||
self.act = torch.nn.functional.silu
|
|
||||||
else:
|
|
||||||
self.act = ACT2FN[act]
|
|
||||||
|
|
||||||
# gating
|
# gating
|
||||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||||
|
|
||||||
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
|
self.moe = SparseMoELayer(
|
||||||
w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).view(
|
n_expert_group=None,
|
||||||
self.num_experts, self.ffn_dim, self.hidden_dim
|
n_experts=config.num_local_experts,
|
||||||
)
|
prefix=f"{prefix}.experts",
|
||||||
w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).view(
|
renormalize=True,
|
||||||
self.num_experts, self.ffn_dim, self.hidden_dim
|
topk=config.num_experts_per_tok,
|
||||||
)
|
topk_group=None,
|
||||||
self.w13 = torch.cat([w1, w3], dim=1)
|
weights=weights,
|
||||||
self.w2 = (
|
gate_proj_name="w1",
|
||||||
_load_experts(config, f"{prefix}.experts", "w2", weights)
|
up_proj_name="w3",
|
||||||
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
|
down_proj_name="w2",
|
||||||
.transpose(1, 2)
|
|
||||||
.contiguous()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
|
@ -360,15 +340,7 @@ class BlockSparseMoE(nn.Module):
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.gate(x)
|
router_logits = self.gate(x)
|
||||||
out = fused_moe(
|
out = self.moe(x, gating_output=router_logits)
|
||||||
x,
|
|
||||||
self.w13,
|
|
||||||
self.w2,
|
|
||||||
router_logits,
|
|
||||||
self.top_k,
|
|
||||||
renormalize=True,
|
|
||||||
inplace=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reduce sum
|
# Reduce sum
|
||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1:
|
||||||
|
@ -475,7 +447,7 @@ class MixtralLayer(nn.Module):
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE
|
moe_cls = BlockSparseMoE if SparseMoELayer.is_supported(weights) else DenseMoE
|
||||||
self.moe = moe_cls(f"{prefix}.block_sparse_moe", config, weights)
|
self.moe = moe_cls(f"{prefix}.block_sparse_moe", config, weights)
|
||||||
|
|
||||||
self.input_layernorm = FastRMSNorm.load(
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
|
|
|
@ -120,7 +120,6 @@ class DefaultWeightsLoader(WeightsLoader):
|
||||||
prefix: str,
|
prefix: str,
|
||||||
block_sizes: Union[int, List[int]],
|
block_sizes: Union[int, List[int]],
|
||||||
):
|
):
|
||||||
|
|
||||||
return self.weight_class(
|
return self.weight_class(
|
||||||
weights.get_packed_sharded(
|
weights.get_packed_sharded(
|
||||||
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||||
|
@ -405,6 +404,10 @@ class Weights:
|
||||||
finally:
|
finally:
|
||||||
self.weights_loader = old_loader
|
self.weights_loader = old_loader
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loader(self):
|
||||||
|
return self.weights_loader
|
||||||
|
|
||||||
|
|
||||||
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
|
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue