Update ROCM libs and improvements (#2579)
* style * update torch * ix issues * fix clone * revert mkl * added custom PA * style * fix style * style * hide env vart * fix mixtral model * add skinny kernel and merge fixes * fixed style * fix issue for sliding window models * addressed review comments * fix import * improved error messag * updated default value * remove import * fix imports after rebase * float16 dep * improve dockerfile * cleaned dockerfile
This commit is contained in:
parent
e790cfc0e4
commit
f9e561eced
175
Dockerfile_amd
175
Dockerfile_amd
|
@ -41,7 +41,7 @@ COPY launcher launcher
|
|||
RUN cargo build --profile release-opt
|
||||
|
||||
# Text Generation Inference base image for RoCm
|
||||
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update AS base
|
||||
FROM rocm/dev-ubuntu-22.04:6.2 AS base
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
|
@ -50,33 +50,34 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
|||
curl \
|
||||
git \
|
||||
make \
|
||||
libmsgpack-dev \
|
||||
libssl-dev \
|
||||
llvm-dev \
|
||||
g++ \
|
||||
# Needed to build VLLM & flash.
|
||||
rocthrust-dev \
|
||||
hipsparse-dev \
|
||||
hipblas-dev \
|
||||
hipblaslt-dev \
|
||||
hipcub-dev \
|
||||
rocblas-dev \
|
||||
hiprand-dev \
|
||||
hipfft-dev \
|
||||
rocrand-dev \
|
||||
miopen-hip-dev \
|
||||
hipfft-dev \
|
||||
hipcub-dev \
|
||||
hipsolver-dev \
|
||||
rccl-dev \
|
||||
cmake \
|
||||
python3.11-dev && \
|
||||
python3.11-venv && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Keep in sync with `server/pyproject.toml
|
||||
ARG MAMBA_VERSION=23.1.0-1
|
||||
ARG PYTORCH_VERSION='2.3.0'
|
||||
ARG ROCM_VERSION='6.0.2'
|
||||
ARG PYTHON_VERSION='3.11.10'
|
||||
# Automatically set by buildx
|
||||
ARG TARGETPLATFORM
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
ENV PATH=/opt/conda/bin:$PATH
|
||||
|
||||
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
|
||||
|
||||
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||
# Install mamba
|
||||
|
@ -100,41 +101,132 @@ RUN case ${TARGETPLATFORM} in \
|
|||
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
|
||||
esac && \
|
||||
/opt/conda/bin/conda clean -ya
|
||||
|
||||
# Install flash-attention, torch dependencies
|
||||
RUN pip install numpy einops ninja --no-cache-dir
|
||||
RUN python3 -m pip install --upgrade pip && pip install numpy einops ninja joblib msgpack cmake --no-cache-dir && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN pip uninstall -y triton && \
|
||||
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
|
||||
cd triton/python && \
|
||||
pip install .
|
||||
RUN conda install mkl=2021
|
||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/lib/python3.11/site-packages/torch/lib:/opt/conda/lib/
|
||||
|
||||
RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir
|
||||
|
||||
ARG _GLIBCXX_USE_CXX11_ABI="1"
|
||||
ARG CMAKE_PREFIX_PATH="/opt/conda"
|
||||
ARG COMMON_WORKDIR=/
|
||||
WORKDIR ${COMMON_WORKDIR}
|
||||
|
||||
|
||||
# Install HIPBLASLt
|
||||
FROM base AS build_hipblaslt
|
||||
ARG HIPBLASLT_BRANCH="e6da924"
|
||||
RUN git clone https://github.com/ROCm/hipBLASLt.git \
|
||||
&& cd hipBLASLt \
|
||||
&& git checkout ${HIPBLASLT_BRANCH} \
|
||||
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} --legacy_hipblas_direct \
|
||||
&& cd build/release \
|
||||
&& make package
|
||||
|
||||
FROM scratch AS export_hipblaslt
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_hipblaslt ${COMMON_WORKDIR}/hipBLASLt/build/release/*.deb /
|
||||
|
||||
# RCCL build stages
|
||||
FROM base AS build_rccl
|
||||
ARG RCCL_BRANCH="rocm-6.2.0"
|
||||
RUN git clone https://github.com/ROCm/rccl \
|
||||
&& cd rccl \
|
||||
&& git checkout ${RCCL_BRANCH} \
|
||||
&& ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
|
||||
FROM scratch AS export_rccl
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_rccl ${COMMON_WORKDIR}/rccl/build/release/*.deb /
|
||||
|
||||
# Triton build stages
|
||||
FROM base AS build_triton
|
||||
ARG TRITON_BRANCH="e192dba"
|
||||
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
|
||||
RUN python3 -m pip install ninja cmake wheel pybind11 && git clone ${TRITON_REPO} \
|
||||
&& cd triton \
|
||||
&& git checkout ${TRITON_BRANCH} \
|
||||
&& cd python \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=dist
|
||||
FROM scratch AS export_triton
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_triton ${COMMON_WORKDIR}/triton/python/dist/*.whl /
|
||||
|
||||
# # AMD-SMI build stages
|
||||
FROM base AS build_amdsmi
|
||||
RUN cd /opt/rocm/share/amd_smi \
|
||||
&& pip wheel . --wheel-dir=dist
|
||||
FROM scratch AS export_amdsmi
|
||||
COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl /
|
||||
|
||||
|
||||
FROM base as build_pytorch
|
||||
|
||||
RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
|
||||
if ls /install/*.deb; then \
|
||||
dpkg -i /install/*.deb \
|
||||
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
|
||||
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
|
||||
fi
|
||||
|
||||
ARG BUILD_ENVIRONMENT=pytorch-linux-jammy-rocm6.2-py3.11
|
||||
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
|
||||
ARG BUILD_CAFFE2="0" \
|
||||
BUILD_CAFFE2_OPS="0" \
|
||||
USE_CUDA="0" \
|
||||
USE_ROCM="1" \
|
||||
BUILD_TEST="0" \
|
||||
USE_FBGEMM="0" \
|
||||
USE_NNPACK="0" \
|
||||
USE_QNNPACK="0" \
|
||||
USE_XNNPACK="0" \
|
||||
USE_FLASH_ATTENTION="1" \
|
||||
USE_MEM_EFF_ATTENTION="0"
|
||||
|
||||
RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install
|
||||
# A commit to fix the output scaling factor issue in _scaled_mm
|
||||
# Not yet in 2.5.0-rc1
|
||||
ARG PYTORCH_BRANCH="cedc116"
|
||||
ARG PYTORCH_VISION_BRANCH="v0.19.1"
|
||||
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
|
||||
|
||||
# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
|
||||
ENV HIP_FORCE_DEV_KERNARG=1
|
||||
RUN git clone ${PYTORCH_REPO} pytorch \
|
||||
&& cd pytorch && git checkout ${PYTORCH_BRANCH} && git submodule update --init --recursive \
|
||||
&& pip install -r requirements.txt --no-cache-dir \
|
||||
&& python tools/amd_build/build_amd.py \
|
||||
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist
|
||||
FROM scratch as export_pytorch
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_pytorch ${COMMON_WORKDIR}/pytorch/dist/*.whl /
|
||||
|
||||
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
|
||||
# However, Triton requires a tunning for each prompt length, which is prohibitive.
|
||||
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
|
||||
FROM base AS install_deps
|
||||
|
||||
FROM base AS kernel-builder
|
||||
ARG COMMON_WORKDIR
|
||||
|
||||
# Install hipblaslt
|
||||
RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
|
||||
if ls /install/*.deb; then \
|
||||
dpkg -i /install/*.deb \
|
||||
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
|
||||
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
|
||||
fi
|
||||
|
||||
RUN --mount=type=bind,from=export_rccl,src=/,target=/install \
|
||||
if ls /install/*.deb; then \
|
||||
dpkg -i /install/*.deb \
|
||||
# RCCL needs to be installed twice
|
||||
&& dpkg -i /install/*.deb \
|
||||
&& sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
|
||||
&& sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status; \
|
||||
fi
|
||||
|
||||
RUN --mount=type=bind,from=export_triton,src=/,target=/install \
|
||||
if ls /install/*.whl; then \
|
||||
# Preemptively uninstall to prevent pip same-version no-installs
|
||||
pip uninstall -y triton \
|
||||
&& pip install /install/*.whl; \
|
||||
fi
|
||||
|
||||
RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \
|
||||
# Preemptively uninstall to prevent pip same-version no-installs
|
||||
pip uninstall -y amdsmi \
|
||||
&& pip install /install/*.whl;
|
||||
|
||||
RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \
|
||||
if ls /install/*.whl; then \
|
||||
# Preemptively uninstall to prevent pip same-version no-installs
|
||||
pip uninstall -y torch torchvision \
|
||||
&& pip install /install/*.whl; \
|
||||
fi
|
||||
|
||||
FROM install_deps AS kernel-builder
|
||||
|
||||
# # Build vllm kernels
|
||||
FROM kernel-builder AS vllm-builder
|
||||
|
@ -174,7 +266,7 @@ COPY server/exllamav2_kernels/ .
|
|||
|
||||
RUN python setup.py build
|
||||
|
||||
FROM base AS base-copy
|
||||
FROM install_deps AS base-copy
|
||||
|
||||
# Text Generation Inference base env
|
||||
ENV HF_HOME=/data \
|
||||
|
@ -224,6 +316,19 @@ ENTRYPOINT ["./entrypoint.sh"]
|
|||
# Final image
|
||||
FROM base-copy
|
||||
|
||||
# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
|
||||
ENV HIP_FORCE_DEV_KERNARG=1
|
||||
|
||||
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
|
||||
# However, Triton requires a tunning for each prompt length, which is prohibitive.
|
||||
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
|
||||
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
|
||||
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
|
||||
ENV VLLM_MOE_PADDING=0
|
||||
ENV ATTENTION=paged
|
||||
ENV USE_PREFIX_CACHING=0
|
||||
ENV ROCM_USE_SKINNY_GEMM=1
|
||||
|
||||
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||
RUN chmod +x /tgi-entrypoint.sh
|
||||
|
||||
|
|
|
@ -31,6 +31,12 @@ Two implementations of Flash Attention are available for ROCm, the first is [ROC
|
|||
|
||||
By default, the Composable Kernel implementation is used. However, the Triton implementation has slightly lower latency on MI250 and MI300, but requires a warmup which can be prohibitive as it needs to be done again for each new prompt length. If needed, FA Triton impelmentation can be enabled with `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container.
|
||||
|
||||
## Custom PagedAttention
|
||||
|
||||
For better performance on ROCm, a custom Paged Attention kernel is available and is enabled by default. To disable it and fall back to the PagedAttention v2 kernel, set the environment variable `ROCM_USE_CUSTOM_PAGED_ATTN=0`.
|
||||
|
||||
The custom kernel supports bf16 and fp16 data types, block size of 16, head size of 128, a maximum context length of 16k, and GQA ratios between 1 and 16. For other configurations, we use the PagedAttention v2 kernel.
|
||||
|
||||
## Unsupported features
|
||||
|
||||
The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
flash_att_v2_commit_cuda := v2.6.1
|
||||
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
|
||||
flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4
|
||||
|
||||
build-flash-attention-v2-cuda:
|
||||
pip install -U packaging wheel
|
||||
|
@ -11,7 +11,7 @@ install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
|
|||
build-flash-attention-v2-rocm:
|
||||
if [ ! -d 'flash-attention-v2' ]; then \
|
||||
pip install -U packaging ninja --no-cache-dir && \
|
||||
git clone https://github.com/ROCm/flash-attention.git flash-attention-v2 && \
|
||||
git clone https://github.com/mht-sharma/flash-attention.git flash-attention-v2 && \
|
||||
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \
|
||||
git submodule update --init --recursive && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \
|
||||
fi
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
|
||||
commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921
|
||||
commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247
|
||||
build-vllm-cuda:
|
||||
if [ ! -d 'vllm' ]; then \
|
||||
pip install -U ninja packaging --no-cache-dir && \
|
||||
|
@ -13,7 +13,7 @@ install-vllm-cuda: build-vllm-cuda
|
|||
build-vllm-rocm:
|
||||
if [ ! -d 'vllm' ]; then \
|
||||
pip install -U ninja packaging --no-cache-dir && \
|
||||
git clone https://github.com/fxmarty/rocm-vllm.git vllm; \
|
||||
git clone https://github.com/mht-sharma/vllm.git vllm; \
|
||||
fi
|
||||
cd vllm && git fetch && git checkout $(commit_rocm) && \
|
||||
PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
|
||||
|
|
|
@ -1,5 +1,17 @@
|
|||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
import torch
|
||||
|
||||
extra_cuda_cflags = []
|
||||
extra_cflags = []
|
||||
if torch.version.hip:
|
||||
extra_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"]
|
||||
extra_cuda_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"]
|
||||
|
||||
extra_compile_args = {
|
||||
"cxx": extra_cflags,
|
||||
"nvcc": extra_cuda_cflags,
|
||||
}
|
||||
|
||||
setup(
|
||||
name="exllama_kernels",
|
||||
|
@ -13,6 +25,7 @@ setup(
|
|||
"exllama_kernels/cuda_func/q4_matmul.cu",
|
||||
"exllama_kernels/cuda_func/q4_matrix.cu",
|
||||
],
|
||||
extra_compile_args=extra_compile_args,
|
||||
)
|
||||
],
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
|
|
|
@ -3,11 +3,13 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
|||
import torch
|
||||
|
||||
extra_cuda_cflags = ["-lineinfo", "-O3"]
|
||||
|
||||
extra_cflags = []
|
||||
if torch.version.hip:
|
||||
extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF"]
|
||||
extra_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"]
|
||||
extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF", "-DLEGACY_HIPBLAS_DIRECT=ON"]
|
||||
|
||||
extra_compile_args = {
|
||||
"cxx": extra_cflags,
|
||||
"nvcc": extra_cuda_cflags,
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from dataclasses import dataclass
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
@ -65,5 +66,7 @@ else:
|
|||
max_k: int
|
||||
|
||||
def clamp(self, max):
|
||||
if SYSTEM == "rocm":
|
||||
return self
|
||||
raise NotImplementedError("Not implemented seqlen for paged")
|
||||
return Seqlen(torch.clamp(self.input_lengths, max=max))
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
from typing import Optional
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
|
@ -8,16 +9,28 @@ from loguru import logger
|
|||
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
is_sm75 = major == 7 and minor == 5
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
_PARTITION_SIZE_V1V2 = 512
|
||||
_PARTITION_SIZE_CUSTOM = 256
|
||||
|
||||
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
|
||||
ENGINE = "triton" if use_triton else "ck"
|
||||
|
||||
|
||||
PREFILL_IN_KV_CACHE = False
|
||||
|
||||
use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
|
||||
try:
|
||||
from vllm._C import cache_ops
|
||||
if use_rocm_custom_paged_attn:
|
||||
from vllm._custom_C import paged_attention_custom
|
||||
except ImportError as e:
|
||||
log_master(
|
||||
logger.info,
|
||||
f"Custom Paged Attention not available. Complete error: {e}",
|
||||
)
|
||||
use_rocm_custom_paged_attn = False
|
||||
|
||||
try:
|
||||
import vllm._custom_ops as ops
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||
|
@ -36,9 +49,7 @@ def reshape_and_cache(
|
|||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||
else:
|
||||
cache_ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||
)
|
||||
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
||||
|
||||
|
||||
def paged_attention(
|
||||
|
@ -48,8 +59,9 @@ def paged_attention(
|
|||
kv_head_mapping: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
input_lengths: Seqlen,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
softcap: Optional[float] = None,
|
||||
):
|
||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||
# Copyright 2023 The vLLM team. All rights
|
||||
|
@ -68,11 +80,31 @@ def paged_attention(
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
if softcap is not None:
|
||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||
|
||||
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
|
||||
num_kv_heads = key_cache.shape[1]
|
||||
gqa_ratio = num_heads // num_kv_heads
|
||||
use_custom = (
|
||||
use_rocm_custom_paged_attn
|
||||
and (query.dtype == torch.half or query.dtype == torch.bfloat16)
|
||||
and (head_size == 128 or head_size == 64)
|
||||
and (block_size == 16 or block_size == 32)
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
||||
and max_s <= 32768
|
||||
)
|
||||
|
||||
if not use_custom:
|
||||
_PARTITION_SIZE = _PARTITION_SIZE_V1V2
|
||||
else:
|
||||
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
|
||||
|
||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||
input_lengths = input_lengths.input_lengths
|
||||
input_lengths = seqlen.input_lengths
|
||||
|
||||
out = torch.empty_like(query)
|
||||
|
||||
|
@ -81,9 +113,13 @@ def paged_attention(
|
|||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
from vllm._C import ops
|
||||
import vllm._custom_ops as ops
|
||||
|
||||
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||
use_v1 = (
|
||||
max_s <= 8192
|
||||
and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||
and not use_custom
|
||||
)
|
||||
if use_v1:
|
||||
ops.paged_attention_v1(
|
||||
out,
|
||||
|
@ -115,6 +151,7 @@ def paged_attention(
|
|||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
|
||||
if not use_custom:
|
||||
ops.paged_attention_v2(
|
||||
out,
|
||||
exp_sums,
|
||||
|
@ -133,6 +170,25 @@ def paged_attention(
|
|||
"auto",
|
||||
1.0,
|
||||
)
|
||||
else:
|
||||
paged_attention_custom(
|
||||
out,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
"auto",
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
@ -175,13 +231,14 @@ if ENGINE == "ck":
|
|||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
window_size_left: int = -1,
|
||||
causal: bool = True,
|
||||
softcap: float = 0.0,
|
||||
):
|
||||
if window_size_left <= 0 and window_size_left != -1:
|
||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||
|
@ -191,46 +248,57 @@ if ENGINE == "ck":
|
|||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
key_cache,
|
||||
value_cache,
|
||||
out,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
window_size_left,
|
||||
0,
|
||||
softcap,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
)[0]
|
||||
|
||||
elif ENGINE == "triton":
|
||||
from .flash_attn_triton import triton_attention
|
||||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
window_size_left: int = -1,
|
||||
causal: bool = True,
|
||||
softcap: Optional[float] = None,
|
||||
):
|
||||
if softcap is not None:
|
||||
raise NotImplementedError("softcap is only available with CK flash attn")
|
||||
|
||||
out = torch.empty_like(q)
|
||||
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
output, _ = triton_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
key_cache,
|
||||
value_cache,
|
||||
out,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
causal,
|
||||
softmax_scale,
|
||||
)
|
||||
|
|
|
@ -1,12 +1,21 @@
|
|||
import torch
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from torch.nn import functional as F
|
||||
import os
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
)
|
||||
|
||||
if ROCM_USE_SKINNY_GEMM:
|
||||
try:
|
||||
from vllm import _custom_C
|
||||
except Exception as e:
|
||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||
raise ImportError(
|
||||
f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}"
|
||||
)
|
||||
|
||||
|
||||
class FastLinear(torch.nn.Module):
|
||||
|
@ -48,6 +57,14 @@ class FastLinearROCm(torch.nn.Module):
|
|||
else:
|
||||
self.bias = None
|
||||
|
||||
self.cu_count = torch.cuda.get_device_properties(
|
||||
device="cuda"
|
||||
).multi_processor_count
|
||||
self.use_skinny_gemm = (
|
||||
ROCM_USE_SKINNY_GEMM
|
||||
and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
|
@ -61,7 +78,11 @@ class FastLinearROCm(torch.nn.Module):
|
|||
weight = self.weight
|
||||
bias = self.bias
|
||||
|
||||
if SYSTEM == "rocm" and inp.numel() // inp.shape[-1] == 1:
|
||||
if (
|
||||
self.use_skinny_gemm
|
||||
and inp.dtype == torch.float16
|
||||
and inp.shape[-1] % 8 == 0
|
||||
):
|
||||
batched = False
|
||||
inp_shape = inp.shape
|
||||
|
||||
|
@ -69,13 +90,16 @@ class FastLinearROCm(torch.nn.Module):
|
|||
inp = inp.view(-1, inp_shape[-1])
|
||||
batched = True
|
||||
|
||||
m, k = weight.shape[0], inp_shape[1]
|
||||
m, n, k = weight.shape[0], inp_shape[0], inp_shape[1]
|
||||
if m > 8 and n <= 4:
|
||||
out = torch.empty(
|
||||
inp_shape[0], weight.shape[0], dtype=inp.dtype, device="cuda"
|
||||
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
|
||||
)
|
||||
_custom_C.wvSpltK(weight, inp, out, n, self.cu_count)
|
||||
elif m % 4 == 0 and n == 1 and k <= 8192:
|
||||
out = torch.empty(
|
||||
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
|
||||
)
|
||||
if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192):
|
||||
_custom_C.LLMM1(weight, inp, out, 8)
|
||||
elif k <= 8192 and k % 8 == 0 and m % 4 == 0:
|
||||
_custom_C.LLMM1(weight, inp, out, 4)
|
||||
else:
|
||||
out = F.linear(inp, weight)
|
||||
|
|
|
@ -19,7 +19,10 @@ from text_generation_server.utils.weights import (
|
|||
Weights,
|
||||
)
|
||||
|
||||
if SYSTEM != "ipex":
|
||||
if SYSTEM == "rocm":
|
||||
from .fused_moe_rocm import grouped_topk
|
||||
from vllm.model_executor.layers.fused_moe import fused_topk
|
||||
elif SYSTEM != "ipex":
|
||||
from moe_kernels.fused_moe import fused_topk, grouped_topk
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
||||
# TODO: Remove the functions once moe_kernel are built for ROCM
|
||||
def grouped_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
num_token = scores.shape[0]
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||
) # [n, n_group]
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
||||
1
|
||||
] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
||||
.reshape(num_token, -1)
|
||||
) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights, topk_ids
|
|
@ -6,7 +6,9 @@ import torch.nn as nn
|
|||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
||||
|
||||
if SYSTEM != "ipex":
|
||||
if SYSTEM == "rocm":
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
elif SYSTEM != "ipex":
|
||||
from moe_kernels.fused_moe import fused_moe
|
||||
|
||||
|
||||
|
@ -52,6 +54,17 @@ class UnquantizedSparseMoELayer(nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||
if SYSTEM == "rocm":
|
||||
return fused_moe(
|
||||
x,
|
||||
self.gate_up_proj,
|
||||
self.down_proj,
|
||||
gating_output,
|
||||
self.topk,
|
||||
renormalize=self.renormalize,
|
||||
inplace=True,
|
||||
)
|
||||
|
||||
return fused_moe(
|
||||
x,
|
||||
w1=self.gate_up_proj,
|
||||
|
|
|
@ -390,6 +390,7 @@ class DeepseekV2MLP(nn.Module):
|
|||
if (
|
||||
SYSTEM == "rocm"
|
||||
and self.hidden_act == "silu"
|
||||
and hidden_states.dtype == torch.float16
|
||||
and hidden_states.shape[0] == 1
|
||||
and not self.quantize
|
||||
):
|
||||
|
|
|
@ -316,12 +316,17 @@ class LlamaMLP(nn.Module):
|
|||
# TODO: This is a hotfix to be removed & properly refactored.
|
||||
self.quantize = config.quantize
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
def forward(self, hidden_states, adapter_data):
|
||||
if (
|
||||
SYSTEM == "rocm"
|
||||
and self.hidden_act == "silu"
|
||||
and hidden_states.dtype == torch.float16
|
||||
and hidden_states.shape[0] == 1
|
||||
and not self.quantize
|
||||
and self.hidden_size
|
||||
!= 16384 # TODO: Temporary workaround for `LLMM_Silu` kernel not working with LLama3.1 405B; needs refactoring once fixed.
|
||||
):
|
||||
out = torch.empty(
|
||||
hidden_states.shape[0],
|
||||
|
|
|
@ -303,6 +303,7 @@ class MistralMLP(nn.Module):
|
|||
if (
|
||||
SYSTEM == "rocm"
|
||||
and self.hidden_act == "silu"
|
||||
and hidden_states.dtype == torch.float16
|
||||
and hidden_states.shape[0] == 1
|
||||
and not self.quantize
|
||||
):
|
||||
|
|
|
@ -1125,12 +1125,12 @@ class FlashCausalLM(Model):
|
|||
else:
|
||||
self.kv_cache = [
|
||||
(
|
||||
torch.empty(
|
||||
torch.zeros(
|
||||
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.empty(
|
||||
torch.zeros(
|
||||
(num_blocks, num_heads, head_size, BLOCK_SIZE),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
@ -1320,8 +1320,7 @@ class FlashCausalLM(Model):
|
|||
elif CUDA_GRAPHS is not None:
|
||||
tuning_sequences = CUDA_GRAPHS
|
||||
else:
|
||||
# For seqlen = 1, we dispatch to LLMM1 kernel.
|
||||
tuning_sequences = [2, 3, 4, 5, 6, 7]
|
||||
tuning_sequences = [1, 2, 3, 4, 5, 6, 7]
|
||||
|
||||
tunableop_filepath = os.path.join(
|
||||
HUGGINGFACE_HUB_CACHE,
|
||||
|
@ -1330,7 +1329,11 @@ class FlashCausalLM(Model):
|
|||
|
||||
log_master(
|
||||
logger.info,
|
||||
f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.",
|
||||
f"PyTorch TunableOp is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.",
|
||||
)
|
||||
|
||||
torch.cuda.tunable.set_filename(
|
||||
tunableop_filepath, insert_device_ordinal=False
|
||||
)
|
||||
|
||||
if os.path.isfile(tunableop_filepath):
|
||||
|
@ -1346,6 +1349,7 @@ class FlashCausalLM(Model):
|
|||
log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
|
||||
self.tunableop_warmup(seqlen)
|
||||
torch.cuda.tunable.write_file(tunableop_filepath)
|
||||
if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1":
|
||||
torch.cuda.tunable.tuning_enable(False)
|
||||
else:
|
||||
log_master(
|
||||
|
@ -1382,6 +1386,7 @@ class FlashCausalLM(Model):
|
|||
cu_seqlen_prefill = torch.tensor(
|
||||
[0, seqlen], device=self.device, dtype=torch.int32
|
||||
)
|
||||
max_s = seqlen
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
prefix_lengths=prefix_lens_tensor,
|
||||
|
@ -1399,7 +1404,7 @@ class FlashCausalLM(Model):
|
|||
block_tables=None,
|
||||
seqlen=seqlen,
|
||||
slots=slots,
|
||||
max_s=seqlen,
|
||||
max_s=max_s,
|
||||
lm_head_indices=None,
|
||||
prefill_cache_indices=None,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue