diff --git a/Dockerfile_amd b/Dockerfile_amd index a79aae48..0b059f8c 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -41,7 +41,7 @@ COPY launcher launcher RUN cargo build --profile release-opt # Text Generation Inference base image for RoCm -FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update AS base +FROM rocm/dev-ubuntu-22.04:6.2 AS base RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ build-essential \ @@ -50,33 +50,34 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins curl \ git \ make \ + libmsgpack-dev \ libssl-dev \ + llvm-dev \ g++ \ # Needed to build VLLM & flash. rocthrust-dev \ hipsparse-dev \ hipblas-dev \ - hipblaslt-dev \ + hipcub-dev \ rocblas-dev \ hiprand-dev \ + hipfft-dev \ rocrand-dev \ miopen-hip-dev \ - hipfft-dev \ - hipcub-dev \ hipsolver-dev \ rccl-dev \ cmake \ - python3.11-dev && \ + python3.11-venv && \ rm -rf /var/lib/apt/lists/* # Keep in sync with `server/pyproject.toml ARG MAMBA_VERSION=23.1.0-1 -ARG PYTORCH_VERSION='2.3.0' -ARG ROCM_VERSION='6.0.2' ARG PYTHON_VERSION='3.11.10' # Automatically set by buildx ARG TARGETPLATFORM -ENV PATH /opt/conda/bin:$PATH +ENV PATH=/opt/conda/bin:$PATH + +ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942" # TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda. # Install mamba @@ -100,41 +101,132 @@ RUN case ${TARGETPLATFORM} in \ /opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \ esac && \ /opt/conda/bin/conda clean -ya + # Install flash-attention, torch dependencies -RUN pip install numpy einops ninja --no-cache-dir +RUN python3 -m pip install --upgrade pip && pip install numpy einops ninja joblib msgpack cmake --no-cache-dir && rm -rf /var/lib/apt/lists/* -RUN pip uninstall -y triton && \ - git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \ - cd triton/python && \ - pip install . +RUN conda install mkl=2021 +ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/lib/python3.11/site-packages/torch/lib:/opt/conda/lib/ -RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir -ARG _GLIBCXX_USE_CXX11_ABI="1" -ARG CMAKE_PREFIX_PATH="/opt/conda" +ARG COMMON_WORKDIR=/ +WORKDIR ${COMMON_WORKDIR} + + +# Install HIPBLASLt +FROM base AS build_hipblaslt +ARG HIPBLASLT_BRANCH="e6da924" +RUN git clone https://github.com/ROCm/hipBLASLt.git \ + && cd hipBLASLt \ + && git checkout ${HIPBLASLT_BRANCH} \ + && SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} --legacy_hipblas_direct \ + && cd build/release \ + && make package + +FROM scratch AS export_hipblaslt +ARG COMMON_WORKDIR +COPY --from=build_hipblaslt ${COMMON_WORKDIR}/hipBLASLt/build/release/*.deb / + +# RCCL build stages +FROM base AS build_rccl +ARG RCCL_BRANCH="rocm-6.2.0" +RUN git clone https://github.com/ROCm/rccl \ + && cd rccl \ + && git checkout ${RCCL_BRANCH} \ + && ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH} +FROM scratch AS export_rccl +ARG COMMON_WORKDIR +COPY --from=build_rccl ${COMMON_WORKDIR}/rccl/build/release/*.deb / + +# Triton build stages +FROM base AS build_triton +ARG TRITON_BRANCH="e192dba" +ARG TRITON_REPO="https://github.com/triton-lang/triton.git" +RUN python3 -m pip install ninja cmake wheel pybind11 && git clone ${TRITON_REPO} \ + && cd triton \ + && git checkout ${TRITON_BRANCH} \ + && cd python \ + && python3 setup.py bdist_wheel --dist-dir=dist +FROM scratch AS export_triton +ARG COMMON_WORKDIR +COPY --from=build_triton ${COMMON_WORKDIR}/triton/python/dist/*.whl / + +# # AMD-SMI build stages +FROM base AS build_amdsmi +RUN cd /opt/rocm/share/amd_smi \ + && pip wheel . --wheel-dir=dist +FROM scratch AS export_amdsmi +COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl / + + +FROM base as build_pytorch + +RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \ + if ls /install/*.deb; then \ + dpkg -i /install/*.deb \ + && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \ + && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \ + fi + +ARG BUILD_ENVIRONMENT=pytorch-linux-jammy-rocm6.2-py3.11 ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942" -ARG BUILD_CAFFE2="0" \ - BUILD_CAFFE2_OPS="0" \ - USE_CUDA="0" \ - USE_ROCM="1" \ - BUILD_TEST="0" \ - USE_FBGEMM="0" \ - USE_NNPACK="0" \ - USE_QNNPACK="0" \ - USE_XNNPACK="0" \ - USE_FLASH_ATTENTION="1" \ - USE_MEM_EFF_ATTENTION="0" -RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install +# A commit to fix the output scaling factor issue in _scaled_mm +# Not yet in 2.5.0-rc1 +ARG PYTORCH_BRANCH="cedc116" +ARG PYTORCH_VISION_BRANCH="v0.19.1" +ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git" -# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm -ENV HIP_FORCE_DEV_KERNARG=1 +RUN git clone ${PYTORCH_REPO} pytorch \ + && cd pytorch && git checkout ${PYTORCH_BRANCH} && git submodule update --init --recursive \ + && pip install -r requirements.txt --no-cache-dir \ + && python tools/amd_build/build_amd.py \ + && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist +FROM scratch as export_pytorch +ARG COMMON_WORKDIR +COPY --from=build_pytorch ${COMMON_WORKDIR}/pytorch/dist/*.whl / -# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK. -# However, Triton requires a tunning for each prompt length, which is prohibitive. -ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0 +FROM base AS install_deps -FROM base AS kernel-builder +ARG COMMON_WORKDIR + +# Install hipblaslt +RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \ + if ls /install/*.deb; then \ + dpkg -i /install/*.deb \ + && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \ + && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \ + fi + +RUN --mount=type=bind,from=export_rccl,src=/,target=/install \ + if ls /install/*.deb; then \ + dpkg -i /install/*.deb \ + # RCCL needs to be installed twice + && dpkg -i /install/*.deb \ + && sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \ + && sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status; \ + fi + +RUN --mount=type=bind,from=export_triton,src=/,target=/install \ + if ls /install/*.whl; then \ + # Preemptively uninstall to prevent pip same-version no-installs + pip uninstall -y triton \ + && pip install /install/*.whl; \ + fi + +RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \ + # Preemptively uninstall to prevent pip same-version no-installs + pip uninstall -y amdsmi \ + && pip install /install/*.whl; + +RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \ + if ls /install/*.whl; then \ + # Preemptively uninstall to prevent pip same-version no-installs + pip uninstall -y torch torchvision \ + && pip install /install/*.whl; \ + fi + +FROM install_deps AS kernel-builder # # Build vllm kernels FROM kernel-builder AS vllm-builder @@ -174,7 +266,7 @@ COPY server/exllamav2_kernels/ . RUN python setup.py build -FROM base AS base-copy +FROM install_deps AS base-copy # Text Generation Inference base env ENV HF_HOME=/data \ @@ -224,6 +316,19 @@ ENTRYPOINT ["./entrypoint.sh"] # Final image FROM base-copy +# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm +ENV HIP_FORCE_DEV_KERNARG=1 + +# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK. +# However, Triton requires a tunning for each prompt length, which is prohibitive. +ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0 +ENV ROCM_USE_CUSTOM_PAGED_ATTN=1 +ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0 +ENV VLLM_MOE_PADDING=0 +ENV ATTENTION=paged +ENV USE_PREFIX_CACHING=0 +ENV ROCM_USE_SKINNY_GEMM=1 + COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh diff --git a/docs/source/installation_amd.md b/docs/source/installation_amd.md index 544beffc..6806bac9 100644 --- a/docs/source/installation_amd.md +++ b/docs/source/installation_amd.md @@ -31,6 +31,12 @@ Two implementations of Flash Attention are available for ROCm, the first is [ROC By default, the Composable Kernel implementation is used. However, the Triton implementation has slightly lower latency on MI250 and MI300, but requires a warmup which can be prohibitive as it needs to be done again for each new prompt length. If needed, FA Triton impelmentation can be enabled with `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container. +## Custom PagedAttention + +For better performance on ROCm, a custom Paged Attention kernel is available and is enabled by default. To disable it and fall back to the PagedAttention v2 kernel, set the environment variable `ROCM_USE_CUSTOM_PAGED_ATTN=0`. + +The custom kernel supports bf16 and fp16 data types, block size of 16, head size of 128, a maximum context length of 16k, and GQA ratios between 1 and 16. For other configurations, we use the PagedAttention v2 kernel. + ## Unsupported features The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future: diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index dbddd0f4..a9cdf782 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,5 +1,5 @@ flash_att_v2_commit_cuda := v2.6.1 -flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6 +flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4 build-flash-attention-v2-cuda: pip install -U packaging wheel @@ -11,7 +11,7 @@ install-flash-attention-v2-cuda: build-flash-attention-v2-cuda build-flash-attention-v2-rocm: if [ ! -d 'flash-attention-v2' ]; then \ pip install -U packaging ninja --no-cache-dir && \ - git clone https://github.com/ROCm/flash-attention.git flash-attention-v2 && \ + git clone https://github.com/mht-sharma/flash-attention.git flash-attention-v2 && \ cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \ git submodule update --init --recursive && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \ fi diff --git a/server/Makefile-vllm b/server/Makefile-vllm index f1f80529..18dcc4a0 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,5 +1,5 @@ commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b -commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921 +commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247 build-vllm-cuda: if [ ! -d 'vllm' ]; then \ pip install -U ninja packaging --no-cache-dir && \ @@ -13,7 +13,7 @@ install-vllm-cuda: build-vllm-cuda build-vllm-rocm: if [ ! -d 'vllm' ]; then \ pip install -U ninja packaging --no-cache-dir && \ - git clone https://github.com/fxmarty/rocm-vllm.git vllm; \ + git clone https://github.com/mht-sharma/vllm.git vllm; \ fi cd vllm && git fetch && git checkout $(commit_rocm) && \ PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build diff --git a/server/exllama_kernels/setup.py b/server/exllama_kernels/setup.py index 987d181e..cc307bf0 100644 --- a/server/exllama_kernels/setup.py +++ b/server/exllama_kernels/setup.py @@ -1,5 +1,17 @@ from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import torch + +extra_cuda_cflags = [] +extra_cflags = [] +if torch.version.hip: + extra_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"] + extra_cuda_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"] + +extra_compile_args = { + "cxx": extra_cflags, + "nvcc": extra_cuda_cflags, +} setup( name="exllama_kernels", @@ -13,6 +25,7 @@ setup( "exllama_kernels/cuda_func/q4_matmul.cu", "exllama_kernels/cuda_func/q4_matrix.cu", ], + extra_compile_args=extra_compile_args, ) ], cmdclass={"build_ext": BuildExtension}, diff --git a/server/exllamav2_kernels/setup.py b/server/exllamav2_kernels/setup.py index 4a16b546..56ffa973 100644 --- a/server/exllamav2_kernels/setup.py +++ b/server/exllamav2_kernels/setup.py @@ -3,11 +3,13 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension import torch extra_cuda_cflags = ["-lineinfo", "-O3"] - +extra_cflags = [] if torch.version.hip: - extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF"] + extra_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"] + extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF", "-DLEGACY_HIPBLAS_DIRECT=ON"] extra_compile_args = { + "cxx": extra_cflags, "nvcc": extra_cuda_cflags, } diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py index 855f4dfc..d6e512c0 100644 --- a/server/text_generation_server/layers/attention/common.py +++ b/server/text_generation_server/layers/attention/common.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models.globals import ATTENTION import torch from typing import Optional @@ -65,5 +66,7 @@ else: max_k: int def clamp(self, max): + if SYSTEM == "rocm": + return self raise NotImplementedError("Not implemented seqlen for paged") return Seqlen(torch.clamp(self.input_lengths, max=max)) diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 9f24ac98..646a763d 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -1,4 +1,5 @@ import os +from typing import Optional import torch from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models.globals import ATTENTION @@ -8,16 +9,28 @@ from loguru import logger major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 -_PARTITION_SIZE = 512 + +_PARTITION_SIZE_V1V2 = 512 +_PARTITION_SIZE_CUSTOM = 256 use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"} ENGINE = "triton" if use_triton else "ck" - PREFILL_IN_KV_CACHE = False +use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0" try: - from vllm._C import cache_ops + if use_rocm_custom_paged_attn: + from vllm._custom_C import paged_attention_custom +except ImportError as e: + log_master( + logger.info, + f"Custom Paged Attention not available. Complete error: {e}", + ) + use_rocm_custom_paged_attn = False + +try: + import vllm._custom_ops as ops except Exception as e: raise ImportError( f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" @@ -36,9 +49,7 @@ def reshape_and_cache( key_cache.view(-1, shape[-2], shape[-1])[slots] = key value_cache.view(-1, shape[-2], shape[-1])[slots] = value else: - cache_ops.reshape_and_cache( - key, value, key_cache, value_cache, slots, "auto", 1.0 - ) + ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) def paged_attention( @@ -48,8 +59,9 @@ def paged_attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - input_lengths: Seqlen, + seqlen: Seqlen, max_s: int, + softcap: Optional[float] = None, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py # Copyright 2023 The vLLM team. All rights @@ -68,11 +80,31 @@ def paged_attention( # limitations under the License. # + if softcap is not None: + raise RuntimeError("Paged attention doesn't support softcapping") + # value_cache => [num_blocks, num_heads, head_size, block_size] block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape + + num_kv_heads = key_cache.shape[1] + gqa_ratio = num_heads // num_kv_heads + use_custom = ( + use_rocm_custom_paged_attn + and (query.dtype == torch.half or query.dtype == torch.bfloat16) + and (head_size == 128 or head_size == 64) + and (block_size == 16 or block_size == 32) + and (gqa_ratio >= 1 and gqa_ratio <= 16) + and max_s <= 32768 + ) + + if not use_custom: + _PARTITION_SIZE = _PARTITION_SIZE_V1V2 + else: + _PARTITION_SIZE = _PARTITION_SIZE_CUSTOM + max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE - input_lengths = input_lengths.input_lengths + input_lengths = seqlen.input_lengths out = torch.empty_like(query) @@ -81,9 +113,13 @@ def paged_attention( # V1 to avoid the overhead of reduction. Also, if the number of # sequences or heads is large, we use V1 since there is enough work # to parallelize. - from vllm._C import ops + import vllm._custom_ops as ops - use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) + use_v1 = ( + max_s <= 8192 + and (max_num_partitions == 1 or num_seqs * num_heads > 512) + and not use_custom + ) if use_v1: ops.paged_attention_v1( out, @@ -115,24 +151,44 @@ def paged_attention( ) max_logits = torch.empty_like(exp_sums) - ops.paged_attention_v2( - out, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - ) + if not use_custom: + ops.paged_attention_v2( + out, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) + else: + paged_attention_custom( + out, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + ) + return out @@ -175,13 +231,14 @@ if ENGINE == "ck": def attention( q, - k, - v, - cu_seqlens, - max_s, - softmax_scale, - window_size_left=-1, - causal=True, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + seqlen: Seqlen, + block_tables: torch.Tensor, + softmax_scale: float, + window_size_left: int = -1, + causal: bool = True, + softcap: float = 0.0, ): if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") @@ -191,46 +248,57 @@ if ENGINE == "ck": # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return flash_attn_2_cuda.varlen_fwd( q, - k, - v, + key_cache, + value_cache, out, - cu_seqlens, - cu_seqlens, - max_s, - max_s, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_q, + None, + None, + None, + None, + seqlen.max_q, + seqlen.max_k, 0.0, softmax_scale, False, causal, + window_size_left, + 0, + softcap, False, None, - ) + )[0] elif ENGINE == "triton": from .flash_attn_triton import triton_attention def attention( q, - k, - v, - cu_seqlens, - max_s, - softmax_scale, - window_size_left=-1, - causal=True, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + seqlen: Seqlen, + block_tables: torch.Tensor, + softmax_scale: float, + window_size_left: int = -1, + causal: bool = True, + softcap: Optional[float] = None, ): + if softcap is not None: + raise NotImplementedError("softcap is only available with CK flash attn") + out = torch.empty_like(q) # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. output, _ = triton_attention( q, - k, - v, + key_cache, + value_cache, out, - cu_seqlens, - cu_seqlens, - max_s, - max_s, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_q, + seqlen.max_q, + seqlen.max_k, causal, softmax_scale, ) diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index 12d7f83a..08306d57 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -1,12 +1,21 @@ import torch from text_generation_server.utils.import_utils import SYSTEM from torch.nn import functional as F +import os if SYSTEM == "rocm": - try: - from vllm import _custom_C - except Exception as e: - raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") + ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in ( + "true", + "1", + ) + + if ROCM_USE_SKINNY_GEMM: + try: + from vllm import _custom_C + except Exception as e: + raise ImportError( + f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}" + ) class FastLinear(torch.nn.Module): @@ -48,6 +57,14 @@ class FastLinearROCm(torch.nn.Module): else: self.bias = None + self.cu_count = torch.cuda.get_device_properties( + device="cuda" + ).multi_processor_count + self.use_skinny_gemm = ( + ROCM_USE_SKINNY_GEMM + and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName + ) + @classmethod def load(cls, config, prefix: str, weights, bias: bool): weight = weights.get_tensor(f"{prefix}.weight") @@ -61,7 +78,11 @@ class FastLinearROCm(torch.nn.Module): weight = self.weight bias = self.bias - if SYSTEM == "rocm" and inp.numel() // inp.shape[-1] == 1: + if ( + self.use_skinny_gemm + and inp.dtype == torch.float16 + and inp.shape[-1] % 8 == 0 + ): batched = False inp_shape = inp.shape @@ -69,13 +90,16 @@ class FastLinearROCm(torch.nn.Module): inp = inp.view(-1, inp_shape[-1]) batched = True - m, k = weight.shape[0], inp_shape[1] - out = torch.empty( - inp_shape[0], weight.shape[0], dtype=inp.dtype, device="cuda" - ) - if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192): - _custom_C.LLMM1(weight, inp, out, 8) - elif k <= 8192 and k % 8 == 0 and m % 4 == 0: + m, n, k = weight.shape[0], inp_shape[0], inp_shape[1] + if m > 8 and n <= 4: + out = torch.empty( + inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device + ) + _custom_C.wvSpltK(weight, inp, out, n, self.cu_count) + elif m % 4 == 0 and n == 1 and k <= 8192: + out = torch.empty( + inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device + ) _custom_C.LLMM1(weight, inp, out, 4) else: out = F.linear(inp, weight) diff --git a/server/text_generation_server/layers/moe/__init__.py b/server/text_generation_server/layers/moe/__init__.py index 3171af90..7e8ac2c8 100644 --- a/server/text_generation_server/layers/moe/__init__.py +++ b/server/text_generation_server/layers/moe/__init__.py @@ -19,7 +19,10 @@ from text_generation_server.utils.weights import ( Weights, ) -if SYSTEM != "ipex": +if SYSTEM == "rocm": + from .fused_moe_rocm import grouped_topk + from vllm.model_executor.layers.fused_moe import fused_topk +elif SYSTEM != "ipex": from moe_kernels.fused_moe import fused_topk, grouped_topk diff --git a/server/text_generation_server/layers/moe/fused_moe_rocm.py b/server/text_generation_server/layers/moe/fused_moe_rocm.py new file mode 100644 index 00000000..68accb99 --- /dev/null +++ b/server/text_generation_server/layers/moe/fused_moe_rocm.py @@ -0,0 +1,52 @@ +# coding=utf-8 +# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch +import torch.distributed + + +# TODO: Remove the functions once moe_kernel are built for ROCM +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor]: + scores = torch.softmax(gating_output, dim=-1) + num_token = scores.shape[0] + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) + .reshape(num_token, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids diff --git a/server/text_generation_server/layers/moe/unquantized.py b/server/text_generation_server/layers/moe/unquantized.py index 8f1d9b3f..d9d62c0e 100644 --- a/server/text_generation_server/layers/moe/unquantized.py +++ b/server/text_generation_server/layers/moe/unquantized.py @@ -6,7 +6,9 @@ import torch.nn as nn from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.weights import UnquantizedWeight, Weights -if SYSTEM != "ipex": +if SYSTEM == "rocm": + from vllm.model_executor.layers.fused_moe import fused_moe +elif SYSTEM != "ipex": from moe_kernels.fused_moe import fused_moe @@ -52,6 +54,17 @@ class UnquantizedSparseMoELayer(nn.Module): ) def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: + if SYSTEM == "rocm": + return fused_moe( + x, + self.gate_up_proj, + self.down_proj, + gating_output, + self.topk, + renormalize=self.renormalize, + inplace=True, + ) + return fused_moe( x, w1=self.gate_up_proj, diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index ac191ec3..88c2cf80 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -390,6 +390,7 @@ class DeepseekV2MLP(nn.Module): if ( SYSTEM == "rocm" and self.hidden_act == "silu" + and hidden_states.dtype == torch.float16 and hidden_states.shape[0] == 1 and not self.quantize ): diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 758e39aa..df48c6f7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -316,12 +316,17 @@ class LlamaMLP(nn.Module): # TODO: This is a hotfix to be removed & properly refactored. self.quantize = config.quantize + self.hidden_size = config.hidden_size + def forward(self, hidden_states, adapter_data): if ( SYSTEM == "rocm" and self.hidden_act == "silu" + and hidden_states.dtype == torch.float16 and hidden_states.shape[0] == 1 and not self.quantize + and self.hidden_size + != 16384 # TODO: Temporary workaround for `LLMM_Silu` kernel not working with LLama3.1 405B; needs refactoring once fixed. ): out = torch.empty( hidden_states.shape[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 3e16d371..341a2352 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -303,6 +303,7 @@ class MistralMLP(nn.Module): if ( SYSTEM == "rocm" and self.hidden_act == "silu" + and hidden_states.dtype == torch.float16 and hidden_states.shape[0] == 1 and not self.quantize ): diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 57582ebc..bc9d44a0 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1125,12 +1125,12 @@ class FlashCausalLM(Model): else: self.kv_cache = [ ( - torch.empty( + torch.zeros( (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x), dtype=dtype, device=device, ), - torch.empty( + torch.zeros( (num_blocks, num_heads, head_size, BLOCK_SIZE), dtype=dtype, device=device, @@ -1320,8 +1320,7 @@ class FlashCausalLM(Model): elif CUDA_GRAPHS is not None: tuning_sequences = CUDA_GRAPHS else: - # For seqlen = 1, we dispatch to LLMM1 kernel. - tuning_sequences = [2, 3, 4, 5, 6, 7] + tuning_sequences = [1, 2, 3, 4, 5, 6, 7] tunableop_filepath = os.path.join( HUGGINGFACE_HUB_CACHE, @@ -1330,7 +1329,11 @@ class FlashCausalLM(Model): log_master( logger.info, - f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.", + f"PyTorch TunableOp is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.", + ) + + torch.cuda.tunable.set_filename( + tunableop_filepath, insert_device_ordinal=False ) if os.path.isfile(tunableop_filepath): @@ -1346,7 +1349,8 @@ class FlashCausalLM(Model): log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}") self.tunableop_warmup(seqlen) torch.cuda.tunable.write_file(tunableop_filepath) - torch.cuda.tunable.tuning_enable(False) + if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1": + torch.cuda.tunable.tuning_enable(False) else: log_master( logger.info, @@ -1382,6 +1386,7 @@ class FlashCausalLM(Model): cu_seqlen_prefill = torch.tensor( [0, seqlen], device=self.device, dtype=torch.int32 ) + max_s = seqlen seqlen = Seqlen( input_lengths=input_lengths, prefix_lengths=prefix_lens_tensor, @@ -1399,7 +1404,7 @@ class FlashCausalLM(Model): block_tables=None, seqlen=seqlen, slots=slots, - max_s=seqlen, + max_s=max_s, lm_head_indices=None, prefill_cache_indices=None, )