diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index d415f369..ce1cdc33 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -45,7 +45,7 @@ jobs: export dockerfile="Dockerfile" export label_extension="" export docker_devices="" - export runs_on="aws-g6-12xlarge-plus-priv" + export runs_on="aws-g6-12xl-plus-priv-cache" export platform="" ;; rocm) diff --git a/Dockerfile b/Dockerfile index 37ced781..dc2c1d9f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -258,7 +258,7 @@ COPY server/Makefile server/Makefile RUN cd server && \ make gen-server && \ pip install -r requirements_cuda.txt && \ - pip install ".[bnb, accelerate, marlin, quantize, peft, outlines]" --no-cache-dir && \ + pip install ".[bnb, accelerate, marlin, moe, quantize, peft, outlines]" --no-cache-dir && \ pip install nvidia-nccl-cu12==2.22.3 ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2 diff --git a/integration-tests/models/__snapshots__/test_flash_mixtral/test_flash_mixtral_all_params.json b/integration-tests/models/__snapshots__/test_flash_mixtral/test_flash_mixtral_all_params.json index 00da1fed..e25086c6 100644 --- a/integration-tests/models/__snapshots__/test_flash_mixtral/test_flash_mixtral_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_mixtral/test_flash_mixtral_all_params.json @@ -16,17 +16,17 @@ }, { "id": 28804, - "logprob": -7.4335938, + "logprob": -7.4375, "text": "?" }, { "id": 13, - "logprob": -0.8017578, + "logprob": -0.8046875, "text": "\n" }, { "id": 13, - "logprob": -0.32958984, + "logprob": -0.33032227, "text": "\n" } ], @@ -64,7 +64,7 @@ }, { "id": 369, - "logprob": -0.06585693, + "logprob": 0.0, "special": false, "text": " that" }, diff --git a/nix/server.nix b/nix/server.nix index cfdb3f01..5921da7f 100644 --- a/nix/server.nix +++ b/nix/server.nix @@ -21,6 +21,7 @@ loguru, mamba-ssm, marlin-kernels, + moe-kernels, opentelemetry-api, opentelemetry-exporter-otlp, opentelemetry-instrumentation-grpc, @@ -88,6 +89,7 @@ buildPythonPackage { loguru mamba-ssm marlin-kernels + moe-kernels opentelemetry-api opentelemetry-exporter-otlp opentelemetry-instrumentation-grpc diff --git a/server/poetry.lock b/server/poetry.lock index ce5b8a6c..eeb22204 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1242,6 +1242,82 @@ files = [ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, ] +[[package]] +name = "moe-kernels" +version = "0.2.2" +description = "MoE kernels" +optional = true +python-versions = ">=3.7" +files = [ + {file = "moe_kernels-0.2.2+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:d268d818932ddcbca9bc71021dc63b008aae832827a7c0484cf206bd59cfc9ab"}, +] + +[package.dependencies] +nvidia-ml-py = "*" +torch = "*" +triton = "*" + +[package.source] +type = "url" +url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp310-cp310-linux_x86_64.whl" + +[[package]] +name = "moe-kernels" +version = "0.2.2" +description = "MoE kernels" +optional = true +python-versions = ">=3.7" +files = [ + {file = "moe_kernels-0.2.2+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:614bbc3f41b707b0c40372f0bb00e218ad0842d306f90bef28ce8e98e7fcb7cb"}, +] + +[package.dependencies] +nvidia-ml-py = "*" +torch = "*" +triton = "*" + +[package.source] +type = "url" +url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp311-cp311-linux_x86_64.whl" + +[[package]] +name = "moe-kernels" +version = "0.2.2" +description = "MoE kernels" +optional = true +python-versions = ">=3.7" +files = [ + {file = "moe_kernels-0.2.2+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:c2f48ed541353be03157d4015270dff797f7b7b8a664babdcbdf7414867d5abd"}, +] + +[package.dependencies] +nvidia-ml-py = "*" +torch = "*" +triton = "*" + +[package.source] +type = "url" +url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp312-cp312-linux_x86_64.whl" + +[[package]] +name = "moe-kernels" +version = "0.2.2" +description = "MoE kernels" +optional = true +python-versions = ">=3.7" +files = [ + {file = "moe_kernels-0.2.2+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:d5f0339b73426c422872f7ff060433df6cd8e881451baf85ee7454e0e905f9d8"}, +] + +[package.dependencies] +nvidia-ml-py = "*" +torch = "*" +triton = "*" + +[package.source] +type = "url" +url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp39-cp39-linux_x86_64.whl" + [[package]] name = "mpmath" version = "1.3.0" @@ -1600,6 +1676,17 @@ files = [ [package.dependencies] nvidia-nvjitlink-cu12 = "*" +[[package]] +name = "nvidia-ml-py" +version = "12.560.30" +description = "Python Bindings for the NVIDIA Management Library" +optional = true +python-versions = "*" +files = [ + {file = "nvidia-ml-py-12.560.30.tar.gz", hash = "sha256:f0254dc7400647680a072ee02509bfd46102b60bdfeca321576d4d4817e7fe97"}, + {file = "nvidia_ml_py-12.560.30-py3-none-any.whl", hash = "sha256:fea371c94d63e38a611c17bbb85fe400e9c8ddb9e8684a9cd0e47786a4bc3c73"}, +] + [[package]] name = "nvidia-nccl-cu12" version = "2.20.5" @@ -3638,6 +3725,7 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", accelerate = ["accelerate"] bnb = ["bitsandbytes"] marlin = ["marlin-kernels", "marlin-kernels", "marlin-kernels", "marlin-kernels"] +moe = ["moe-kernels", "moe-kernels", "moe-kernels", "moe-kernels"] outlines = ["outlines"] peft = ["peft"] quantize = ["accelerate", "datasets", "texttable"] diff --git a/server/pyproject.toml b/server/pyproject.toml index 57deb1b8..6eee1e72 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -46,6 +46,12 @@ marlin-kernels = [ { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, ] +moe-kernels = [ + { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, + { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, + { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, + { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, +] rich = "^13.7.1" [tool.poetry.extras] @@ -53,6 +59,7 @@ torch = ["torch"] accelerate = ["accelerate"] bnb = ["bitsandbytes"] marlin = ["marlin-kernels"] +moe = ["moe-kernels"] peft = ["peft"] quantize = ["texttable", "datasets", "accelerate"] outlines = ["outlines"] diff --git a/server/text_generation_server/layers/moe/__init__.py b/server/text_generation_server/layers/moe/__init__.py new file mode 100644 index 00000000..e32003ae --- /dev/null +++ b/server/text_generation_server/layers/moe/__init__.py @@ -0,0 +1,76 @@ +from typing import Optional + +import torch +import torch.nn as nn +from text_generation_server.layers.fp8 import HybridFP8UnquantLoader +from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer +from text_generation_server.utils.weights import ( + DefaultWeightsLoader, + UnquantizedWeight, + Weights, +) + + +class SparseMoELayer(nn.Module): + """ + Layer for MoE that uses fused kernels to only apply the active experts + for each token (rather than applying all experts and selecting the + outputs of active experts). + """ + + def __init__( + self, + *, + n_expert_group: Optional[int], + n_experts: int, + prefix: str, + renormalize: bool, + topk: int, + topk_group: Optional[int], + weights: Weights, + gate_proj_name: str = "gate_proj", + up_proj_name: str = "up_proj", + down_proj_name: str = "down_proj", + ): + super().__init__() + + if ( + isinstance(weights.loader, DefaultWeightsLoader) + and isinstance(weights.loader.weight_class, UnquantizedWeight) + ) or isinstance(weights.loader, HybridFP8UnquantLoader): + cls = UnquantizedSparseMoELayer + # Once we wire up GPTQ-Marlin MoE: + # elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym: + # cls = GPTQMarlinSparseMoELayer + else: + raise ValueError( + f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights" + ) + + self.moe = cls( + n_expert_group=n_expert_group, + n_experts=n_experts, + prefix=prefix, + renormalize=renormalize, + topk=topk, + topk_group=topk_group, + weights=weights, + gate_proj_name=gate_proj_name, + up_proj_name=up_proj_name, + down_proj_name=down_proj_name, + ) + + def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: + return self.moe(x, gating_output=gating_output) + + @staticmethod + def is_supported(weights: Weights) -> bool: + return ( + ( + isinstance(weights.loader, DefaultWeightsLoader) + and isinstance(weights.loader.weight_class, UnquantizedWeight) + ) + or isinstance(weights.loader, HybridFP8UnquantLoader) + # Once we wire up GPTQ-Marlin MoE: + # or isinstance(weights.loader, GPTQMarlinWeightsLoader) + ) diff --git a/server/text_generation_server/layers/moe/unquantized.py b/server/text_generation_server/layers/moe/unquantized.py new file mode 100644 index 00000000..8f1d9b3f --- /dev/null +++ b/server/text_generation_server/layers/moe/unquantized.py @@ -0,0 +1,125 @@ +from typing import Optional + +import torch +import torch.nn as nn + +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.weights import UnquantizedWeight, Weights + +if SYSTEM != "ipex": + from moe_kernels.fused_moe import fused_moe + + +class UnquantizedSparseMoELayer(nn.Module): + def __init__( + self, + *, + n_expert_group: Optional[int], + n_experts: int, + prefix: str, + renormalize: bool, + topk: int, + topk_group: Optional[int], + weights: Weights, + gate_proj_name: str = "gate_proj", + up_proj_name: str = "up_proj", + down_proj_name: str = "down_proj", + ): + super().__init__() + + assert (n_expert_group is None) == ( + topk_group is None + ), "n_expert_group and topk_group must both be None or have some value" + + self.n_expert_group = n_expert_group + self.topk = topk + self.topk_group = topk_group + self.renormalize = renormalize + + self.gate_up_proj = _load_expert_multi_weights_col( + prefix=prefix, + n_experts=n_experts, + gate_proj_name=gate_proj_name, + up_proj_name=up_proj_name, + weights=weights, + ) + + self.down_proj = _load_expert_weights_row( + prefix=prefix, + n_experts=n_experts, + name=down_proj_name, + weights=weights, + ) + + def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: + return fused_moe( + x, + w1=self.gate_up_proj, + w2=self.down_proj, + gating_output=gating_output, + topk=self.topk, + renormalize=self.renormalize, + inplace=True, + use_grouped_topk=self.n_expert_group is not None, + num_expert_group=self.n_expert_group, + topk_group=self.topk_group, + ) + + +def _load_expert_multi_weights_col( + *, + prefix: str, + n_experts: int, + gate_proj_name: str, + up_proj_name: str, + weights: Weights, +) -> torch.Tensor: + all_weight = None + for i in range(n_experts): + weight = weights.get_multi_weights_col( + [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0 + ) + + assert isinstance(weight, UnquantizedWeight) + + if all_weight is None: + all_weight = torch.empty( + (n_experts,) + weight.weight.shape, + dtype=weight.weight.dtype, + device=weight.weight.device, + ) + + all_weight[i] = weight.weight + + assert all_weight is not None + + return all_weight + + +def _load_expert_weights_row( + *, + prefix: str, + n_experts: int, + name: str, + weights: Weights, +) -> torch.Tensor: + all_weight = None + for i in range(n_experts): + weight = weights.get_weights_row( + f"{prefix}.{i}.{name}", + ) + + assert isinstance(weight, UnquantizedWeight) + + if all_weight is None: + all_weight = torch.empty( + (n_experts,) + weight.weight.shape, + dtype=weight.weight.dtype, + device=weight.weight.device, + ) + + all_weight[i] = weight.weight + + assert all_weight is not None + + return all_weight diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index f62dfe66..12be08cd 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -13,8 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple +from typing import List, Optional, Tuple +from moe_kernels.fused_moe import grouped_topk import torch import torch.distributed from text_generation_server.layers import ( @@ -32,6 +33,7 @@ from text_generation_server.layers.attention import ( Seqlen, ) from text_generation_server.layers.layernorm import FastRMSNorm +from text_generation_server.layers.moe import SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.weights import Weights @@ -152,44 +154,6 @@ class DeepseekV2Config(PretrainedConfig): ) -def _load_experts(config, prefix: str, mat: str, weights: Weights): - if config.quantize is not None: - raise NotImplementedError( - "Deepseek V2 does not support weight quantization yet." - ) - - assert mat in ["gate_proj", "up_proj", "down_proj"] - - world_size = weights.process_group.size() - rank = weights.process_group.rank() - - assert ( - config.moe_intermediate_size % world_size == 0 - ), f"The chosen size {config.moe_intermediate_size} is not compatible with sharding on {world_size} shards" - - block_size = config.moe_intermediate_size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - - tensor = torch.empty( - (config.n_routed_experts * block_size, config.hidden_size), - dtype=weights.dtype, - device=weights.device, - ) - - for i in range(config.n_routed_experts): - slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight") - - if mat == "down_proj": - expert_slice = slice_[:, start:stop].t().contiguous() - else: - expert_slice = slice_[start:stop] - tensor[i * block_size : (i + 1) * block_size] = expert_slice.to( - dtype=weights.dtype - ).to(device=weights.device) - return tensor - - class DeepseekV2Attention(torch.nn.Module): def __init__( self, @@ -452,33 +416,21 @@ class BlockSparseMoE(nn.Module): self.moe_intermediate_size = ( config.moe_intermediate_size // weights.process_group.size() ) - self.n_routed_experts = config.n_routed_experts - self.n_expert_group = config.n_group - self.topk_group = config.topk_group - self.top_k = config.num_experts_per_tok - self.norm_topk_prob = config.norm_topk_prob self.routed_scaling_factor = config.routed_scaling_factor - gate_proj = _load_experts( - config, f"{prefix}.experts", "gate_proj", weights - ).view(self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim) - - up_proj = _load_experts(config, f"{prefix}.experts", "up_proj", weights).view( - self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim - ) - - self.gate_up_proj = torch.cat([gate_proj, up_proj], dim=1) - - self.down_proj = ( - _load_experts(config, f"{prefix}.experts", "down_proj", weights) - .view(self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim) - .transpose(1, 2) - .contiguous() - ) - # Gating self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) + self.moe_layer = SparseMoELayer( + prefix=f"{prefix}.experts", + n_experts=config.n_routed_experts, + n_expert_group=config.n_group, + renormalize=config.norm_topk_prob, + topk=config.num_experts_per_tok, + topk_group=config.topk_group, + weights=weights, + ) + if config.n_shared_experts is not None: self.shared_experts = DeepseekV2MLP( prefix=f"{prefix}.shared_experts", @@ -499,25 +451,8 @@ class BlockSparseMoE(nn.Module): shared_output = None router_logits = self.gate(x) - topk_weights, topk_ids = grouped_topk( - x, - router_logits, - self.top_k, - renormalize=self.norm_topk_prob, - num_expert_group=self.n_expert_group, - topk_group=self.topk_group, - ) - out = ( - fused_experts( - x, - self.gate_up_proj, - self.down_proj, - topk_weights, - topk_ids, - inplace=True, - ) - * self.routed_scaling_factor - ) + + out = self.moe_layer(x, gating_output=router_logits) if shared_output is not None: out = out + shared_output @@ -635,7 +570,9 @@ class DeepseekV2Layer(nn.Module): and layer_id >= config.first_k_dense_replace and layer_id % config.moe_layer_freq == 0 ): - moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE + moe_cls = ( + BlockSparseMoE if SparseMoELayer.is_supported(weights) else DenseMoE + ) self.mlp = moe_cls(f"{prefix}.mlp", config, weights) else: self.mlp = DeepseekV2MLP( @@ -799,183 +736,3 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module): hidden_states = hidden_states[lm_head_indices] logits, speculative_logits = self.lm_head(hidden_states) return logits, speculative_logits - - -# Functions below are from vLLM: -# -# https://github.com/vllm-project/vllm/blob/f7160d946a0a07703e72d81ba9ecf3913f192605/vllm/model_executor/layers/fused_moe/fused_moe.py#L397 -# -# Remove after we have synced our version with upstream. - - -def grouped_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, -) -> Tuple[torch.Tensor, torch.Tensor]: - scores = torch.softmax(gating_output, dim=-1) - num_token = scores.shape[0] - group_scores = ( - scores.view(num_token, num_expert_group, -1).max(dim=-1).values - ) # [n, n_group] - group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ - 1 - ] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = ( - group_mask.unsqueeze(-1) - .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) - .reshape(num_token, -1) - ) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - - return topk_weights, topk_ids - - -def get_default_config( - M: int, - E: int, - N: int, - K: int, - topk: int, - dtype: Optional[str], -) -> Dict[str, int]: - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - } - if M <= E: - config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - } - return config - - -def fused_experts( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, - use_fp8: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, -): - # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" - assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] - - import triton.language as tl - from vllm import _custom_ops as ops - from vllm.model_executor.layers.fused_moe.fused_moe import ( - get_moe_configs, - invoke_fused_moe_kernel, - moe_align_block_size, - ) - - M, _ = hidden_states.shape - E, N, _ = w1.shape - - if override_config: - config = override_config - else: - # First try to load optimal config from the file - configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None) - - if configs: - # If an optimal configuration map has been found, look up the - # optimal config - config = configs[min(configs.keys(), key=lambda x: abs(x - M))] - else: - # Else use the default config - config = get_default_config( - M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None - ) - - intermediate_cache1 = torch.empty( - (M, topk_ids.shape[1], N), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - intermediate_cache2 = torch.empty( - (M * topk_ids.shape[1], N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - intermediate_cache3 = torch.empty( - (M, topk_ids.shape[1], w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - topk_ids, config["BLOCK_SIZE_M"], E - ) - compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 - - invoke_fused_moe_kernel( - hidden_states, - w1, - intermediate_cache1, - a1_scale, - w1_scale, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - topk_ids.shape[1], - config, - compute_type=compute_type, - use_fp8=use_fp8, - ) - - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - - invoke_fused_moe_kernel( - intermediate_cache2, - w2, - intermediate_cache3, - a2_scale, - w2_scale, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - True, - 1, - config, - compute_type=compute_type, - use_fp8=use_fp8, - ) - - if inplace: - return torch.sum( - intermediate_cache3.view(*intermediate_cache3.shape), - dim=1, - out=hidden_states, - ) - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index c36e97f6..2fda718b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -25,8 +25,6 @@ import torch.distributed from torch import nn from text_generation_server.utils.import_utils import SYSTEM -if SYSTEM != "ipex": - from vllm.model_executor.layers.fused_moe import fused_moe from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple @@ -45,6 +43,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.layers.moe import SparseMoELayer from text_generation_server.layers.layernorm import ( FastRMSNorm, ) @@ -319,40 +318,21 @@ def round_up(x: torch.Tensor, value: int): class BlockSparseMoE(nn.Module): def __init__(self, prefix, config: MixtralConfig, weights): super().__init__() - self.hidden_dim = config.hidden_size - self.ffn_dim = config.intermediate_size // weights.process_group.size() - self.num_experts = config.num_local_experts - self.top_k = config.num_experts_per_tok - - act = config.hidden_act - if "gelu" in act: - self.act = lambda x: torch.nn.functional.gelu( - x, - approximate=( - "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" - ), - ) - elif "silu" in act: - self.act = torch.nn.functional.silu - else: - self.act = ACT2FN[act] # gating self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) - # merged expert weights, all of size (n_experts * ffn_dim, hidden_dim) - w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).view( - self.num_experts, self.ffn_dim, self.hidden_dim - ) - w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).view( - self.num_experts, self.ffn_dim, self.hidden_dim - ) - self.w13 = torch.cat([w1, w3], dim=1) - self.w2 = ( - _load_experts(config, f"{prefix}.experts", "w2", weights) - .view(self.num_experts, self.ffn_dim, self.hidden_dim) - .transpose(1, 2) - .contiguous() + self.moe = SparseMoELayer( + n_expert_group=None, + n_experts=config.num_local_experts, + prefix=f"{prefix}.experts", + renormalize=True, + topk=config.num_experts_per_tok, + topk_group=None, + weights=weights, + gate_proj_name="w1", + up_proj_name="w3", + down_proj_name="w2", ) self.process_group = weights.process_group @@ -360,15 +340,7 @@ class BlockSparseMoE(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits = self.gate(x) - out = fused_moe( - x, - self.w13, - self.w2, - router_logits, - self.top_k, - renormalize=True, - inplace=True, - ) + out = self.moe(x, gating_output=router_logits) # Reduce sum if self.process_group.size() > 1: @@ -475,7 +447,7 @@ class MixtralLayer(nn.Module): prefix=f"{prefix}.self_attn", config=config, weights=weights ) - moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE + moe_cls = BlockSparseMoE if SparseMoELayer.is_supported(weights) else DenseMoE self.moe = moe_cls(f"{prefix}.block_sparse_moe", config, weights) self.input_layernorm = FastRMSNorm.load( diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 108ced48..75e01f7c 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -120,7 +120,6 @@ class DefaultWeightsLoader(WeightsLoader): prefix: str, block_sizes: Union[int, List[int]], ): - return self.weight_class( weights.get_packed_sharded( f"{prefix}.weight", dim=0, block_sizes=block_sizes @@ -405,6 +404,10 @@ class Weights: finally: self.weights_loader = old_loader + @property + def loader(self): + return self.weights_loader + def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]: """