Update vllm kernels for ROCM (#2826)

* (vllm) updated vllm rocm kernels

* revert silu

* update partition size

* remove grouped_topk

* (nit) remove log

* update moe-kernels commit
This commit is contained in:
Mohit Sharma 2024-12-18 17:14:42 +05:30 committed by GitHub
parent 7eeefa3b57
commit 8f66d323d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 95 additions and 117 deletions

View File

@ -234,6 +234,7 @@ FROM kernel-builder AS vllm-builder
WORKDIR /usr/src
COPY server/Makefile-vllm Makefile
RUN pip install setuptools_scm
# Build specific version of vllm
RUN make build-vllm-rocm
@ -267,6 +268,15 @@ COPY server/exllamav2_kernels/ .
RUN python setup.py build
FROM kernel-builder AS moe-kernels
WORKDIR /usr/src
ENV MOE_KERNELS_BRANCH=a67b35841774b2056a73806c36661134b5054edd
ENV VLLM_TARGET_DEVICE=rocm
RUN git clone https://github.com/danieldk/moe-kernels.git && \
cd moe-kernels && \
git checkout ${MOE_KERNELS_BRANCH} && \
python setup.py install
FROM install_deps AS base-copy
# Text Generation Inference base env
@ -289,6 +299,9 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311
# Copy build artifacts from exllamav2 kernels builder
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from moe kernels
COPY --from=moe-kernels /usr/src/moe-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Install server
COPY proto proto
COPY server server

View File

@ -1,4 +1,4 @@
commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247
commit_rocm := de990cd12537f78f74e40b5c8ee1a62d63d734dd
build-vllm-rocm:
if [ ! -d 'vllm' ]; then \

View File

@ -215,7 +215,9 @@ def paged_reshape_and_cache(
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0, 1.0
)
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex

View File

@ -6,26 +6,42 @@ from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils.log import log_master
from loguru import logger
import vllm._custom_ops as ops
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE_V1V2 = 512
_PARTITION_SIZE_V1V2 = 1024
_PARTITION_SIZE_CUSTOM = 256
_GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
_ON_MI250_MI300 = any(
arch in _GPU_ARCH for arch in ["gfx90a", "gfx940", "gfx941", "gfx942"]
)
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
ENGINE = "triton" if use_triton else "ck"
use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
try:
if use_rocm_custom_paged_attn:
from vllm._custom_C import paged_attention_custom
except ImportError as e:
log_master(
logger.info,
f"Custom Paged Attention not available. Complete error: {e}",
def _use_rocm_custom_paged_attention(
qtype: torch.dtype,
head_size: int,
block_size: int,
gqa_ratio: int,
max_seq_len: int,
) -> bool:
# rocm custom page attention not support on navi (gfx1*)
return (
use_rocm_custom_paged_attn
and _ON_MI250_MI300
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_seq_len <= 131072
)
use_rocm_custom_paged_attn = False
def paged_attention(
@ -66,13 +82,8 @@ def paged_attention(
num_kv_heads = kv_cache.key.shape[1]
gqa_ratio = num_heads // num_kv_heads
use_custom = (
use_rocm_custom_paged_attn
and (query.dtype == torch.half or query.dtype == torch.bfloat16)
and (head_size == 128 or head_size == 64)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_s <= 32768
use_custom = _use_rocm_custom_paged_attention(
query.dtype, head_size, block_size, gqa_ratio, max_s
)
if not use_custom:
@ -90,8 +101,6 @@ def paged_attention(
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
import vllm._custom_ops as ops
use_v1 = (
max_s <= 8192
and (max_num_partitions == 1 or num_seqs * num_heads > 512)
@ -103,7 +112,7 @@ def paged_attention(
query,
kv_cache.key,
kv_cache.value,
kv_head_mapping,
num_kv_heads,
softmax_scale,
block_tables,
input_lengths,
@ -112,6 +121,7 @@ def paged_attention(
None,
"auto",
1.0,
1.0,
)
else:
# Run PagedAttention V2.
@ -137,7 +147,7 @@ def paged_attention(
query,
kv_cache.key,
kv_cache.value,
kv_head_mapping,
num_kv_heads,
softmax_scale,
block_tables,
input_lengths,
@ -146,9 +156,10 @@ def paged_attention(
None,
"auto",
1.0,
1.0,
)
else:
paged_attention_custom(
ops.paged_attention_rocm(
out,
exp_sums,
max_logits,
@ -164,6 +175,10 @@ def paged_attention(
max_s,
None,
"auto",
1.0,
1.0,
None,
_PARTITION_SIZE,
)
return out

View File

@ -72,7 +72,7 @@ if SYSTEM == "cuda":
return normed_hidden_states, residual
elif SYSTEM == "rocm":
from vllm._C import ops
import vllm._custom_ops as ops
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
@ -121,6 +121,27 @@ class FastRMSNorm(nn.Module):
residual is not None,
)
return out, residual if residual is not None else hidden_states
elif SYSTEM == "rocm":
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
ops.fused_add_rms_norm(
hidden_states,
residual,
self.weight.data,
self.variance_epsilon,
)
return hidden_states, residual
residual = hidden_states
out = torch.empty_like(hidden_states)
ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
return out, residual
elif hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
@ -164,20 +185,6 @@ class FastRMSNorm(nn.Module):
res = hidden_states
return normed_hidden_states, res
elif SYSTEM == "rocm":
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
hidden_states += residual
residual = hidden_states
out = torch.empty_like(hidden_states)
ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
return out, residual
else:
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."

View File

@ -11,10 +11,10 @@ if SYSTEM == "rocm":
if ROCM_USE_SKINNY_GEMM:
try:
from vllm import _custom_C
import vllm._custom_ops as ops
except Exception as e:
raise ImportError(
f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}"
f"Could not load `vllm._custom_ops` for ROCm skinny gemm. Full error: {e}"
)
@ -95,12 +95,12 @@ class FastLinearROCm(torch.nn.Module):
out = torch.empty(
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
)
_custom_C.wvSpltK(weight, inp, out, n, self.cu_count)
ops.wvSpltK(weight, inp, out, n, self.cu_count)
elif m % 4 == 0 and n == 1 and k <= 8192:
out = torch.empty(
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
)
_custom_C.LLMM1(weight, inp, out, 4)
ops.LLMM1(weight, inp, out, 4)
else:
out = F.linear(inp, weight)

View File

@ -24,10 +24,7 @@ from text_generation_server.utils.weights import (
UnquantizedWeight,
)
if SYSTEM == "rocm":
from .fused_moe_rocm import grouped_topk
from vllm.model_executor.layers.fused_moe import fused_topk
elif SYSTEM == "ipex":
if SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
else:
from moe_kernels.fused_moe import fused_topk, grouped_topk

View File

@ -1,52 +0,0 @@
# coding=utf-8
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
import torch.distributed
# TODO: Remove the functions once moe_kernel are built for ROCM
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
scores = torch.softmax(gating_output, dim=-1)
num_token = scores.shape[0]
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids

View File

@ -6,9 +6,7 @@ import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import UnquantizedWeight, Weights
if SYSTEM == "rocm":
from vllm.model_executor.layers.fused_moe import fused_moe
elif SYSTEM == "ipex":
if SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
else:
from moe_kernels.fused_moe import fused_moe

View File

@ -7,7 +7,7 @@ from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "cuda":
import rotary_emb
elif SYSTEM == "rocm":
from vllm._C import ops
import vllm._custom_ops as ops
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex

View File

@ -75,7 +75,7 @@ class CohereRotary(PositionRotaryEmbedding):
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif SYSTEM == "rocm":
from vllm._C import ops
import vllm._custom_ops as ops
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773

View File

@ -23,9 +23,7 @@ from typing import Optional, List, Tuple, Any
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "rocm":
from vllm.model_executor.layers.fused_moe import fused_moe
elif SYSTEM == "ipex":
if SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
else:
from moe_kernels.fused_moe import fused_moe

View File

@ -43,9 +43,9 @@ from text_generation_server.utils.weights import Weights
if SYSTEM == "rocm":
try:
from vllm import _custom_C
import vllm._custom_ops as ops
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
raise ImportError(f"Could not load `vllm._custom_ops`. Full error: {e}")
class DeepseekV2Config(PretrainedConfig):
@ -408,7 +408,7 @@ class DeepseekV2MLP(nn.Module):
dtype=hidden_states.dtype,
device="cuda",
)
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
ops.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
return self.down_proj(out, reduce=reduce)
else:
gate_up_states = self.gate_up_proj(hidden_states)

View File

@ -91,7 +91,7 @@ class GPTJRotary(PositionRotaryEmbedding):
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif SYSTEM == "rocm":
from vllm._C import ops
import vllm._custom_ops as ops
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773

View File

@ -64,9 +64,9 @@ if SYSTEM != "ipex":
if SYSTEM == "rocm":
try:
from vllm import _custom_C
import vllm._custom_ops as ops
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
raise ImportError(f"Could not load `vllm._custom_ops`. Full error: {e}")
def load_attention(config, prefix: str, weights, layer_id):
@ -392,7 +392,7 @@ class LlamaMLP(nn.Module):
dtype=hidden_states.dtype,
device="cuda",
)
_custom_C.LLMM_Silu(
ops.LLMM_Silu(
self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8
)
return self.down_proj(out, adapter_data)

View File

@ -49,9 +49,9 @@ from text_generation_server.layers.layernorm import (
if SYSTEM == "rocm":
try:
from vllm import _custom_C
import vllm._custom_ops as ops
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
raise ImportError(f"Could not load `vllm._custom_ops`. Full error: {e}")
class MistralConfig(PretrainedConfig):
@ -318,7 +318,7 @@ class MistralMLP(nn.Module):
dtype=hidden_states.dtype,
device="cuda",
)
_custom_C.LLMM_Silu(
ops.LLMM_Silu(
self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8
)
return self.down_proj(out, adapter_data)

View File

@ -52,7 +52,7 @@ from loguru import logger
if SYSTEM == "cuda":
import dropout_layer_norm
elif SYSTEM == "rocm":
from vllm._C import ops
import vllm._custom_ops as ops
else:
dropout_layer_norm = None