diff --git a/server/text_generation_server/layers/__init__.py b/server/text_generation_server/layers/__init__.py index 0000ca91..e8282b16 100644 --- a/server/text_generation_server/layers/__init__.py +++ b/server/text_generation_server/layers/__init__.py @@ -19,6 +19,8 @@ from text_generation_server.layers.lora import ( TensorParallelAdapterRowLinear, ) +from text_generation_server.layers.moe.fused_moe_rocm import grouped_topk + __all__ = [ "get_linear", "FastLinear", @@ -31,4 +33,5 @@ __all__ = [ "TensorParallelAdapterRowLinear", "load_layer_norm", "load_conv2d", + "grouped_topk", ] diff --git a/server/text_generation_server/layers/moe/fused_moe_rocm.py b/server/text_generation_server/layers/moe/fused_moe_rocm.py new file mode 100644 index 00000000..ab30ff53 --- /dev/null +++ b/server/text_generation_server/layers/moe/fused_moe_rocm.py @@ -0,0 +1,193 @@ +# coding=utf-8 +# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Dict, Any + +import torch +import torch.distributed + + +# TODO: Remove the functions once moe_kernel are built for ROCM +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor]: + scores = torch.softmax(gating_output, dim=-1) + num_token = scores.shape[0] + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) + .reshape(num_token, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids + + +def get_default_config( + M: int, + E: int, + N: int, + K: int, + topk: int, + dtype: Optional[str], +) -> Dict[str, int]: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + if M <= E: + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } + return config + + +def fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +): + # Check constraints. + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] + + import triton.language as tl + from vllm import _custom_ops as ops + from vllm.model_executor.layers.fused_moe.fused_moe import ( + get_moe_configs, + invoke_fused_moe_kernel, + moe_align_block_size, + ) + + M, _ = hidden_states.shape + E, N, _ = w1.shape + + if override_config: + config = override_config + else: + # First try to load optimal config from the file + configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None) + + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Else use the default config + config = get_default_config( + M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None + ) + + intermediate_cache1 = torch.empty( + (M, topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache3 = torch.empty( + (M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, config["BLOCK_SIZE_M"], E + ) + compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 + + invoke_fused_moe_kernel( + hidden_states, + w1, + intermediate_cache1, + a1_scale, + w1_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + compute_type=compute_type, + use_fp8=use_fp8, + ) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + + invoke_fused_moe_kernel( + intermediate_cache2, + w2, + intermediate_cache3, + a2_scale, + w2_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + compute_type=compute_type, + use_fp8=use_fp8, + ) + + if inplace: + return torch.sum( + intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=hidden_states, + ) + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) diff --git a/server/text_generation_server/layers/moe/unquantized.py b/server/text_generation_server/layers/moe/unquantized.py index 8f1d9b3f..d9d62c0e 100644 --- a/server/text_generation_server/layers/moe/unquantized.py +++ b/server/text_generation_server/layers/moe/unquantized.py @@ -6,7 +6,9 @@ import torch.nn as nn from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.weights import UnquantizedWeight, Weights -if SYSTEM != "ipex": +if SYSTEM == "rocm": + from vllm.model_executor.layers.fused_moe import fused_moe +elif SYSTEM != "ipex": from moe_kernels.fused_moe import fused_moe @@ -52,6 +54,17 @@ class UnquantizedSparseMoELayer(nn.Module): ) def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: + if SYSTEM == "rocm": + return fused_moe( + x, + self.gate_up_proj, + self.down_proj, + gating_output, + self.topk, + renormalize=self.renormalize, + inplace=True, + ) + return fused_moe( x, w1=self.gate_up_proj, diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 08a8d258..fccafb01 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -16,7 +16,13 @@ from typing import List, Optional, Tuple from text_generation_server.models.globals import PAGED_KV -from moe_kernels.fused_moe import grouped_topk +from text_generation_server.utils.import_utils import SYSTEM + +if SYSTEM == "rocm": + from text_generation_server.layers import grouped_topk +else: + from vllm.model_executor.layers.fused_moe import grouped_topk + import torch import torch.distributed from text_generation_server.layers import ( @@ -36,7 +42,6 @@ from text_generation_server.layers.attention import ( from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.weights import Weights from torch import nn from transformers.activations import ACT2FN