Add RoCm support (#1243)
This PR adds support for AMD Instinct MI210 & MI250 GPUs, with paged attention and FAv2 support. Remaining items to discuss, on top of possible others: * Should we have a `ghcr.io/huggingface/text-generation-inference:1.1.0+rocm` hosted image, or is it too early? * Should we set up a CI on MI210/MI250? I don't have access to the runners of TGI though. * Are we comfortable with those changes being directly in TGI, or do we need a fork? --------- Co-authored-by: Felix Marty <felix@hf.co> Co-authored-by: OlivierDehaene <olivier@huggingface.co> Co-authored-by: Your Name <you@example.com>
This commit is contained in:
parent
ed2a3f617e
commit
b2b5df0e94
|
@ -59,7 +59,7 @@ jobs:
|
||||||
|
|
||||||
build-and-push-image:
|
build-and-push-image:
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
group: ${{ github.workflow }}-build-and-push-image-${{ github.head_ref || github.run_id }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
needs: start-runner # required to start the main job when the runner is ready
|
needs: start-runner # required to start the main job when the runner is ready
|
||||||
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
||||||
|
@ -146,6 +146,95 @@ jobs:
|
||||||
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=min
|
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=min
|
||||||
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=min
|
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=min
|
||||||
|
|
||||||
|
build-and-push-image-rocm:
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-build-and-push-image-rocm-${{ github.head_ref || github.run_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
needs: start-runner # required to start the main job when the runner is ready
|
||||||
|
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
packages: write
|
||||||
|
# This is used to complete the identity challenge
|
||||||
|
# with sigstore/fulcio when running outside of PRs.
|
||||||
|
id-token: write
|
||||||
|
security-events: write
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
- name: Initialize Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v2.0.0
|
||||||
|
with:
|
||||||
|
install: true
|
||||||
|
- name: Inject slug/short variables
|
||||||
|
uses: rlespinasse/github-slug-action@v4.4.1
|
||||||
|
- name: Tailscale
|
||||||
|
uses: tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966
|
||||||
|
with:
|
||||||
|
authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
|
||||||
|
- name: Login to GitHub Container Registry
|
||||||
|
if: github.event_name != 'pull_request'
|
||||||
|
uses: docker/login-action@v2
|
||||||
|
with:
|
||||||
|
registry: ghcr.io
|
||||||
|
username: ${{ github.actor }}
|
||||||
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
- name: Login to internal Container Registry
|
||||||
|
uses: docker/login-action@v2.1.0
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }}
|
||||||
|
password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }}
|
||||||
|
registry: registry.internal.huggingface.tech
|
||||||
|
- name: Login to Azure Container Registry
|
||||||
|
if: github.event_name != 'pull_request'
|
||||||
|
uses: docker/login-action@v2.1.0
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.AZURE_DOCKER_USERNAME }}
|
||||||
|
password: ${{ secrets.AZURE_DOCKER_PASSWORD }}
|
||||||
|
registry: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io
|
||||||
|
# If pull request
|
||||||
|
- name: Extract metadata (tags, labels) for Docker
|
||||||
|
if: ${{ github.event_name == 'pull_request' }}
|
||||||
|
id: meta-pr
|
||||||
|
uses: docker/metadata-action@v4.3.0
|
||||||
|
with:
|
||||||
|
images: |
|
||||||
|
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
||||||
|
tags: |
|
||||||
|
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}-rocm
|
||||||
|
# If main, release or tag
|
||||||
|
- name: Extract metadata (tags, labels) for Docker
|
||||||
|
if: ${{ github.event_name != 'pull_request' }}
|
||||||
|
id: meta
|
||||||
|
uses: docker/metadata-action@v4.3.0
|
||||||
|
with:
|
||||||
|
flavor: |
|
||||||
|
latest=false
|
||||||
|
images: |
|
||||||
|
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
||||||
|
ghcr.io/huggingface/text-generation-inference
|
||||||
|
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
|
||||||
|
tags: |
|
||||||
|
type=semver,pattern={{version}}-rocm
|
||||||
|
type=semver,pattern={{major}}.{{minor}}-rocm
|
||||||
|
type=raw,value=latest-rocm,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
|
||||||
|
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}-rocm
|
||||||
|
- name: Build and push Docker image
|
||||||
|
id: build-and-push
|
||||||
|
uses: docker/build-push-action@v4
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
file: Dockerfile_amd
|
||||||
|
push: true
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
build-args: |
|
||||||
|
GIT_SHA=${{ env.GITHUB_SHA }}
|
||||||
|
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}-rocm
|
||||||
|
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
|
||||||
|
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
|
||||||
|
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min
|
||||||
|
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min
|
||||||
|
|
||||||
integration-tests:
|
integration-tests:
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
||||||
|
@ -153,6 +242,7 @@ jobs:
|
||||||
needs:
|
needs:
|
||||||
- start-runner
|
- start-runner
|
||||||
- build-and-push-image # Wait for the docker image to be built
|
- build-and-push-image # Wait for the docker image to be built
|
||||||
|
- build-and-push-image-rocm
|
||||||
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
||||||
env:
|
env:
|
||||||
DOCKER_VOLUME: /cache
|
DOCKER_VOLUME: /cache
|
||||||
|
@ -187,6 +277,7 @@ jobs:
|
||||||
needs:
|
needs:
|
||||||
- start-runner
|
- start-runner
|
||||||
- build-and-push-image
|
- build-and-push-image
|
||||||
|
- build-and-push-image-rocm
|
||||||
- integration-tests
|
- integration-tests
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
|
|
|
@ -106,7 +106,7 @@ WORKDIR /usr/src
|
||||||
COPY server/Makefile-flash-att-v2 Makefile
|
COPY server/Makefile-flash-att-v2 Makefile
|
||||||
|
|
||||||
# Build specific version of flash attention v2
|
# Build specific version of flash attention v2
|
||||||
RUN make build-flash-attention-v2
|
RUN make build-flash-attention-v2-cuda
|
||||||
|
|
||||||
# Build Transformers exllama kernels
|
# Build Transformers exllama kernels
|
||||||
FROM kernel-builder as exllama-kernels-builder
|
FROM kernel-builder as exllama-kernels-builder
|
||||||
|
@ -152,7 +152,7 @@ WORKDIR /usr/src
|
||||||
COPY server/Makefile-vllm Makefile
|
COPY server/Makefile-vllm Makefile
|
||||||
|
|
||||||
# Build specific version of vllm
|
# Build specific version of vllm
|
||||||
RUN make build-vllm
|
RUN make build-vllm-cuda
|
||||||
|
|
||||||
# Text Generation Inference base image
|
# Text Generation Inference base image
|
||||||
FROM nvidia/cuda:12.1.0-base-ubuntu20.04 as base
|
FROM nvidia/cuda:12.1.0-base-ubuntu20.04 as base
|
||||||
|
@ -209,7 +209,7 @@ COPY server server
|
||||||
COPY server/Makefile server/Makefile
|
COPY server/Makefile server/Makefile
|
||||||
RUN cd server && \
|
RUN cd server && \
|
||||||
make gen-server && \
|
make gen-server && \
|
||||||
pip install -r requirements.txt && \
|
pip install -r requirements_cuda.txt && \
|
||||||
pip install ".[bnb, accelerate, quantize, peft]" --no-cache-dir
|
pip install ".[bnb, accelerate, quantize, peft]" --no-cache-dir
|
||||||
|
|
||||||
# Install benchmarker
|
# Install benchmarker
|
||||||
|
@ -224,7 +224,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
||||||
g++ \
|
g++ \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# AWS Sagemaker compatbile image
|
# AWS Sagemaker compatible image
|
||||||
FROM base as sagemaker
|
FROM base as sagemaker
|
||||||
|
|
||||||
COPY sagemaker-entrypoint.sh entrypoint.sh
|
COPY sagemaker-entrypoint.sh entrypoint.sh
|
||||||
|
|
|
@ -0,0 +1,153 @@
|
||||||
|
# Rust builder
|
||||||
|
FROM lukemathwalker/cargo-chef:latest-rust-1.71 AS chef
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
|
|
||||||
|
FROM chef as planner
|
||||||
|
COPY Cargo.toml Cargo.toml
|
||||||
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
|
COPY proto proto
|
||||||
|
COPY benchmark benchmark
|
||||||
|
COPY router router
|
||||||
|
COPY launcher launcher
|
||||||
|
RUN cargo chef prepare --recipe-path recipe.json
|
||||||
|
|
||||||
|
FROM chef AS builder
|
||||||
|
|
||||||
|
ARG GIT_SHA
|
||||||
|
ARG DOCKER_LABEL
|
||||||
|
|
||||||
|
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||||
|
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
||||||
|
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
||||||
|
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
|
||||||
|
rm -f $PROTOC_ZIP
|
||||||
|
|
||||||
|
COPY --from=planner /usr/src/recipe.json recipe.json
|
||||||
|
RUN cargo chef cook --release --recipe-path recipe.json
|
||||||
|
|
||||||
|
COPY Cargo.toml Cargo.toml
|
||||||
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
|
COPY proto proto
|
||||||
|
COPY benchmark benchmark
|
||||||
|
COPY router router
|
||||||
|
COPY launcher launcher
|
||||||
|
RUN cargo build --release
|
||||||
|
|
||||||
|
# Text Generation Inference base image for RoCm
|
||||||
|
FROM rocm/dev-ubuntu-20.04:5.7 as base
|
||||||
|
|
||||||
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
|
build-essential \
|
||||||
|
ca-certificates \
|
||||||
|
ccache \
|
||||||
|
curl \
|
||||||
|
git \
|
||||||
|
make \
|
||||||
|
libssl-dev \
|
||||||
|
g++ \
|
||||||
|
# Needed to build VLLM & flash.
|
||||||
|
rocthrust-dev \
|
||||||
|
hipsparse-dev \
|
||||||
|
hipblas-dev && \
|
||||||
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Keep in sync with `server/pyproject.toml
|
||||||
|
ARG MAMBA_VERSION=23.1.0-1
|
||||||
|
ARG PYTORCH_VERSION='2.2.0.dev0'
|
||||||
|
ARG ROCM_VERSION='5.7'
|
||||||
|
ARG PYTHON_VERSION='3.10.10'
|
||||||
|
# Automatically set by buildx
|
||||||
|
ARG TARGETPLATFORM
|
||||||
|
ENV PATH /opt/conda/bin:$PATH
|
||||||
|
|
||||||
|
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||||
|
# Install mamba
|
||||||
|
# translating Docker's TARGETPLATFORM into mamba arches
|
||||||
|
RUN case ${TARGETPLATFORM} in \
|
||||||
|
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
|
||||||
|
*) MAMBA_ARCH=x86_64 ;; \
|
||||||
|
esac && \
|
||||||
|
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
|
||||||
|
RUN chmod +x ~/mambaforge.sh && \
|
||||||
|
bash ~/mambaforge.sh -b -p /opt/conda && \
|
||||||
|
mamba init && \
|
||||||
|
rm ~/mambaforge.sh
|
||||||
|
|
||||||
|
# Install PyTorch nightly (2.2.0.dev2023) compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6.
|
||||||
|
RUN pip install --pre torch==2.2.0.dev20231106 --index-url https://download.pytorch.org/whl/nightly/rocm5.7
|
||||||
|
|
||||||
|
FROM base AS kernel-builder
|
||||||
|
|
||||||
|
# Build vllm kernels
|
||||||
|
FROM kernel-builder AS vllm-builder
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
COPY server/Makefile-vllm Makefile
|
||||||
|
|
||||||
|
# Build specific version of vllm
|
||||||
|
RUN make build-vllm-rocm
|
||||||
|
|
||||||
|
# Build Flash Attention v2 kernels
|
||||||
|
FROM kernel-builder AS flash-att-v2-builder
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
COPY server/Makefile-flash-att-v2 Makefile
|
||||||
|
|
||||||
|
# Build specific version of flash attention v2
|
||||||
|
RUN make build-flash-attention-v2-rocm
|
||||||
|
|
||||||
|
# Build Transformers CUDA kernels (gpt-neox and bloom)
|
||||||
|
FROM kernel-builder as custom-kernels-builder
|
||||||
|
WORKDIR /usr/src
|
||||||
|
COPY server/custom_kernels/ .
|
||||||
|
RUN PYTORCH_ROCM_ARCH=gfx90a python setup.py build
|
||||||
|
|
||||||
|
FROM base as base-copy
|
||||||
|
|
||||||
|
# Text Generation Inference base env
|
||||||
|
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||||
|
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||||
|
PORT=80
|
||||||
|
|
||||||
|
# Copy builds artifacts from vllm builder
|
||||||
|
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
|
# Copy build artifacts from flash attention v2 builder
|
||||||
|
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
|
# Copy build artifacts from custom kernels builder
|
||||||
|
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
|
# Install flash-attention dependencies
|
||||||
|
RUN pip install einops --no-cache-dir
|
||||||
|
|
||||||
|
# Install server
|
||||||
|
COPY proto proto
|
||||||
|
COPY server server
|
||||||
|
COPY server/Makefile server/Makefile
|
||||||
|
RUN cd server && \
|
||||||
|
make gen-server && \
|
||||||
|
pip install -r requirements_rocm.txt && \
|
||||||
|
pip install ".[accelerate, peft]" --no-cache-dir
|
||||||
|
|
||||||
|
# Install benchmarker
|
||||||
|
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
|
# Install router
|
||||||
|
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
||||||
|
# Install launcher
|
||||||
|
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
|
# AWS Sagemaker compatible image
|
||||||
|
FROM base-copy as sagemaker
|
||||||
|
COPY sagemaker-entrypoint.sh entrypoint.sh
|
||||||
|
RUN chmod +x entrypoint.sh
|
||||||
|
|
||||||
|
ENTRYPOINT ["./entrypoint.sh"]
|
||||||
|
|
||||||
|
# Final image
|
||||||
|
FROM base-copy
|
||||||
|
|
||||||
|
ENTRYPOINT ["text-generation-launcher"]
|
||||||
|
CMD ["--json-output"]
|
|
@ -74,7 +74,9 @@ curl 127.0.0.1:8080/generate \
|
||||||
-H 'Content-Type: application/json'
|
-H 'Content-Type: application/json'
|
||||||
```
|
```
|
||||||
|
|
||||||
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
||||||
|
|
||||||
|
**Note:** TGI supports AMD Instinct MI210 and MI250 [to some extent](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.1.0+rocm --model-id $model` instead of the command above.
|
||||||
|
|
||||||
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
|
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
|
||||||
```
|
```
|
||||||
|
@ -189,7 +191,7 @@ sudo apt-get install libssl-dev gcc -y
|
||||||
|
|
||||||
### CUDA Kernels
|
### CUDA Kernels
|
||||||
|
|
||||||
The custom CUDA kernels are only tested on NVIDIA A100s. If you have any installation or runtime issues, you can remove
|
The custom CUDA kernels are only tested on NVIDIA A100, AMD MI210 and AMD MI250. If you have any installation or runtime issues, you can remove
|
||||||
the kernels by using the `DISABLE_CUSTOM_KERNELS=True` environment variable.
|
the kernels by using the `DISABLE_CUSTOM_KERNELS=True` environment variable.
|
||||||
|
|
||||||
Be aware that the official Docker image has them enabled by default.
|
Be aware that the official Docker image has them enabled by default.
|
||||||
|
|
|
@ -15,6 +15,8 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingf
|
||||||
|
|
||||||
To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) . We also recommend using NVIDIA drivers with CUDA version 11.8 or higher.
|
To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) . We also recommend using NVIDIA drivers with CUDA version 11.8 or higher.
|
||||||
|
|
||||||
|
To use TGI on RoCm-enabled AMD GPUs (only MI210 and MI250 are tested), please use the image `ghcr.io/huggingface/text-generation-inference:1.1.1+rocm` instead. For details about the usage on RoCm, please refer to the [Supported Hardware section](./supported_models#supported-hardware) and [AMD documentation](https://rocm.docs.amd.com/en/latest/deploy/docker.html).
|
||||||
|
|
||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
Once TGI is running, you can use the `generate` endpoint by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint.
|
Once TGI is running, you can use the `generate` endpoint by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint.
|
||||||
|
|
|
@ -39,9 +39,9 @@ text-generation-launcher --model-id <PATH-TO-LOCAL-BLOOM>
|
||||||
|
|
||||||
## Supported Hardware
|
## Supported Hardware
|
||||||
|
|
||||||
TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 11.8+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other hardware, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed.
|
TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 11.8+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed.
|
||||||
|
|
||||||
|
TGI also has support of RoCm-enabled AMD Instinct MI210 and MI250 GPUs, with paged attention and flash attention v2 support. The following features are missing from the RoCm version of TGI: quantization and flash [layer norm kernel](https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm).
|
||||||
|
|
||||||
TGI is also supported on the following AI hardware accelerators:
|
TGI is also supported on the following AI hardware accelerators:
|
||||||
- *Habana first-gen Gaudi and Gaudi2:* check out this [example](https://github.com/huggingface/optimum-habana/tree/main/text-generation-inference) how to serve models with TGI on Gaudi and Gaudi2 with [Optimum Habana](https://huggingface.co/docs/optimum/habana/index)
|
- *Habana first-gen Gaudi and Gaudi2:* check out this [example](https://github.com/huggingface/optimum-habana/tree/main/text-generation-inference) how to serve models with TGI on Gaudi and Gaudi2 with [Optimum Habana](https://huggingface.co/docs/optimum/habana/index)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,11 +18,12 @@ gen-server:
|
||||||
|
|
||||||
install: gen-server
|
install: gen-server
|
||||||
pip install pip --upgrade
|
pip install pip --upgrade
|
||||||
pip install -r requirements.txt
|
pip install -r requirements_cuda.txt
|
||||||
pip install -e ".[bnb, accelerate, quantize, peft]"
|
pip install -e ".[bnb, accelerate, quantize, peft]"
|
||||||
|
|
||||||
run-dev:
|
run-dev:
|
||||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
||||||
|
|
||||||
export-requirements:
|
export-requirements:
|
||||||
poetry export -o requirements.txt -E bnb --without-hashes
|
poetry export -o requirements_cuda.txt --extras bnb --without-hashes
|
||||||
|
poetry export -o requirements_rocm.txt --without-hashes
|
||||||
|
|
|
@ -2,7 +2,7 @@ flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec
|
||||||
|
|
||||||
flash-attention:
|
flash-attention:
|
||||||
# Clone flash attention
|
# Clone flash attention
|
||||||
pip install packaging
|
pip install -U packaging ninja --no-cache-dir
|
||||||
git clone https://github.com/HazyResearch/flash-attention.git
|
git clone https://github.com/HazyResearch/flash-attention.git
|
||||||
|
|
||||||
build-flash-attention: flash-attention
|
build-flash-attention: flash-attention
|
||||||
|
|
|
@ -1,13 +1,26 @@
|
||||||
flash_att_v2_commit := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3
|
flash_att_v2_commit := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3
|
||||||
|
|
||||||
|
build-flash-attention-v2-cuda: FLASH_ATTN_V2_COMMIT=02ac572f3ffc4f402e4183aaa6824b45859d3ed3
|
||||||
|
build-flash-attention-v2-cuda: FLASH_REPOSITORY=https://github.com/HazyResearch/flash-attention.git
|
||||||
|
build-flash-attention-v2-cuda: BRANCH=main
|
||||||
|
build-flash-attention-v2-cuda: PYTORCH_ROCM_ARCH=""
|
||||||
|
build-flash-attention-v2-cuda: build-flash-attention-v2
|
||||||
|
|
||||||
|
build-flash-attention-v2-rocm: FLASH_ATTN_V2_COMMIT=8736558c287ff2ef28b24878e42828c595ac3e69
|
||||||
|
build-flash-attention-v2-rocm: FLASH_REPOSITORY=https://github.com/fxmarty/flash-attention-rocm
|
||||||
|
build-flash-attention-v2-rocm: BRANCH=remove-offload-arch-native
|
||||||
|
build-flash-attention-v2-rocm: PYTORCH_ROCM_ARCH=gfx90a
|
||||||
|
build-flash-attention-v2-rocm: build-flash-attention-v2
|
||||||
|
|
||||||
flash-attention-v2:
|
flash-attention-v2:
|
||||||
# Clone flash attention
|
# Clone flash attention
|
||||||
pip install packaging
|
pip install -U packaging ninja --no-cache-dir
|
||||||
git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2
|
git clone --single-branch --branch $(BRANCH) $(FLASH_REPOSITORY) flash-attention-v2
|
||||||
|
|
||||||
build-flash-attention-v2: flash-attention-v2
|
build-flash-attention-v2: flash-attention-v2
|
||||||
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit)
|
cd flash-attention-v2 && git fetch && git checkout $(FLASH_ATTN_V2_COMMIT)
|
||||||
cd flash-attention-v2 && python setup.py build
|
cd flash-attention-v2 && git submodule update --init --recursive
|
||||||
|
cd flash-attention-v2 && PYTORCH_ROCM_ARCH=$(PYTORCH_ROCM_ARCH) python setup.py build
|
||||||
|
|
||||||
install-flash-attention-v2: build-flash-attention-v2
|
install-flash-attention-v2: build-flash-attention-v2
|
||||||
cd flash-attention-v2 && python setup.py install
|
cd flash-attention-v2 && git submodule update --init --recursive && python setup.py install
|
|
@ -1,11 +1,20 @@
|
||||||
vllm_commit := f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
|
build-vllm-cuda: REPOSITORY=https://github.com/vllm-project/vllm.git
|
||||||
|
build-vllm-cuda: VLLM_COMMIT=f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
|
||||||
|
build-vllm-cuda: BRANCH=main
|
||||||
|
build-vllm-cuda: build-vllm
|
||||||
|
|
||||||
|
build-vllm-rocm: REPOSITORY=https://github.com/fxmarty/vllm-public.git
|
||||||
|
build-vllm-rocm: VLLM_COMMIT=ad9b7c4095ef54419a0533d254f2ad84bd2dfcae
|
||||||
|
build-vllm-rocm: BRANCH=rotary-no-positions-split-cos-sin
|
||||||
|
build-vllm-rocm: build-vllm
|
||||||
|
|
||||||
vllm:
|
vllm:
|
||||||
# Clone vllm
|
# Clone vllm
|
||||||
git clone https://github.com/vllm-project/vllm.git
|
pip install -U ninja packaging --no-cache-dir
|
||||||
|
git clone --single-branch --branch $(BRANCH) $(REPOSITORY) vllm
|
||||||
|
|
||||||
build-vllm: vllm
|
build-vllm: vllm
|
||||||
cd vllm && git fetch && git checkout $(vllm_commit)
|
cd vllm && git fetch && git checkout $(VLLM_COMMIT)
|
||||||
cd vllm && python setup.py build
|
cd vllm && python setup.py build
|
||||||
|
|
||||||
install-vllm: build-vllm
|
install-vllm: build-vllm
|
||||||
|
|
|
@ -1,5 +1,10 @@
|
||||||
from setuptools import setup
|
from setuptools import setup
|
||||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||||
|
import torch
|
||||||
|
|
||||||
|
extra_compile_args = ["-std=c++17"]
|
||||||
|
if not torch.version.hip:
|
||||||
|
extra_compile_args.append("-arch=compute_80")
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="custom_kernels",
|
name="custom_kernels",
|
||||||
|
@ -7,12 +12,12 @@ setup(
|
||||||
CUDAExtension(
|
CUDAExtension(
|
||||||
name="custom_kernels.fused_bloom_attention_cuda",
|
name="custom_kernels.fused_bloom_attention_cuda",
|
||||||
sources=["custom_kernels/fused_bloom_attention_cuda.cu"],
|
sources=["custom_kernels/fused_bloom_attention_cuda.cu"],
|
||||||
extra_compile_args=["-arch=compute_80", "-std=c++17"],
|
extra_compile_args=extra_compile_args,
|
||||||
),
|
),
|
||||||
CUDAExtension(
|
CUDAExtension(
|
||||||
name="custom_kernels.fused_attention_cuda",
|
name="custom_kernels.fused_attention_cuda",
|
||||||
sources=["custom_kernels/fused_attention_cuda.cu"],
|
sources=["custom_kernels/fused_attention_cuda.cu"],
|
||||||
extra_compile_args=["-arch=compute_80", "-std=c++17"],
|
extra_compile_args=extra_compile_args,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
cmdclass={"build_ext": BuildExtension},
|
cmdclass={"build_ext": BuildExtension},
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
certifi==2023.11.17 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
googleapis-common-protos==1.61.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
grpcio-reflection==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
grpcio-status==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
grpcio==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
packaging==23.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
pillow==10.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
protobuf==4.25.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
regex==2023.10.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
transformers==4.33.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||||
|
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
|
@ -0,0 +1,46 @@
|
||||||
|
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
certifi==2023.11.17 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
googleapis-common-protos==1.61.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
grpcio-reflection==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
grpcio-status==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
grpcio==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
packaging==23.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
pillow==10.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
protobuf==4.25.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
regex==2023.10.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
transformers==4.33.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||||
|
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
|
@ -26,9 +26,6 @@ from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# Flash attention imports
|
|
||||||
import dropout_layer_norm
|
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
|
@ -38,6 +35,12 @@ from text_generation_server.utils.layers import (
|
||||||
TensorParallelHead,
|
TensorParallelHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||||
|
|
||||||
|
if IS_CUDA_SYSTEM:
|
||||||
|
import dropout_layer_norm
|
||||||
|
elif IS_ROCM_SYSTEM:
|
||||||
|
from vllm import layernorm_ops
|
||||||
|
|
||||||
|
|
||||||
class LlamaConfig(PretrainedConfig):
|
class LlamaConfig(PretrainedConfig):
|
||||||
|
@ -120,7 +123,7 @@ class LlamaRMSNorm(nn.Module):
|
||||||
hidden_states = hidden_states.to(self.weight.dtype)
|
hidden_states = hidden_states.to(self.weight.dtype)
|
||||||
|
|
||||||
return self.weight * hidden_states, residual
|
return self.weight * hidden_states, residual
|
||||||
else:
|
elif IS_CUDA_SYSTEM:
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
@ -143,6 +146,22 @@ class LlamaRMSNorm(nn.Module):
|
||||||
res = hidden_states
|
res = hidden_states
|
||||||
|
|
||||||
return normed_hidden_states, res
|
return normed_hidden_states, res
|
||||||
|
elif IS_ROCM_SYSTEM:
|
||||||
|
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
|
||||||
|
if residual is not None:
|
||||||
|
hidden_states += residual
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
out = torch.empty_like(hidden_states)
|
||||||
|
layernorm_ops.rms_norm(
|
||||||
|
out,
|
||||||
|
hidden_states,
|
||||||
|
self.weight.data,
|
||||||
|
self.variance_epsilon,
|
||||||
|
)
|
||||||
|
return out, residual
|
||||||
|
else:
|
||||||
|
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
|
@ -204,9 +223,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.head_size = self.hidden_size // self.num_heads
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
# self.rotary_emb = PositionRotaryEmbedding.load(
|
|
||||||
# config=config, prefix=f"{prefix}.rotary_emb", weights=weights
|
|
||||||
# )
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
config=config,
|
config=config,
|
||||||
dim=self.head_size,
|
dim=self.head_size,
|
||||||
|
@ -261,9 +277,8 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
)
|
)
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
|
@ -297,7 +312,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -26,11 +26,8 @@ from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# Flash attention imports
|
|
||||||
import dropout_layer_norm
|
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2
|
from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2_ROCM, HAS_FLASH_ATTN_V2_CUDA
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
|
@ -39,8 +36,14 @@ from text_generation_server.utils.layers import (
|
||||||
TensorParallelHead,
|
TensorParallelHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||||
|
|
||||||
if not HAS_FLASH_ATTN_V2:
|
if IS_CUDA_SYSTEM:
|
||||||
|
import dropout_layer_norm
|
||||||
|
elif IS_ROCM_SYSTEM:
|
||||||
|
from vllm import layernorm_ops
|
||||||
|
|
||||||
|
if not HAS_FLASH_ATTN_V2_CUDA and not HAS_FLASH_ATTN_V2_ROCM:
|
||||||
raise ImportError("Mistral model requires flash attn v2")
|
raise ImportError("Mistral model requires flash attn v2")
|
||||||
|
|
||||||
|
|
||||||
|
@ -126,7 +129,7 @@ class MistralRMSNorm(nn.Module):
|
||||||
hidden_states = hidden_states.to(self.weight.dtype)
|
hidden_states = hidden_states.to(self.weight.dtype)
|
||||||
|
|
||||||
return self.weight * hidden_states, residual
|
return self.weight * hidden_states, residual
|
||||||
else:
|
elif IS_CUDA_SYSTEM:
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
@ -149,6 +152,22 @@ class MistralRMSNorm(nn.Module):
|
||||||
res = hidden_states
|
res = hidden_states
|
||||||
|
|
||||||
return normed_hidden_states, res
|
return normed_hidden_states, res
|
||||||
|
elif IS_ROCM_SYSTEM:
|
||||||
|
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
|
||||||
|
if residual is not None:
|
||||||
|
hidden_states += residual
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
out = torch.empty_like(hidden_states)
|
||||||
|
layernorm_ops.rms_norm(
|
||||||
|
out,
|
||||||
|
hidden_states,
|
||||||
|
self.weight.data,
|
||||||
|
self.variance_epsilon,
|
||||||
|
)
|
||||||
|
return out, residual
|
||||||
|
else:
|
||||||
|
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
|
@ -261,8 +280,7 @@ class MistralAttention(torch.nn.Module):
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
|
||||||
|
|
||||||
if prefill_cache_indices is not None:
|
if prefill_cache_indices is not None:
|
||||||
kv_to_cache = kv[prefill_cache_indices]
|
kv_to_cache = kv[prefill_cache_indices]
|
||||||
|
|
|
@ -135,8 +135,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||||
|
|
||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(qkv[:, 0], cos, sin)
|
self.rotary_emb(qkv[:, 0], qkv[:, 1], cos, sin)
|
||||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
|
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
|
||||||
|
|
|
@ -185,8 +185,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||||
kv = kv.view(-1, 2, self.num_heads_kv, self.head_size)
|
kv = kv.view(-1, 2, self.num_heads_kv, self.head_size)
|
||||||
|
|
||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
|
@ -301,8 +300,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||||
query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||||
|
|
||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
|
||||||
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
|
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
kv[:, :, 0].contiguous(),
|
kv[:, :, 0].contiguous(),
|
||||||
|
|
|
@ -55,8 +55,12 @@ from text_generation_server.utils.layers import (
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
FastLinear,
|
FastLinear,
|
||||||
)
|
)
|
||||||
import dropout_layer_norm
|
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||||
|
|
||||||
|
if IS_CUDA_SYSTEM:
|
||||||
|
import dropout_layer_norm
|
||||||
|
elif IS_ROCM_SYSTEM:
|
||||||
|
from vllm import layernorm_ops
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseModelOutputWithPastImage(BaseModelOutputWithPast):
|
class BaseModelOutputWithPastImage(BaseModelOutputWithPast):
|
||||||
|
@ -370,7 +374,7 @@ class IdeficsRMSNorm(nn.Module):
|
||||||
hidden_states = hidden_states.to(self.weight.dtype)
|
hidden_states = hidden_states.to(self.weight.dtype)
|
||||||
|
|
||||||
return self.weight * hidden_states
|
return self.weight * hidden_states
|
||||||
else:
|
elif IS_CUDA_SYSTEM:
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
unwrap = False
|
unwrap = False
|
||||||
if len(hidden_states.shape) > 2:
|
if len(hidden_states.shape) > 2:
|
||||||
|
@ -402,6 +406,32 @@ class IdeficsRMSNorm(nn.Module):
|
||||||
normed_hidden_states = normed_hidden_states.view(*shape)
|
normed_hidden_states = normed_hidden_states.view(*shape)
|
||||||
|
|
||||||
return normed_hidden_states
|
return normed_hidden_states
|
||||||
|
elif IS_ROCM_SYSTEM:
|
||||||
|
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
|
||||||
|
if residual is not None:
|
||||||
|
hidden_states += residual
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
unwrap = False
|
||||||
|
if len(hidden_states.shape) > 2:
|
||||||
|
unwrap = True
|
||||||
|
shape = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.reshape(-1, shape[-1])
|
||||||
|
|
||||||
|
out = torch.empty_like(hidden_states)
|
||||||
|
layernorm_ops.rms_norm(
|
||||||
|
out,
|
||||||
|
hidden_states,
|
||||||
|
self.weight.data,
|
||||||
|
self.variance_epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
|
if unwrap:
|
||||||
|
out = out.view(*shape)
|
||||||
|
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
|
||||||
|
|
||||||
|
|
||||||
# this was adapted from LlamaMLP
|
# this was adapted from LlamaMLP
|
||||||
|
@ -581,15 +611,12 @@ class IdeficsAttention(nn.Module):
|
||||||
position_ids.view(-1), max_s, hidden_states.dtype
|
position_ids.view(-1), max_s, hidden_states.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
shape = query_states.shape
|
query_shape = query_states.shape
|
||||||
query_states = self.rotary_emb(
|
key_shape = key_states.shape
|
||||||
query_states.view(-1, *shape[2:]), cos, sin
|
self.rotary_emb(query_states.view(-1, *query_shape[2:]), key_states.reshape(-1, *key_shape[2:]), cos, sin)
|
||||||
).view(shape)
|
|
||||||
|
query_states = query_states.view(query_shape)
|
||||||
shape = key_states.shape
|
key_states = key_states.view(key_shape)
|
||||||
key_states = self.rotary_emb(
|
|
||||||
key_states.reshape(-1, *shape[2:]), cos, sin
|
|
||||||
).view(shape)
|
|
||||||
|
|
||||||
query_states = query_states.transpose(1, 2)
|
query_states = query_states.transpose(1, 2)
|
||||||
key_states = key_states.transpose(1, 2)
|
key_states = key_states.transpose(1, 2)
|
||||||
|
|
|
@ -3,6 +3,8 @@ import torch
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||||
|
|
||||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||||
|
|
||||||
|
@ -15,7 +17,8 @@ is_sm8x = major == 8 and minor >= 0
|
||||||
is_sm90 = major == 9 and minor == 0
|
is_sm90 = major == 9 and minor == 0
|
||||||
|
|
||||||
HAS_FLASH_ATTN = False
|
HAS_FLASH_ATTN = False
|
||||||
HAS_FLASH_ATTN_V2 = False
|
HAS_FLASH_ATTN_V2_CUDA = False
|
||||||
|
HAS_FLASH_ATTN_V2_ROCM = False
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
import flash_attn_2_cuda
|
import flash_attn_2_cuda
|
||||||
|
@ -30,7 +33,8 @@ try:
|
||||||
f"GPU with CUDA capability {major} {minor} is not supported for "
|
f"GPU with CUDA capability {major} {minor} is not supported for "
|
||||||
"Flash Attention V2"
|
"Flash Attention V2"
|
||||||
)
|
)
|
||||||
HAS_FLASH_ATTN_V2 = True
|
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
|
||||||
|
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
try:
|
try:
|
||||||
import flash_attn_cuda
|
import flash_attn_cuda
|
||||||
|
@ -41,10 +45,17 @@ except ImportError as e:
|
||||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
if not (is_sm75 or is_sm8x or is_sm90):
|
if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90):
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"GPU with CUDA capability {major} {minor} is not supported"
|
f"GPU with CUDA capability {major} {minor} is not supported"
|
||||||
) from e
|
) from e
|
||||||
|
elif IS_ROCM_SYSTEM:
|
||||||
|
for idx in range(torch.cuda.device_count()):
|
||||||
|
if "MI210" not in torch.cuda.get_device_name(idx) and "MI250" not in torch.cuda.get_device_name(idx):
|
||||||
|
raise ImportError(
|
||||||
|
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
|
||||||
|
)
|
||||||
|
|
||||||
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
||||||
HAS_FLASH_ATTN = True
|
HAS_FLASH_ATTN = True
|
||||||
|
|
||||||
|
@ -59,7 +70,7 @@ def attention(
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
):
|
):
|
||||||
if HAS_FLASH_ATTN_V2:
|
if HAS_FLASH_ATTN_V2_CUDA:
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
|
@ -78,8 +89,28 @@ def attention(
|
||||||
False,
|
False,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
elif HAS_FLASH_ATTN_V2_ROCM:
|
||||||
if HAS_FLASH_ATTN:
|
if window_size_left != -1:
|
||||||
|
raise ValueError(f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left}).")
|
||||||
|
|
||||||
|
# RoCm flash API does not take the window_size_left and window_size_right arguments.
|
||||||
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
elif HAS_FLASH_ATTN:
|
||||||
if window_size_left != -1:
|
if window_size_left != -1:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"window_size_left is only available with flash attn v2"
|
"window_size_left is only available with flash attn v2"
|
||||||
|
|
|
@ -0,0 +1,4 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
IS_ROCM_SYSTEM = torch.version.hip is not None
|
||||||
|
IS_CUDA_SYSTEM = torch.version.cuda is not None
|
|
@ -12,14 +12,13 @@ HAS_BITS_AND_BYTES = True
|
||||||
try:
|
try:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
from bitsandbytes.nn import Int8Params, Params4bit
|
from bitsandbytes.nn import Int8Params, Params4bit
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
HAS_BITS_AND_BYTES = False
|
HAS_BITS_AND_BYTES = False
|
||||||
|
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||||
|
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||||
|
|
||||||
HAS_AWQ = True
|
HAS_AWQ = True
|
||||||
try:
|
try:
|
||||||
|
@ -525,11 +524,14 @@ class TensorParallelEmbedding(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import dropout_layer_norm
|
if IS_CUDA_SYSTEM:
|
||||||
|
import dropout_layer_norm
|
||||||
|
else:
|
||||||
|
dropout_layer_norm = None
|
||||||
|
|
||||||
class FastLayerNorm(nn.LayerNorm):
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
if hidden_states.shape[-1] > 8192:
|
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
@ -561,14 +563,16 @@ try:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
return normed_hidden_states, residual
|
return normed_hidden_states, residual
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn.layers.rotary import RotaryEmbedding
|
if IS_CUDA_SYSTEM:
|
||||||
import rotary_emb
|
from flash_attn.layers.rotary import RotaryEmbedding
|
||||||
|
import rotary_emb
|
||||||
|
elif IS_ROCM_SYSTEM:
|
||||||
|
from vllm import pos_encoding_ops
|
||||||
|
|
||||||
def _create_inv_freq(dim, base, device):
|
def _create_inv_freq(dim, base, device):
|
||||||
inv_freq = 1.0 / (
|
inv_freq = 1.0 / (
|
||||||
|
@ -597,6 +601,37 @@ try:
|
||||||
self.scaling_factor = scaling_factor
|
self.scaling_factor = scaling_factor
|
||||||
self.dynamic_args = None
|
self.dynamic_args = None
|
||||||
|
|
||||||
|
def forward(self, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
||||||
|
# Such controlflows may add some overhead.
|
||||||
|
if IS_CUDA_SYSTEM:
|
||||||
|
rotary_dim = cos.shape[-1]
|
||||||
|
q1 = query[..., :rotary_dim]
|
||||||
|
q2 = query[..., rotary_dim : 2 * rotary_dim]
|
||||||
|
|
||||||
|
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
|
||||||
|
|
||||||
|
k1 = key[..., :rotary_dim]
|
||||||
|
k2 = key[..., rotary_dim : 2 * rotary_dim]
|
||||||
|
|
||||||
|
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
||||||
|
elif IS_ROCM_SYSTEM:
|
||||||
|
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
||||||
|
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
|
||||||
|
|
||||||
|
head_size = query.shape[-1]
|
||||||
|
|
||||||
|
# Inplace operation, updating query and key.
|
||||||
|
pos_encoding_ops.rotary_embedding(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
head_size,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def static(cls, config, dim, base, device):
|
def static(cls, config, dim, base, device):
|
||||||
inv_freq = _create_inv_freq(dim, base, device)
|
inv_freq = _create_inv_freq(dim, base, device)
|
||||||
|
@ -699,21 +734,19 @@ try:
|
||||||
"""
|
"""
|
||||||
Return cos and sin for the asked position ids
|
Return cos and sin for the asked position ids
|
||||||
"""
|
"""
|
||||||
|
if IS_ROCM_SYSTEM:
|
||||||
|
# For RoCm, we always use float cos/sin to avoid a cast.
|
||||||
|
# For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26
|
||||||
|
# But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal.
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
|
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
|
||||||
|
|
||||||
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
||||||
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||||
|
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
||||||
return cos.unsqueeze(1), sin.unsqueeze(1)
|
return cos.unsqueeze(1), sin.unsqueeze(1)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
|
||||||
rotary_dim = cos.shape[-1]
|
|
||||||
x1 = x[..., :rotary_dim]
|
|
||||||
x2 = x[..., rotary_dim : 2 * rotary_dim]
|
|
||||||
|
|
||||||
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
|
|
||||||
return x
|
|
||||||
|
|
||||||
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
||||||
inv_freq = _create_inv_freq(dim, base, device)
|
inv_freq = _create_inv_freq(dim, base, device)
|
||||||
|
@ -722,7 +755,7 @@ try:
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.base = base
|
self.base = base
|
||||||
|
|
||||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
# Reset the tables if the sequence length has changed,
|
# Reset the tables if the sequence length has changed,
|
||||||
# or if we're on a new device (possibly due to tracing for instance)
|
# or if we're on a new device (possibly due to tracing for instance)
|
||||||
if (
|
if (
|
||||||
|
|
Loading…
Reference in New Issue