From b2b5df0e946cfcbb7c679d94c0642bc7a6ad8f5e Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Mon, 27 Nov 2023 14:08:12 +0100 Subject: [PATCH] Add RoCm support (#1243) This PR adds support for AMD Instinct MI210 & MI250 GPUs, with paged attention and FAv2 support. Remaining items to discuss, on top of possible others: * Should we have a `ghcr.io/huggingface/text-generation-inference:1.1.0+rocm` hosted image, or is it too early? * Should we set up a CI on MI210/MI250? I don't have access to the runners of TGI though. * Are we comfortable with those changes being directly in TGI, or do we need a fork? --------- Co-authored-by: Felix Marty Co-authored-by: OlivierDehaene Co-authored-by: Your Name --- .github/workflows/build.yaml | 93 ++++++++++- Dockerfile | 8 +- Dockerfile_amd | 153 ++++++++++++++++++ README.md | 6 +- docs/source/quicktour.md | 2 + docs/source/supported_models.md | 6 +- server/Makefile | 5 +- server/Makefile-flash-att | 2 +- server/Makefile-flash-att-v2 | 25 ++- server/Makefile-vllm | 15 +- server/custom_kernels/setup.py | 9 +- server/requirements_common.txt | 46 ++++++ ...requirements.txt => requirements_cuda.txt} | 0 server/requirements_rocm.txt | 46 ++++++ .../custom_modeling/flash_llama_modeling.py | 37 +++-- .../custom_modeling/flash_mistral_modeling.py | 34 +++- .../custom_modeling/flash_neox_modeling.py | 3 +- .../custom_modeling/flash_rw_modeling.py | 6 +- .../custom_modeling/idefics_modeling.py | 49 ++++-- .../utils/flash_attn.py | 43 ++++- .../utils/import_utils.py | 4 + server/text_generation_server/utils/layers.py | 65 ++++++-- 22 files changed, 575 insertions(+), 82 deletions(-) create mode 100644 Dockerfile_amd create mode 100644 server/requirements_common.txt rename server/{requirements.txt => requirements_cuda.txt} (100%) create mode 100644 server/requirements_rocm.txt create mode 100644 server/text_generation_server/utils/import_utils.py diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 11a95f4b..395a0b6a 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -59,7 +59,7 @@ jobs: build-and-push-image: concurrency: - group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }} + group: ${{ github.workflow }}-build-and-push-image-${{ github.head_ref || github.run_id }} cancel-in-progress: true needs: start-runner # required to start the main job when the runner is ready runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner @@ -146,6 +146,95 @@ jobs: cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=min cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=min + build-and-push-image-rocm: + concurrency: + group: ${{ github.workflow }}-build-and-push-image-rocm-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + needs: start-runner # required to start the main job when the runner is ready + runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner + permissions: + contents: write + packages: write + # This is used to complete the identity challenge + # with sigstore/fulcio when running outside of PRs. + id-token: write + security-events: write + steps: + - name: Checkout repository + uses: actions/checkout@v3 + - name: Initialize Docker Buildx + uses: docker/setup-buildx-action@v2.0.0 + with: + install: true + - name: Inject slug/short variables + uses: rlespinasse/github-slug-action@v4.4.1 + - name: Tailscale + uses: tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966 + with: + authkey: ${{ secrets.TAILSCALE_AUTHKEY }} + - name: Login to GitHub Container Registry + if: github.event_name != 'pull_request' + uses: docker/login-action@v2 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Login to internal Container Registry + uses: docker/login-action@v2.1.0 + with: + username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }} + password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }} + registry: registry.internal.huggingface.tech + - name: Login to Azure Container Registry + if: github.event_name != 'pull_request' + uses: docker/login-action@v2.1.0 + with: + username: ${{ secrets.AZURE_DOCKER_USERNAME }} + password: ${{ secrets.AZURE_DOCKER_PASSWORD }} + registry: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io + # If pull request + - name: Extract metadata (tags, labels) for Docker + if: ${{ github.event_name == 'pull_request' }} + id: meta-pr + uses: docker/metadata-action@v4.3.0 + with: + images: | + registry.internal.huggingface.tech/api-inference/community/text-generation-inference + tags: | + type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}-rocm + # If main, release or tag + - name: Extract metadata (tags, labels) for Docker + if: ${{ github.event_name != 'pull_request' }} + id: meta + uses: docker/metadata-action@v4.3.0 + with: + flavor: | + latest=false + images: | + registry.internal.huggingface.tech/api-inference/community/text-generation-inference + ghcr.io/huggingface/text-generation-inference + db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference + tags: | + type=semver,pattern={{version}}-rocm + type=semver,pattern={{major}}.{{minor}}-rocm + type=raw,value=latest-rocm,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }} + type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}-rocm + - name: Build and push Docker image + id: build-and-push + uses: docker/build-push-action@v4 + with: + context: . + file: Dockerfile_amd + push: true + platforms: 'linux/amd64' + build-args: | + GIT_SHA=${{ env.GITHUB_SHA }} + DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}-rocm + tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }} + labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }} + cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min + cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min + integration-tests: concurrency: group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }} @@ -153,6 +242,7 @@ jobs: needs: - start-runner - build-and-push-image # Wait for the docker image to be built + - build-and-push-image-rocm runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner env: DOCKER_VOLUME: /cache @@ -187,6 +277,7 @@ jobs: needs: - start-runner - build-and-push-image + - build-and-push-image-rocm - integration-tests runs-on: ubuntu-latest env: diff --git a/Dockerfile b/Dockerfile index d45aaec5..02540f81 100644 --- a/Dockerfile +++ b/Dockerfile @@ -106,7 +106,7 @@ WORKDIR /usr/src COPY server/Makefile-flash-att-v2 Makefile # Build specific version of flash attention v2 -RUN make build-flash-attention-v2 +RUN make build-flash-attention-v2-cuda # Build Transformers exllama kernels FROM kernel-builder as exllama-kernels-builder @@ -152,7 +152,7 @@ WORKDIR /usr/src COPY server/Makefile-vllm Makefile # Build specific version of vllm -RUN make build-vllm +RUN make build-vllm-cuda # Text Generation Inference base image FROM nvidia/cuda:12.1.0-base-ubuntu20.04 as base @@ -209,7 +209,7 @@ COPY server server COPY server/Makefile server/Makefile RUN cd server && \ make gen-server && \ - pip install -r requirements.txt && \ + pip install -r requirements_cuda.txt && \ pip install ".[bnb, accelerate, quantize, peft]" --no-cache-dir # Install benchmarker @@ -224,7 +224,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins g++ \ && rm -rf /var/lib/apt/lists/* -# AWS Sagemaker compatbile image +# AWS Sagemaker compatible image FROM base as sagemaker COPY sagemaker-entrypoint.sh entrypoint.sh diff --git a/Dockerfile_amd b/Dockerfile_amd new file mode 100644 index 00000000..dd331a5d --- /dev/null +++ b/Dockerfile_amd @@ -0,0 +1,153 @@ +# Rust builder +FROM lukemathwalker/cargo-chef:latest-rust-1.71 AS chef +WORKDIR /usr/src + +ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse + +FROM chef as planner +COPY Cargo.toml Cargo.toml +COPY rust-toolchain.toml rust-toolchain.toml +COPY proto proto +COPY benchmark benchmark +COPY router router +COPY launcher launcher +RUN cargo chef prepare --recipe-path recipe.json + +FROM chef AS builder + +ARG GIT_SHA +ARG DOCKER_LABEL + +RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ + unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ + unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ + rm -f $PROTOC_ZIP + +COPY --from=planner /usr/src/recipe.json recipe.json +RUN cargo chef cook --release --recipe-path recipe.json + +COPY Cargo.toml Cargo.toml +COPY rust-toolchain.toml rust-toolchain.toml +COPY proto proto +COPY benchmark benchmark +COPY router router +COPY launcher launcher +RUN cargo build --release + +# Text Generation Inference base image for RoCm +FROM rocm/dev-ubuntu-20.04:5.7 as base + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + build-essential \ + ca-certificates \ + ccache \ + curl \ + git \ + make \ + libssl-dev \ + g++ \ + # Needed to build VLLM & flash. + rocthrust-dev \ + hipsparse-dev \ + hipblas-dev && \ + rm -rf /var/lib/apt/lists/* + +# Keep in sync with `server/pyproject.toml +ARG MAMBA_VERSION=23.1.0-1 +ARG PYTORCH_VERSION='2.2.0.dev0' +ARG ROCM_VERSION='5.7' +ARG PYTHON_VERSION='3.10.10' +# Automatically set by buildx +ARG TARGETPLATFORM +ENV PATH /opt/conda/bin:$PATH + +# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda. +# Install mamba +# translating Docker's TARGETPLATFORM into mamba arches +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") MAMBA_ARCH=aarch64 ;; \ + *) MAMBA_ARCH=x86_64 ;; \ + esac && \ + curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" +RUN chmod +x ~/mambaforge.sh && \ + bash ~/mambaforge.sh -b -p /opt/conda && \ + mamba init && \ + rm ~/mambaforge.sh + +# Install PyTorch nightly (2.2.0.dev2023) compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6. +RUN pip install --pre torch==2.2.0.dev20231106 --index-url https://download.pytorch.org/whl/nightly/rocm5.7 + +FROM base AS kernel-builder + +# Build vllm kernels +FROM kernel-builder AS vllm-builder +WORKDIR /usr/src + +COPY server/Makefile-vllm Makefile + +# Build specific version of vllm +RUN make build-vllm-rocm + +# Build Flash Attention v2 kernels +FROM kernel-builder AS flash-att-v2-builder +WORKDIR /usr/src + +COPY server/Makefile-flash-att-v2 Makefile + +# Build specific version of flash attention v2 +RUN make build-flash-attention-v2-rocm + +# Build Transformers CUDA kernels (gpt-neox and bloom) +FROM kernel-builder as custom-kernels-builder +WORKDIR /usr/src +COPY server/custom_kernels/ . +RUN PYTORCH_ROCM_ARCH=gfx90a python setup.py build + +FROM base as base-copy + +# Text Generation Inference base env +ENV HUGGINGFACE_HUB_CACHE=/data \ + HF_HUB_ENABLE_HF_TRANSFER=1 \ + PORT=80 + +# Copy builds artifacts from vllm builder +COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages + +# Copy build artifacts from flash attention v2 builder +COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages + +# Copy build artifacts from custom kernels builder +COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages + +# Install flash-attention dependencies +RUN pip install einops --no-cache-dir + +# Install server +COPY proto proto +COPY server server +COPY server/Makefile server/Makefile +RUN cd server && \ + make gen-server && \ + pip install -r requirements_rocm.txt && \ + pip install ".[accelerate, peft]" --no-cache-dir + +# Install benchmarker +COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark +# Install router +COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router +# Install launcher +COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher + +# AWS Sagemaker compatible image +FROM base-copy as sagemaker +COPY sagemaker-entrypoint.sh entrypoint.sh +RUN chmod +x entrypoint.sh + +ENTRYPOINT ["./entrypoint.sh"] + +# Final image +FROM base-copy + +ENTRYPOINT ["text-generation-launcher"] +CMD ["--json-output"] diff --git a/README.md b/README.md index e4f90d53..0fa7a538 100644 --- a/README.md +++ b/README.md @@ -74,7 +74,9 @@ curl 127.0.0.1:8080/generate \ -H 'Content-Type: application/json' ``` -**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar. +**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar. + +**Note:** TGI supports AMD Instinct MI210 and MI250 [to some extent](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.1.0+rocm --model-id $model` instead of the command above. To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli): ``` @@ -189,7 +191,7 @@ sudo apt-get install libssl-dev gcc -y ### CUDA Kernels -The custom CUDA kernels are only tested on NVIDIA A100s. If you have any installation or runtime issues, you can remove +The custom CUDA kernels are only tested on NVIDIA A100, AMD MI210 and AMD MI250. If you have any installation or runtime issues, you can remove the kernels by using the `DISABLE_CUSTOM_KERNELS=True` environment variable. Be aware that the official Docker image has them enabled by default. diff --git a/docs/source/quicktour.md b/docs/source/quicktour.md index efcaae28..b0a77599 100644 --- a/docs/source/quicktour.md +++ b/docs/source/quicktour.md @@ -15,6 +15,8 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingf To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) . We also recommend using NVIDIA drivers with CUDA version 11.8 or higher. +To use TGI on RoCm-enabled AMD GPUs (only MI210 and MI250 are tested), please use the image `ghcr.io/huggingface/text-generation-inference:1.1.1+rocm` instead. For details about the usage on RoCm, please refer to the [Supported Hardware section](./supported_models#supported-hardware) and [AMD documentation](https://rocm.docs.amd.com/en/latest/deploy/docker.html). + Once TGI is running, you can use the `generate` endpoint by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint. diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index 8b4c33b1..d7d45b70 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -39,9 +39,9 @@ text-generation-launcher --model-id ## Supported Hardware -TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 11.8+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other hardware, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed. +TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 11.8+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed. + +TGI also has support of RoCm-enabled AMD Instinct MI210 and MI250 GPUs, with paged attention and flash attention v2 support. The following features are missing from the RoCm version of TGI: quantization and flash [layer norm kernel](https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm). TGI is also supported on the following AI hardware accelerators: - *Habana first-gen Gaudi and Gaudi2:* check out this [example](https://github.com/huggingface/optimum-habana/tree/main/text-generation-inference) how to serve models with TGI on Gaudi and Gaudi2 with [Optimum Habana](https://huggingface.co/docs/optimum/habana/index) - - diff --git a/server/Makefile b/server/Makefile index 92958d02..2810a528 100644 --- a/server/Makefile +++ b/server/Makefile @@ -18,11 +18,12 @@ gen-server: install: gen-server pip install pip --upgrade - pip install -r requirements.txt + pip install -r requirements_cuda.txt pip install -e ".[bnb, accelerate, quantize, peft]" run-dev: SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded export-requirements: - poetry export -o requirements.txt -E bnb --without-hashes + poetry export -o requirements_cuda.txt --extras bnb --without-hashes + poetry export -o requirements_rocm.txt --without-hashes diff --git a/server/Makefile-flash-att b/server/Makefile-flash-att index bc1d37ef..b4b2e40c 100644 --- a/server/Makefile-flash-att +++ b/server/Makefile-flash-att @@ -2,7 +2,7 @@ flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec flash-attention: # Clone flash attention - pip install packaging + pip install -U packaging ninja --no-cache-dir git clone https://github.com/HazyResearch/flash-attention.git build-flash-attention: flash-attention diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index 583437b2..8b9f289d 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,13 +1,26 @@ flash_att_v2_commit := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3 +build-flash-attention-v2-cuda: FLASH_ATTN_V2_COMMIT=02ac572f3ffc4f402e4183aaa6824b45859d3ed3 +build-flash-attention-v2-cuda: FLASH_REPOSITORY=https://github.com/HazyResearch/flash-attention.git +build-flash-attention-v2-cuda: BRANCH=main +build-flash-attention-v2-cuda: PYTORCH_ROCM_ARCH="" +build-flash-attention-v2-cuda: build-flash-attention-v2 + +build-flash-attention-v2-rocm: FLASH_ATTN_V2_COMMIT=8736558c287ff2ef28b24878e42828c595ac3e69 +build-flash-attention-v2-rocm: FLASH_REPOSITORY=https://github.com/fxmarty/flash-attention-rocm +build-flash-attention-v2-rocm: BRANCH=remove-offload-arch-native +build-flash-attention-v2-rocm: PYTORCH_ROCM_ARCH=gfx90a +build-flash-attention-v2-rocm: build-flash-attention-v2 + flash-attention-v2: - # Clone flash attention - pip install packaging - git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2 + # Clone flash attention + pip install -U packaging ninja --no-cache-dir + git clone --single-branch --branch $(BRANCH) $(FLASH_REPOSITORY) flash-attention-v2 build-flash-attention-v2: flash-attention-v2 - cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit) - cd flash-attention-v2 && python setup.py build + cd flash-attention-v2 && git fetch && git checkout $(FLASH_ATTN_V2_COMMIT) + cd flash-attention-v2 && git submodule update --init --recursive + cd flash-attention-v2 && PYTORCH_ROCM_ARCH=$(PYTORCH_ROCM_ARCH) python setup.py build install-flash-attention-v2: build-flash-attention-v2 - cd flash-attention-v2 && python setup.py install \ No newline at end of file + cd flash-attention-v2 && git submodule update --init --recursive && python setup.py install \ No newline at end of file diff --git a/server/Makefile-vllm b/server/Makefile-vllm index c601e452..ddb648ea 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,11 +1,20 @@ -vllm_commit := f8a1e39fae05ca610be8d5a78be9d40f5274e5fc +build-vllm-cuda: REPOSITORY=https://github.com/vllm-project/vllm.git +build-vllm-cuda: VLLM_COMMIT=f8a1e39fae05ca610be8d5a78be9d40f5274e5fc +build-vllm-cuda: BRANCH=main +build-vllm-cuda: build-vllm + +build-vllm-rocm: REPOSITORY=https://github.com/fxmarty/vllm-public.git +build-vllm-rocm: VLLM_COMMIT=ad9b7c4095ef54419a0533d254f2ad84bd2dfcae +build-vllm-rocm: BRANCH=rotary-no-positions-split-cos-sin +build-vllm-rocm: build-vllm vllm: # Clone vllm - git clone https://github.com/vllm-project/vllm.git + pip install -U ninja packaging --no-cache-dir + git clone --single-branch --branch $(BRANCH) $(REPOSITORY) vllm build-vllm: vllm - cd vllm && git fetch && git checkout $(vllm_commit) + cd vllm && git fetch && git checkout $(VLLM_COMMIT) cd vllm && python setup.py build install-vllm: build-vllm diff --git a/server/custom_kernels/setup.py b/server/custom_kernels/setup.py index 43b8ee4e..69f6b72a 100644 --- a/server/custom_kernels/setup.py +++ b/server/custom_kernels/setup.py @@ -1,5 +1,10 @@ from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import torch + +extra_compile_args = ["-std=c++17"] +if not torch.version.hip: + extra_compile_args.append("-arch=compute_80") setup( name="custom_kernels", @@ -7,12 +12,12 @@ setup( CUDAExtension( name="custom_kernels.fused_bloom_attention_cuda", sources=["custom_kernels/fused_bloom_attention_cuda.cu"], - extra_compile_args=["-arch=compute_80", "-std=c++17"], + extra_compile_args=extra_compile_args, ), CUDAExtension( name="custom_kernels.fused_attention_cuda", sources=["custom_kernels/fused_attention_cuda.cu"], - extra_compile_args=["-arch=compute_80", "-std=c++17"], + extra_compile_args=extra_compile_args, ), ], cmdclass={"build_ext": BuildExtension}, diff --git a/server/requirements_common.txt b/server/requirements_common.txt new file mode 100644 index 00000000..5a321834 --- /dev/null +++ b/server/requirements_common.txt @@ -0,0 +1,46 @@ +backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" +certifi==2023.11.17 ; python_version >= "3.9" and python_version < "3.13" +charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" +click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" +colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") +deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" +einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" +filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13" +fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13" +googleapis-common-protos==1.61.0 ; python_version >= "3.9" and python_version < "3.13" +grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" +grpcio-reflection==1.59.3 ; python_version >= "3.9" and python_version < "3.13" +grpcio-status==1.59.3 ; python_version >= "3.9" and python_version < "3.13" +grpcio==1.59.3 ; python_version >= "3.9" and python_version < "3.13" +hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13" +huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13" +idna==3.4 ; python_version >= "3.9" and python_version < "3.13" +loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" +numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" +packaging==23.2 ; python_version >= "3.9" and python_version < "3.13" +pillow==10.1.0 ; python_version >= "3.9" and python_version < "3.13" +protobuf==4.25.1 ; python_version >= "3.9" and python_version < "3.13" +pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" +regex==2023.10.3 ; python_version >= "3.9" and python_version < "3.13" +requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" +safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13" +scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13" +sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" +setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13" +tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13" +tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.33.3 ; python_version >= "3.9" and python_version < "3.13" +typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" +typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13" +urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13" +win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" +wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/requirements.txt b/server/requirements_cuda.txt similarity index 100% rename from server/requirements.txt rename to server/requirements_cuda.txt diff --git a/server/requirements_rocm.txt b/server/requirements_rocm.txt new file mode 100644 index 00000000..5a321834 --- /dev/null +++ b/server/requirements_rocm.txt @@ -0,0 +1,46 @@ +backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" +certifi==2023.11.17 ; python_version >= "3.9" and python_version < "3.13" +charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" +click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" +colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") +deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" +einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" +filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13" +fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13" +googleapis-common-protos==1.61.0 ; python_version >= "3.9" and python_version < "3.13" +grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" +grpcio-reflection==1.59.3 ; python_version >= "3.9" and python_version < "3.13" +grpcio-status==1.59.3 ; python_version >= "3.9" and python_version < "3.13" +grpcio==1.59.3 ; python_version >= "3.9" and python_version < "3.13" +hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13" +huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13" +idna==3.4 ; python_version >= "3.9" and python_version < "3.13" +loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" +numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" +packaging==23.2 ; python_version >= "3.9" and python_version < "3.13" +pillow==10.1.0 ; python_version >= "3.9" and python_version < "3.13" +protobuf==4.25.1 ; python_version >= "3.9" and python_version < "3.13" +pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" +regex==2023.10.3 ; python_version >= "3.9" and python_version < "3.13" +requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" +safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13" +scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13" +sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" +setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13" +tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13" +tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.33.3 ; python_version >= "3.9" and python_version < "3.13" +typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" +typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13" +urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13" +win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" +wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 69608e1c..4aeb447d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -26,9 +26,6 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -# Flash attention imports -import dropout_layer_norm - from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.layers import ( TensorParallelRowLinear, @@ -38,6 +35,12 @@ from text_generation_server.utils.layers import ( TensorParallelHead, get_linear, ) +from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM + +if IS_CUDA_SYSTEM: + import dropout_layer_norm +elif IS_ROCM_SYSTEM: + from vllm import layernorm_ops class LlamaConfig(PretrainedConfig): @@ -120,7 +123,7 @@ class LlamaRMSNorm(nn.Module): hidden_states = hidden_states.to(self.weight.dtype) return self.weight * hidden_states, residual - else: + elif IS_CUDA_SYSTEM: # faster post attention rms norm normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( hidden_states, @@ -143,6 +146,22 @@ class LlamaRMSNorm(nn.Module): res = hidden_states return normed_hidden_states, res + elif IS_ROCM_SYSTEM: + # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not. + if residual is not None: + hidden_states += residual + residual = hidden_states + + out = torch.empty_like(hidden_states) + layernorm_ops.rms_norm( + out, + hidden_states, + self.weight.data, + self.variance_epsilon, + ) + return out, residual + else: + raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.") def load_attention(config, prefix, weights): @@ -204,9 +223,6 @@ class FlashLlamaAttention(torch.nn.Module): self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads - # self.rotary_emb = PositionRotaryEmbedding.load( - # config=config, prefix=f"{prefix}.rotary_emb", weights=weights - # ) self.rotary_emb = PositionRotaryEmbedding.static( config=config, dim=self.head_size, @@ -261,9 +277,8 @@ class FlashLlamaAttention(torch.nn.Module): ) query = query.view(-1, self.num_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) - - self.rotary_emb(query, cos, sin) - self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) + + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) paged_attention.reshape_and_cache( kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots @@ -297,7 +312,7 @@ class FlashLlamaAttention(torch.nn.Module): input_lengths, max_s, ) - + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 2d731406..959949f0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -26,11 +26,8 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -# Flash attention imports -import dropout_layer_norm - from text_generation_server.utils import paged_attention, flash_attn -from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2 +from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2_ROCM, HAS_FLASH_ATTN_V2_CUDA from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -39,8 +36,14 @@ from text_generation_server.utils.layers import ( TensorParallelHead, get_linear, ) +from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM -if not HAS_FLASH_ATTN_V2: +if IS_CUDA_SYSTEM: + import dropout_layer_norm +elif IS_ROCM_SYSTEM: + from vllm import layernorm_ops + +if not HAS_FLASH_ATTN_V2_CUDA and not HAS_FLASH_ATTN_V2_ROCM: raise ImportError("Mistral model requires flash attn v2") @@ -126,7 +129,7 @@ class MistralRMSNorm(nn.Module): hidden_states = hidden_states.to(self.weight.dtype) return self.weight * hidden_states, residual - else: + elif IS_CUDA_SYSTEM: # faster post attention rms norm normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( hidden_states, @@ -149,6 +152,22 @@ class MistralRMSNorm(nn.Module): res = hidden_states return normed_hidden_states, res + elif IS_ROCM_SYSTEM: + # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not. + if residual is not None: + hidden_states += residual + residual = hidden_states + + out = torch.empty_like(hidden_states) + layernorm_ops.rms_norm( + out, + hidden_states, + self.weight.data, + self.variance_epsilon, + ) + return out, residual + else: + raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.") def load_attention(config, prefix, weights): @@ -261,8 +280,7 @@ class MistralAttention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) - self.rotary_emb(query, cos, sin) - self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) if prefill_cache_indices is not None: kv_to_cache = kv[prefill_cache_indices] diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index af4ba96b..eea5f787 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -135,8 +135,7 @@ class FlashNeoxAttention(torch.nn.Module): qkv = qkv.view(-1, 3, self.num_heads, self.head_size) # Inplace rotary - self.rotary_emb(qkv[:, 0], cos, sin) - self.rotary_emb(qkv[:, 1], cos, sin) + self.rotary_emb(qkv[:, 0], qkv[:, 1], cos, sin) paged_attention.reshape_and_cache( qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 00f953a6..6a530f3c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -185,8 +185,7 @@ class FlashRWAttention(torch.nn.Module): kv = kv.view(-1, 2, self.num_heads_kv, self.head_size) # Inplace rotary - self.rotary_emb(query, cos, sin) - self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) paged_attention.reshape_and_cache( kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots @@ -301,8 +300,7 @@ class FlashRWLargeAttention(torch.nn.Module): query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size) # Inplace rotary - self.rotary_emb(query, cos, sin) - self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin) + self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin) paged_attention.reshape_and_cache( kv[:, :, 0].contiguous(), diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py index 1ffe6276..946f7683 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -55,8 +55,12 @@ from text_generation_server.utils.layers import ( PositionRotaryEmbedding, FastLinear, ) -import dropout_layer_norm +from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM +if IS_CUDA_SYSTEM: + import dropout_layer_norm +elif IS_ROCM_SYSTEM: + from vllm import layernorm_ops @dataclass class BaseModelOutputWithPastImage(BaseModelOutputWithPast): @@ -370,7 +374,7 @@ class IdeficsRMSNorm(nn.Module): hidden_states = hidden_states.to(self.weight.dtype) return self.weight * hidden_states - else: + elif IS_CUDA_SYSTEM: # faster post attention rms norm unwrap = False if len(hidden_states.shape) > 2: @@ -402,6 +406,32 @@ class IdeficsRMSNorm(nn.Module): normed_hidden_states = normed_hidden_states.view(*shape) return normed_hidden_states + elif IS_ROCM_SYSTEM: + # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not. + if residual is not None: + hidden_states += residual + residual = hidden_states + + unwrap = False + if len(hidden_states.shape) > 2: + unwrap = True + shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, shape[-1]) + + out = torch.empty_like(hidden_states) + layernorm_ops.rms_norm( + out, + hidden_states, + self.weight.data, + self.variance_epsilon, + ) + + if unwrap: + out = out.view(*shape) + + return out + else: + raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.") # this was adapted from LlamaMLP @@ -581,15 +611,12 @@ class IdeficsAttention(nn.Module): position_ids.view(-1), max_s, hidden_states.dtype ) - shape = query_states.shape - query_states = self.rotary_emb( - query_states.view(-1, *shape[2:]), cos, sin - ).view(shape) - - shape = key_states.shape - key_states = self.rotary_emb( - key_states.reshape(-1, *shape[2:]), cos, sin - ).view(shape) + query_shape = query_states.shape + key_shape = key_states.shape + self.rotary_emb(query_states.view(-1, *query_shape[2:]), key_states.reshape(-1, *key_shape[2:]), cos, sin) + + query_states = query_states.view(query_shape) + key_states = key_states.view(key_shape) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index 8f0fcee6..aca95e11 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -3,6 +3,8 @@ import torch from loguru import logger +from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM + if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") @@ -15,7 +17,8 @@ is_sm8x = major == 8 and minor >= 0 is_sm90 = major == 9 and minor == 0 HAS_FLASH_ATTN = False -HAS_FLASH_ATTN_V2 = False +HAS_FLASH_ATTN_V2_CUDA = False +HAS_FLASH_ATTN_V2_ROCM = False try: try: import flash_attn_2_cuda @@ -30,7 +33,8 @@ try: f"GPU with CUDA capability {major} {minor} is not supported for " "Flash Attention V2" ) - HAS_FLASH_ATTN_V2 = True + HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM + HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM except ImportError as e: try: import flash_attn_cuda @@ -41,10 +45,17 @@ except ImportError as e: "or install flash attention with `cd server && make install install-flash-attention`" ) from e - if not (is_sm75 or is_sm8x or is_sm90): + if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90): raise ImportError( f"GPU with CUDA capability {major} {minor} is not supported" ) from e + elif IS_ROCM_SYSTEM: + for idx in range(torch.cuda.device_count()): + if "MI210" not in torch.cuda.get_device_name(idx) and "MI250" not in torch.cuda.get_device_name(idx): + raise ImportError( + f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" + ) + logger.warning(f"Unable to use Flash Attention V2: {e}") HAS_FLASH_ATTN = True @@ -59,7 +70,7 @@ def attention( softmax_scale, window_size_left=-1, ): - if HAS_FLASH_ATTN_V2: + if HAS_FLASH_ATTN_V2_CUDA: return flash_attn_2_cuda.varlen_fwd( q, k, @@ -78,8 +89,28 @@ def attention( False, None, ) - - if HAS_FLASH_ATTN: + elif HAS_FLASH_ATTN_V2_ROCM: + if window_size_left != -1: + raise ValueError(f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left}).") + + # RoCm flash API does not take the window_size_left and window_size_right arguments. + return flash_attn_2_cuda.varlen_fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + True, + False, + None, + ) + elif HAS_FLASH_ATTN: if window_size_left != -1: raise NotImplementedError( "window_size_left is only available with flash attn v2" diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py new file mode 100644 index 00000000..428c9f3e --- /dev/null +++ b/server/text_generation_server/utils/import_utils.py @@ -0,0 +1,4 @@ +import torch + +IS_ROCM_SYSTEM = torch.version.hip is not None +IS_CUDA_SYSTEM = torch.version.cuda is not None diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index e6a90116..a93ccd0e 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -12,14 +12,13 @@ HAS_BITS_AND_BYTES = True try: import bitsandbytes as bnb from bitsandbytes.nn import Int8Params, Params4bit - except ImportError: HAS_BITS_AND_BYTES = False from accelerate import init_empty_weights from text_generation_server.utils.gptq.quant_linear import QuantLinear - +from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM HAS_AWQ = True try: @@ -525,11 +524,14 @@ class TensorParallelEmbedding(nn.Module): try: - import dropout_layer_norm + if IS_CUDA_SYSTEM: + import dropout_layer_norm + else: + dropout_layer_norm = None class FastLayerNorm(nn.LayerNorm): def forward(self, hidden_states, residual=None): - if hidden_states.shape[-1] > 8192: + if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: if residual is not None: hidden_states += residual residual = hidden_states @@ -561,14 +563,16 @@ try: residual = hidden_states return normed_hidden_states, residual - except ImportError: pass try: - from flash_attn.layers.rotary import RotaryEmbedding - import rotary_emb + if IS_CUDA_SYSTEM: + from flash_attn.layers.rotary import RotaryEmbedding + import rotary_emb + elif IS_ROCM_SYSTEM: + from vllm import pos_encoding_ops def _create_inv_freq(dim, base, device): inv_freq = 1.0 / ( @@ -597,6 +601,37 @@ try: self.scaling_factor = scaling_factor self.dynamic_args = None + def forward(self, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + # Such controlflows may add some overhead. + if IS_CUDA_SYSTEM: + rotary_dim = cos.shape[-1] + q1 = query[..., :rotary_dim] + q2 = query[..., rotary_dim : 2 * rotary_dim] + + rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) + + k1 = key[..., :rotary_dim] + k2 = key[..., rotary_dim : 2 * rotary_dim] + + rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) + elif IS_ROCM_SYSTEM: + # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. + # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 + + head_size = query.shape[-1] + + # Inplace operation, updating query and key. + pos_encoding_ops.rotary_embedding( + query, + key, + head_size, + cos, + sin, + True + ) + else: + raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.") + @classmethod def static(cls, config, dim, base, device): inv_freq = _create_inv_freq(dim, base, device) @@ -699,21 +734,19 @@ try: """ Return cos and sin for the asked position ids """ + if IS_ROCM_SYSTEM: + # For RoCm, we always use float cos/sin to avoid a cast. + # For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26 + # But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal. + dtype = torch.float32 self._update_cos_sin_cache(dtype, position_ids.device, max_s) cos = torch.index_select(self._cos_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids) + # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow. return cos.unsqueeze(1), sin.unsqueeze(1) - def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): - rotary_dim = cos.shape[-1] - x1 = x[..., :rotary_dim] - x2 = x[..., rotary_dim : 2 * rotary_dim] - - rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False) - return x - class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): inv_freq = _create_inv_freq(dim, base, device) @@ -722,7 +755,7 @@ try: self.max_position_embeddings = max_position_embeddings self.base = base - def _update_cos_sin_cache(self, dtype, device, seqlen): + def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) if (