Update vllm kernels for ROCM (#2826)
* (vllm) updated vllm rocm kernels * revert silu * update partition size * remove grouped_topk * (nit) remove log * update moe-kernels commit
This commit is contained in:
parent
7eeefa3b57
commit
8f66d323d0
|
@ -234,6 +234,7 @@ FROM kernel-builder AS vllm-builder
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
COPY server/Makefile-vllm Makefile
|
COPY server/Makefile-vllm Makefile
|
||||||
|
RUN pip install setuptools_scm
|
||||||
|
|
||||||
# Build specific version of vllm
|
# Build specific version of vllm
|
||||||
RUN make build-vllm-rocm
|
RUN make build-vllm-rocm
|
||||||
|
@ -267,6 +268,15 @@ COPY server/exllamav2_kernels/ .
|
||||||
|
|
||||||
RUN python setup.py build
|
RUN python setup.py build
|
||||||
|
|
||||||
|
FROM kernel-builder AS moe-kernels
|
||||||
|
WORKDIR /usr/src
|
||||||
|
ENV MOE_KERNELS_BRANCH=a67b35841774b2056a73806c36661134b5054edd
|
||||||
|
ENV VLLM_TARGET_DEVICE=rocm
|
||||||
|
RUN git clone https://github.com/danieldk/moe-kernels.git && \
|
||||||
|
cd moe-kernels && \
|
||||||
|
git checkout ${MOE_KERNELS_BRANCH} && \
|
||||||
|
python setup.py install
|
||||||
|
|
||||||
FROM install_deps AS base-copy
|
FROM install_deps AS base-copy
|
||||||
|
|
||||||
# Text Generation Inference base env
|
# Text Generation Inference base env
|
||||||
|
@ -289,6 +299,9 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311
|
||||||
# Copy build artifacts from exllamav2 kernels builder
|
# Copy build artifacts from exllamav2 kernels builder
|
||||||
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||||
|
|
||||||
|
# Copy build artifacts from moe kernels
|
||||||
|
COPY --from=moe-kernels /usr/src/moe-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY server server
|
COPY server server
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247
|
commit_rocm := de990cd12537f78f74e40b5c8ee1a62d63d734dd
|
||||||
|
|
||||||
build-vllm-rocm:
|
build-vllm-rocm:
|
||||||
if [ ! -d 'vllm' ]; then \
|
if [ ! -d 'vllm' ]; then \
|
||||||
|
|
|
@ -215,7 +215,9 @@ def paged_reshape_and_cache(
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||||
)
|
)
|
||||||
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
ops.reshape_and_cache(
|
||||||
|
key, value, key_cache, value_cache, slots, "auto", 1.0, 1.0
|
||||||
|
)
|
||||||
elif SYSTEM == "ipex":
|
elif SYSTEM == "ipex":
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
|
|
@ -6,26 +6,42 @@ from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
import vllm._custom_ops as ops
|
||||||
|
|
||||||
major, minor = torch.cuda.get_device_capability()
|
major, minor = torch.cuda.get_device_capability()
|
||||||
is_sm75 = major == 7 and minor == 5
|
is_sm75 = major == 7 and minor == 5
|
||||||
|
|
||||||
_PARTITION_SIZE_V1V2 = 512
|
_PARTITION_SIZE_V1V2 = 1024
|
||||||
_PARTITION_SIZE_CUSTOM = 256
|
_PARTITION_SIZE_CUSTOM = 256
|
||||||
|
|
||||||
|
_GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||||
|
_ON_MI250_MI300 = any(
|
||||||
|
arch in _GPU_ARCH for arch in ["gfx90a", "gfx940", "gfx941", "gfx942"]
|
||||||
|
)
|
||||||
|
|
||||||
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
|
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
|
||||||
ENGINE = "triton" if use_triton else "ck"
|
ENGINE = "triton" if use_triton else "ck"
|
||||||
|
|
||||||
use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
|
use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
|
||||||
try:
|
|
||||||
if use_rocm_custom_paged_attn:
|
|
||||||
from vllm._custom_C import paged_attention_custom
|
def _use_rocm_custom_paged_attention(
|
||||||
except ImportError as e:
|
qtype: torch.dtype,
|
||||||
log_master(
|
head_size: int,
|
||||||
logger.info,
|
block_size: int,
|
||||||
f"Custom Paged Attention not available. Complete error: {e}",
|
gqa_ratio: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
) -> bool:
|
||||||
|
# rocm custom page attention not support on navi (gfx1*)
|
||||||
|
return (
|
||||||
|
use_rocm_custom_paged_attn
|
||||||
|
and _ON_MI250_MI300
|
||||||
|
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||||
|
and (head_size == 64 or head_size == 128)
|
||||||
|
and (block_size == 16 or block_size == 32)
|
||||||
|
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
||||||
|
and max_seq_len <= 131072
|
||||||
)
|
)
|
||||||
use_rocm_custom_paged_attn = False
|
|
||||||
|
|
||||||
|
|
||||||
def paged_attention(
|
def paged_attention(
|
||||||
|
@ -66,13 +82,8 @@ def paged_attention(
|
||||||
|
|
||||||
num_kv_heads = kv_cache.key.shape[1]
|
num_kv_heads = kv_cache.key.shape[1]
|
||||||
gqa_ratio = num_heads // num_kv_heads
|
gqa_ratio = num_heads // num_kv_heads
|
||||||
use_custom = (
|
use_custom = _use_rocm_custom_paged_attention(
|
||||||
use_rocm_custom_paged_attn
|
query.dtype, head_size, block_size, gqa_ratio, max_s
|
||||||
and (query.dtype == torch.half or query.dtype == torch.bfloat16)
|
|
||||||
and (head_size == 128 or head_size == 64)
|
|
||||||
and (block_size == 16 or block_size == 32)
|
|
||||||
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
|
||||||
and max_s <= 32768
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not use_custom:
|
if not use_custom:
|
||||||
|
@ -90,8 +101,6 @@ def paged_attention(
|
||||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||||
# sequences or heads is large, we use V1 since there is enough work
|
# sequences or heads is large, we use V1 since there is enough work
|
||||||
# to parallelize.
|
# to parallelize.
|
||||||
import vllm._custom_ops as ops
|
|
||||||
|
|
||||||
use_v1 = (
|
use_v1 = (
|
||||||
max_s <= 8192
|
max_s <= 8192
|
||||||
and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||||
|
@ -103,7 +112,7 @@ def paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache.key,
|
kv_cache.key,
|
||||||
kv_cache.value,
|
kv_cache.value,
|
||||||
kv_head_mapping,
|
num_kv_heads,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
|
@ -112,6 +121,7 @@ def paged_attention(
|
||||||
None,
|
None,
|
||||||
"auto",
|
"auto",
|
||||||
1.0,
|
1.0,
|
||||||
|
1.0,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Run PagedAttention V2.
|
# Run PagedAttention V2.
|
||||||
|
@ -137,7 +147,7 @@ def paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache.key,
|
kv_cache.key,
|
||||||
kv_cache.value,
|
kv_cache.value,
|
||||||
kv_head_mapping,
|
num_kv_heads,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
|
@ -146,9 +156,10 @@ def paged_attention(
|
||||||
None,
|
None,
|
||||||
"auto",
|
"auto",
|
||||||
1.0,
|
1.0,
|
||||||
|
1.0,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
paged_attention_custom(
|
ops.paged_attention_rocm(
|
||||||
out,
|
out,
|
||||||
exp_sums,
|
exp_sums,
|
||||||
max_logits,
|
max_logits,
|
||||||
|
@ -164,6 +175,10 @@ def paged_attention(
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
"auto",
|
"auto",
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
None,
|
||||||
|
_PARTITION_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
|
@ -72,7 +72,7 @@ if SYSTEM == "cuda":
|
||||||
return normed_hidden_states, residual
|
return normed_hidden_states, residual
|
||||||
|
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from vllm._C import ops
|
import vllm._custom_ops as ops
|
||||||
|
|
||||||
class FastLayerNorm(nn.LayerNorm):
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
|
@ -121,6 +121,27 @@ class FastRMSNorm(nn.Module):
|
||||||
residual is not None,
|
residual is not None,
|
||||||
)
|
)
|
||||||
return out, residual if residual is not None else hidden_states
|
return out, residual if residual is not None else hidden_states
|
||||||
|
elif SYSTEM == "rocm":
|
||||||
|
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
|
||||||
|
if residual is not None:
|
||||||
|
ops.fused_add_rms_norm(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
self.weight.data,
|
||||||
|
self.variance_epsilon,
|
||||||
|
)
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
out = torch.empty_like(hidden_states)
|
||||||
|
ops.rms_norm(
|
||||||
|
out,
|
||||||
|
hidden_states,
|
||||||
|
self.weight.data,
|
||||||
|
self.variance_epsilon,
|
||||||
|
)
|
||||||
|
return out, residual
|
||||||
elif hidden_states.shape[-1] > 8192:
|
elif hidden_states.shape[-1] > 8192:
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
|
@ -164,20 +185,6 @@ class FastRMSNorm(nn.Module):
|
||||||
res = hidden_states
|
res = hidden_states
|
||||||
|
|
||||||
return normed_hidden_states, res
|
return normed_hidden_states, res
|
||||||
elif SYSTEM == "rocm":
|
|
||||||
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
|
|
||||||
if residual is not None:
|
|
||||||
hidden_states += residual
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
out = torch.empty_like(hidden_states)
|
|
||||||
ops.rms_norm(
|
|
||||||
out,
|
|
||||||
hidden_states,
|
|
||||||
self.weight.data,
|
|
||||||
self.variance_epsilon,
|
|
||||||
)
|
|
||||||
return out, residual
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
||||||
|
|
|
@ -11,10 +11,10 @@ if SYSTEM == "rocm":
|
||||||
|
|
||||||
if ROCM_USE_SKINNY_GEMM:
|
if ROCM_USE_SKINNY_GEMM:
|
||||||
try:
|
try:
|
||||||
from vllm import _custom_C
|
import vllm._custom_ops as ops
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}"
|
f"Could not load `vllm._custom_ops` for ROCm skinny gemm. Full error: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -95,12 +95,12 @@ class FastLinearROCm(torch.nn.Module):
|
||||||
out = torch.empty(
|
out = torch.empty(
|
||||||
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
|
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
|
||||||
)
|
)
|
||||||
_custom_C.wvSpltK(weight, inp, out, n, self.cu_count)
|
ops.wvSpltK(weight, inp, out, n, self.cu_count)
|
||||||
elif m % 4 == 0 and n == 1 and k <= 8192:
|
elif m % 4 == 0 and n == 1 and k <= 8192:
|
||||||
out = torch.empty(
|
out = torch.empty(
|
||||||
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
|
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
|
||||||
)
|
)
|
||||||
_custom_C.LLMM1(weight, inp, out, 4)
|
ops.LLMM1(weight, inp, out, 4)
|
||||||
else:
|
else:
|
||||||
out = F.linear(inp, weight)
|
out = F.linear(inp, weight)
|
||||||
|
|
||||||
|
|
|
@ -24,10 +24,7 @@ from text_generation_server.utils.weights import (
|
||||||
UnquantizedWeight,
|
UnquantizedWeight,
|
||||||
)
|
)
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "ipex":
|
||||||
from .fused_moe_rocm import grouped_topk
|
|
||||||
from vllm.model_executor.layers.fused_moe import fused_topk
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||||
else:
|
else:
|
||||||
from moe_kernels.fused_moe import fused_topk, grouped_topk
|
from moe_kernels.fused_moe import fused_topk, grouped_topk
|
||||||
|
|
|
@ -1,52 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Remove the functions once moe_kernel are built for ROCM
|
|
||||||
def grouped_topk(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
gating_output: torch.Tensor,
|
|
||||||
topk: int,
|
|
||||||
renormalize: bool,
|
|
||||||
num_expert_group: int = 0,
|
|
||||||
topk_group: int = 0,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
scores = torch.softmax(gating_output, dim=-1)
|
|
||||||
num_token = scores.shape[0]
|
|
||||||
group_scores = (
|
|
||||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
|
||||||
) # [n, n_group]
|
|
||||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
|
||||||
1
|
|
||||||
] # [n, top_k_group]
|
|
||||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
|
||||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
|
||||||
score_mask = (
|
|
||||||
group_mask.unsqueeze(-1)
|
|
||||||
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
|
||||||
.reshape(num_token, -1)
|
|
||||||
) # [n, e]
|
|
||||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
|
||||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
|
||||||
|
|
||||||
if renormalize:
|
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
return topk_weights, topk_ids
|
|
|
@ -6,9 +6,7 @@ import torch.nn as nn
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "ipex":
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||||
else:
|
else:
|
||||||
from moe_kernels.fused_moe import fused_moe
|
from moe_kernels.fused_moe import fused_moe
|
||||||
|
|
|
@ -7,7 +7,7 @@ from text_generation_server.utils.import_utils import SYSTEM
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
import rotary_emb
|
import rotary_emb
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from vllm._C import ops
|
import vllm._custom_ops as ops
|
||||||
elif SYSTEM == "ipex":
|
elif SYSTEM == "ipex":
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
|
|
@ -75,7 +75,7 @@ class CohereRotary(PositionRotaryEmbedding):
|
||||||
|
|
||||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from vllm._C import ops
|
import vllm._custom_ops as ops
|
||||||
|
|
||||||
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
||||||
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
|
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
|
||||||
|
|
|
@ -23,9 +23,7 @@ from typing import Optional, List, Tuple, Any
|
||||||
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "ipex":
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||||
else:
|
else:
|
||||||
from moe_kernels.fused_moe import fused_moe
|
from moe_kernels.fused_moe import fused_moe
|
||||||
|
|
|
@ -43,9 +43,9 @@ from text_generation_server.utils.weights import Weights
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
try:
|
try:
|
||||||
from vllm import _custom_C
|
import vllm._custom_ops as ops
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
raise ImportError(f"Could not load `vllm._custom_ops`. Full error: {e}")
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2Config(PretrainedConfig):
|
class DeepseekV2Config(PretrainedConfig):
|
||||||
|
@ -408,7 +408,7 @@ class DeepseekV2MLP(nn.Module):
|
||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states.dtype,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
|
ops.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
|
||||||
return self.down_proj(out, reduce=reduce)
|
return self.down_proj(out, reduce=reduce)
|
||||||
else:
|
else:
|
||||||
gate_up_states = self.gate_up_proj(hidden_states)
|
gate_up_states = self.gate_up_proj(hidden_states)
|
||||||
|
|
|
@ -91,7 +91,7 @@ class GPTJRotary(PositionRotaryEmbedding):
|
||||||
|
|
||||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from vllm._C import ops
|
import vllm._custom_ops as ops
|
||||||
|
|
||||||
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
||||||
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
|
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
|
||||||
|
|
|
@ -64,9 +64,9 @@ if SYSTEM != "ipex":
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
try:
|
try:
|
||||||
from vllm import _custom_C
|
import vllm._custom_ops as ops
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
raise ImportError(f"Could not load `vllm._custom_ops`. Full error: {e}")
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix: str, weights, layer_id):
|
def load_attention(config, prefix: str, weights, layer_id):
|
||||||
|
@ -392,7 +392,7 @@ class LlamaMLP(nn.Module):
|
||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states.dtype,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
_custom_C.LLMM_Silu(
|
ops.LLMM_Silu(
|
||||||
self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8
|
self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8
|
||||||
)
|
)
|
||||||
return self.down_proj(out, adapter_data)
|
return self.down_proj(out, adapter_data)
|
||||||
|
|
|
@ -49,9 +49,9 @@ from text_generation_server.layers.layernorm import (
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
try:
|
try:
|
||||||
from vllm import _custom_C
|
import vllm._custom_ops as ops
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
raise ImportError(f"Could not load `vllm._custom_ops`. Full error: {e}")
|
||||||
|
|
||||||
|
|
||||||
class MistralConfig(PretrainedConfig):
|
class MistralConfig(PretrainedConfig):
|
||||||
|
@ -318,7 +318,7 @@ class MistralMLP(nn.Module):
|
||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states.dtype,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
_custom_C.LLMM_Silu(
|
ops.LLMM_Silu(
|
||||||
self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8
|
self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8
|
||||||
)
|
)
|
||||||
return self.down_proj(out, adapter_data)
|
return self.down_proj(out, adapter_data)
|
||||||
|
|
|
@ -52,7 +52,7 @@ from loguru import logger
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
import dropout_layer_norm
|
import dropout_layer_norm
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from vllm._C import ops
|
import vllm._custom_ops as ops
|
||||||
else:
|
else:
|
||||||
dropout_layer_norm = None
|
dropout_layer_norm = None
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue