MI300 compatibility (#1764)
Adds support for AMD Instinct MI300 in TGI.
Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
3489ce7936
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1
By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.
Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```
---------
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
This commit is contained in:
parent
a60fa8406a
commit
232e8d5227
|
@ -290,6 +290,9 @@ jobs:
|
||||||
# with sigstore/fulcio when running outside of PRs.
|
# with sigstore/fulcio when running outside of PRs.
|
||||||
id-token: write
|
id-token: write
|
||||||
security-events: write
|
security-events: write
|
||||||
|
outputs:
|
||||||
|
# env is not available in the later `container:`, but previous job outputs are.
|
||||||
|
short_sha: ${{ env.GITHUB_SHA_SHORT }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
@ -392,3 +395,37 @@ jobs:
|
||||||
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
||||||
label: ${{ needs.start-runner.outputs.label }}
|
label: ${{ needs.start-runner.outputs.label }}
|
||||||
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}
|
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}
|
||||||
|
|
||||||
|
integration-tests-rocm:
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
needs:
|
||||||
|
- start-runner
|
||||||
|
- build-and-push-image
|
||||||
|
- integration-tests
|
||||||
|
- build-and-push-image-rocm
|
||||||
|
- stop-runner
|
||||||
|
runs-on: [self-hosted, docker-gpu, amd-gpu, multi-gpu, mi300]
|
||||||
|
container:
|
||||||
|
image: registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ needs.build-and-push-image-rocm.outputs.short_sha }}-rocm
|
||||||
|
options: --device /dev/kfd --device /dev/dri --env ROCR_VISIBLE_DEVICES --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/cache
|
||||||
|
env:
|
||||||
|
DOCKER_VOLUME: /cache
|
||||||
|
steps:
|
||||||
|
- name: ROCM-SMI
|
||||||
|
run: |
|
||||||
|
rocm-smi
|
||||||
|
- name: ROCM-INFO
|
||||||
|
run: |
|
||||||
|
rocminfo | grep "Agent" -A 14
|
||||||
|
- name: Show ROCR environment
|
||||||
|
run: |
|
||||||
|
echo "ROCR: $ROCR_VISIBLE_DEVICES"
|
||||||
|
- name: Install
|
||||||
|
run: |
|
||||||
|
make install-integration-tests
|
||||||
|
- name: Run tests
|
||||||
|
run: |
|
||||||
|
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||||
|
pytest -s -vv integration-tests
|
||||||
|
|
|
@ -36,7 +36,7 @@ COPY launcher launcher
|
||||||
RUN cargo build --release
|
RUN cargo build --release
|
||||||
|
|
||||||
# Text Generation Inference base image for RoCm
|
# Text Generation Inference base image for RoCm
|
||||||
FROM rocm/dev-ubuntu-22.04:5.7 as base
|
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update as base
|
||||||
|
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
build-essential \
|
build-essential \
|
||||||
|
@ -50,13 +50,24 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
||||||
# Needed to build VLLM & flash.
|
# Needed to build VLLM & flash.
|
||||||
rocthrust-dev \
|
rocthrust-dev \
|
||||||
hipsparse-dev \
|
hipsparse-dev \
|
||||||
hipblas-dev && \
|
hipblas-dev \
|
||||||
|
hipblaslt-dev \
|
||||||
|
rocblas-dev \
|
||||||
|
hiprand-dev \
|
||||||
|
rocrand-dev \
|
||||||
|
miopen-hip-dev \
|
||||||
|
hipfft-dev \
|
||||||
|
hipcub-dev \
|
||||||
|
hipsolver-dev \
|
||||||
|
rccl-dev \
|
||||||
|
cmake \
|
||||||
|
python3-dev && \
|
||||||
rm -rf /var/lib/apt/lists/*
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Keep in sync with `server/pyproject.toml
|
# Keep in sync with `server/pyproject.toml
|
||||||
ARG MAMBA_VERSION=23.1.0-1
|
ARG MAMBA_VERSION=23.1.0-1
|
||||||
ARG PYTORCH_VERSION='2.2.0.dev0'
|
ARG PYTORCH_VERSION='2.3.0'
|
||||||
ARG ROCM_VERSION='5.7'
|
ARG ROCM_VERSION='6.0.2'
|
||||||
ARG PYTHON_VERSION='3.10.10'
|
ARG PYTHON_VERSION='3.10.10'
|
||||||
# Automatically set by buildx
|
# Automatically set by buildx
|
||||||
ARG TARGETPLATFORM
|
ARG TARGETPLATFORM
|
||||||
|
@ -75,12 +86,43 @@ RUN chmod +x ~/mambaforge.sh && \
|
||||||
mamba init && \
|
mamba init && \
|
||||||
rm ~/mambaforge.sh
|
rm ~/mambaforge.sh
|
||||||
|
|
||||||
# Install PyTorch 2.2 RC compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6.
|
# Install flash-attention, torch dependencies
|
||||||
RUN pip install torch --index-url https://download.pytorch.org/whl/test/rocm5.7/
|
RUN pip install numpy einops ninja --no-cache-dir
|
||||||
|
|
||||||
|
RUN conda install intel::mkl-static intel::mkl-include
|
||||||
|
RUN pip uninstall -y triton && \
|
||||||
|
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
|
||||||
|
cd triton/python && \
|
||||||
|
pip install .
|
||||||
|
|
||||||
|
RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir
|
||||||
|
|
||||||
|
ARG _GLIBCXX_USE_CXX11_ABI="1"
|
||||||
|
ARG CMAKE_PREFIX_PATH="/opt/conda"
|
||||||
|
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
|
||||||
|
ARG BUILD_CAFFE2="0" \
|
||||||
|
BUILD_CAFFE2_OPS="0" \
|
||||||
|
USE_CUDA="0" \
|
||||||
|
USE_ROCM="1" \
|
||||||
|
BUILD_TEST="0" \
|
||||||
|
USE_FBGEMM="0" \
|
||||||
|
USE_NNPACK="0" \
|
||||||
|
USE_QNNPACK="0" \
|
||||||
|
USE_XNNPACK="0" \
|
||||||
|
USE_FLASH_ATTENTION="1" \
|
||||||
|
USE_MEM_EFF_ATTENTION="0"
|
||||||
|
|
||||||
|
RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install
|
||||||
|
|
||||||
|
# Set as recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
|
||||||
|
ENV HIP_FORCE_DEV_KERNARG=1
|
||||||
|
|
||||||
|
# On MI300, performances for flash with Triton FA is very competitive (actually better than CK)
|
||||||
|
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=1
|
||||||
|
|
||||||
FROM base AS kernel-builder
|
FROM base AS kernel-builder
|
||||||
|
|
||||||
# Build vllm kernels
|
# # Build vllm kernels
|
||||||
FROM kernel-builder AS vllm-builder
|
FROM kernel-builder AS vllm-builder
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
@ -102,21 +144,21 @@ RUN make build-flash-attention-v2-rocm
|
||||||
FROM kernel-builder as custom-kernels-builder
|
FROM kernel-builder as custom-kernels-builder
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
COPY server/custom_kernels/ .
|
COPY server/custom_kernels/ .
|
||||||
RUN PYTORCH_ROCM_ARCH=gfx90a python setup.py build
|
RUN python setup.py build
|
||||||
|
|
||||||
# Build exllama kernels
|
# Build exllama kernels
|
||||||
FROM kernel-builder as exllama-kernels-builder
|
FROM kernel-builder as exllama-kernels-builder
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
COPY server/exllama_kernels/ .
|
COPY server/exllama_kernels/ .
|
||||||
|
|
||||||
RUN PYTORCH_ROCM_ARCH="gfx90a" python setup.py build
|
RUN python setup.py build
|
||||||
|
|
||||||
# Build exllama v2 kernels
|
# Build exllama v2 kernels
|
||||||
FROM kernel-builder as exllamav2-kernels-builder
|
FROM kernel-builder as exllamav2-kernels-builder
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
COPY server/exllamav2_kernels/ .
|
COPY server/exllamav2_kernels/ .
|
||||||
|
|
||||||
RUN PYTORCH_ROCM_ARCH="gfx90a" python setup.py build
|
RUN python setup.py build
|
||||||
|
|
||||||
FROM base as base-copy
|
FROM base as base-copy
|
||||||
|
|
||||||
|
@ -140,9 +182,6 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310
|
||||||
# Copy build artifacts from exllamav2 kernels builder
|
# Copy build artifacts from exllamav2 kernels builder
|
||||||
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
# Install flash-attention dependencies
|
|
||||||
RUN pip install einops --no-cache-dir
|
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY server server
|
COPY server server
|
||||||
|
@ -160,7 +199,8 @@ COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bi
|
||||||
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
# AWS Sagemaker compatible image
|
# AWS Sagemaker compatible image
|
||||||
FROM base-copy as sagemaker
|
FROM base as sagemaker
|
||||||
|
|
||||||
COPY sagemaker-entrypoint.sh entrypoint.sh
|
COPY sagemaker-entrypoint.sh entrypoint.sh
|
||||||
RUN chmod +x entrypoint.sh
|
RUN chmod +x entrypoint.sh
|
||||||
|
|
||||||
|
@ -169,5 +209,8 @@ ENTRYPOINT ["./entrypoint.sh"]
|
||||||
# Final image
|
# Final image
|
||||||
FROM base-copy
|
FROM base-copy
|
||||||
|
|
||||||
ENTRYPOINT ["text-generation-launcher"]
|
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||||
|
RUN chmod +x /tgi-entrypoint.sh
|
||||||
|
|
||||||
|
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||||
CMD ["--json-output"]
|
CMD ["--json-output"]
|
||||||
|
|
|
@ -3,8 +3,16 @@
|
||||||
title: Text Generation Inference
|
title: Text Generation Inference
|
||||||
- local: quicktour
|
- local: quicktour
|
||||||
title: Quick Tour
|
title: Quick Tour
|
||||||
|
- local: installation_nvidia
|
||||||
|
title: Using TGI with Nvidia GPUs
|
||||||
|
- local: installation_amd
|
||||||
|
title: Using TGI with AMD GPUs
|
||||||
|
- local: installation_gaudi
|
||||||
|
title: Using TGI with Intel Gaudi
|
||||||
|
- local: installation_inferentia
|
||||||
|
title: Using TGI with AWS Inferentia
|
||||||
- local: installation
|
- local: installation
|
||||||
title: Installation
|
title: Installation from source
|
||||||
- local: supported_models
|
- local: supported_models
|
||||||
title: Supported Models and Hardware
|
title: Supported Models and Hardware
|
||||||
- local: messages_api
|
- local: messages_api
|
||||||
|
|
|
@ -19,6 +19,6 @@ docker run --gpus all \
|
||||||
--shm-size 1g \
|
--shm-size 1g \
|
||||||
-e HUGGING_FACE_HUB_TOKEN=$token \
|
-e HUGGING_FACE_HUB_TOKEN=$token \
|
||||||
-p 8080:80 \
|
-p 8080:80 \
|
||||||
-v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 \
|
-v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0.3 \
|
||||||
--model-id $model
|
--model-id $model
|
||||||
```
|
```
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
# Installation
|
# Installation from source
|
||||||
|
|
||||||
This section explains how to install the CLI tool as well as installing TGI from source. **The strongly recommended approach is to use Docker, as it does not require much setup. Check [the Quick Tour](./quicktour) to learn how to run TGI with Docker.**
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
Installing TGI from source is not the recommended usage. We strongly recommend to use TGI through Docker, check the [Quick Tour](./quicktour), [Installation for Nvidia GPUs](./installation_nvidia) and [Installation for AMD GPUs](./installation_amd) to learn how to use TGI with Docker.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
## Install CLI
|
## Install CLI
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Using TGI with AMD GPUs
|
||||||
|
|
||||||
|
TGI is supported and tested on [AMD Instinct MI210](https://www.amd.com/en/products/accelerators/instinct/mi200/mi210.html), [MI250](https://www.amd.com/en/products/accelerators/instinct/mi200/mi250.html) and [MI300](https://www.amd.com/en/products/accelerators/instinct/mi300.html) GPUs. The support may be extended in the future. The recommended usage is through Docker. Make sure to check the [AMD documentation](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/how-to/docker.html) on how to use Docker with AMD GPUs.
|
||||||
|
|
||||||
|
On a server powered by AMD GPUs, TGI can be launched with the following command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
model=teknium/OpenHermes-2.5-Mistral-7B
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
|
||||||
|
--device=/dev/kfd --device=/dev/dri --group-add video \
|
||||||
|
--ipc=host --shm-size 256g --net host -v $volume:/data \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:2.0.3-rocm \
|
||||||
|
--model-id $model
|
||||||
|
```
|
||||||
|
|
||||||
|
The launched TGI server can then be queried from clients, make sure to check out the [Consuming TGI](./basic_tutorials/consuming_tgi) guide.
|
||||||
|
|
||||||
|
## TunableOp
|
||||||
|
|
||||||
|
TGI's docker image for AMD GPUs integrates [PyTorch's TunableOp](https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable), which allows to do an additional warmup to select the best performing matrix multiplication (GEMM) kernel from rocBLAS or hipBLASLt.
|
||||||
|
|
||||||
|
Experimentally, on MI300X, we noticed a 6-8% latency improvement when using TunableOp on top of ROCm 6.1 and PyTorch 2.3.
|
||||||
|
|
||||||
|
TunableOp is enabled by default, the warmup may take 1-2 minutes. In case you would like to disable TunableOp, please pass `--env PYTORCH_TUNABLEOP_ENABLED="0"` when launcher TGI's docker container.
|
||||||
|
|
||||||
|
## Flash attention implementation
|
||||||
|
|
||||||
|
Two implementations of Flash Attention are available for ROCm, the first is [ROCm/flash-attention](https://github.com/ROCm/flash-attention) based on a [Composable Kernel](https://github.com/ROCm/composable_kernel) (CK) implementation, and the second is a [Triton implementation](https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/utils/flash_attn_triton.py).
|
||||||
|
|
||||||
|
By default, as its performances have experimentally been better, Triton implementation is used. It can be disabled (using CK implementation instead) by passing `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container.
|
||||||
|
|
||||||
|
## Unsupported features
|
||||||
|
|
||||||
|
The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
|
||||||
|
* Loading [AWQ](https://huggingface.co/docs/transformers/quantization#awq) checkpoints.
|
||||||
|
* Kernel for sliding window attention (Mistral)
|
|
@ -0,0 +1,3 @@
|
||||||
|
# Using TGI with Intel Gaudi
|
||||||
|
|
||||||
|
Check out this [repository](https://github.com/huggingface/tgi-gaudi) to serve models with TGI on Gaudi and Gaudi2 with [Optimum Habana](https://huggingface.co/docs/optimum/habana/index).
|
|
@ -0,0 +1,3 @@
|
||||||
|
# Using TGI with Inferentia
|
||||||
|
|
||||||
|
Check out this [guide](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference) on how to serve models with TGI on Inferentia2.
|
|
@ -0,0 +1,18 @@
|
||||||
|
# Using TGI with Nvidia GPUs
|
||||||
|
|
||||||
|
TGI optimized models are supported on NVIDIA [H100](https://www.nvidia.com/en-us/data-center/h100/), [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 12.2+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it.
|
||||||
|
|
||||||
|
For other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed.
|
||||||
|
|
||||||
|
TGI can be used on NVIDIA GPUs through its official docker image:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
model=teknium/OpenHermes-2.5-Mistral-7B
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:2.0.3 \
|
||||||
|
--model-id $model
|
||||||
|
```
|
||||||
|
|
||||||
|
The launched TGI server can then be queried from clients, make sure to check out the [Consuming TGI](./basic_tutorials/consuming_tgi) guide.
|
|
@ -2,30 +2,27 @@
|
||||||
|
|
||||||
The easiest way of getting started is using the official Docker container. Install Docker following [their installation instructions](https://docs.docker.com/get-docker/).
|
The easiest way of getting started is using the official Docker container. Install Docker following [their installation instructions](https://docs.docker.com/get-docker/).
|
||||||
|
|
||||||
Let's say you want to deploy [teknium/OpenHermes-2.5-Mistral-7B](https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B) model with TGI. Here is an example on how to do that:
|
## Launching TGI
|
||||||
|
|
||||||
|
Let's say you want to deploy [teknium/OpenHermes-2.5-Mistral-7B](https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B) model with TGI on an Nvidia GPU. Here is an example on how to do that:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
model=teknium/OpenHermes-2.5-Mistral-7B
|
model=teknium/OpenHermes-2.5-Mistral-7B
|
||||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 --model-id $model
|
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:2.0.3 \
|
||||||
|
--model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
<Tip warning={true}>
|
### Supported hardware
|
||||||
|
|
||||||
To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher.
|
TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on.
|
||||||
|
|
||||||
</Tip>
|
## Consuming TGI
|
||||||
|
|
||||||
TGI also supports ROCm-enabled AMD GPUs (only MI210 and MI250 are tested), details are available in the [Supported Hardware section](./supported_models#supported-hardware) and [AMD documentation](https://rocm.docs.amd.com/en/latest/deploy/docker.html). To launch TGI on ROCm GPUs, please use instead:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker run --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4-rocm --model-id $model
|
|
||||||
```
|
|
||||||
|
|
||||||
Once TGI is running, you can use the `generate` endpoint by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint.
|
Once TGI is running, you can use the `generate` endpoint by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint.
|
||||||
|
|
||||||
|
|
||||||
<inferencesnippet>
|
<inferencesnippet>
|
||||||
<python>
|
<python>
|
||||||
|
|
||||||
|
@ -91,7 +88,7 @@ curl 127.0.0.1:8080/generate \
|
||||||
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
|
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run ghcr.io/huggingface/text-generation-inference:1.4 --help
|
docker run ghcr.io/huggingface/text-generation-inference:2.0.3 --help
|
||||||
```
|
```
|
||||||
|
|
||||||
</Tip>
|
</Tip>
|
||||||
|
|
|
@ -40,17 +40,3 @@ If you wish to serve a supported model that already exists on a local folder, ju
|
||||||
```bash
|
```bash
|
||||||
text-generation-launcher --model-id <PATH-TO-LOCAL-BLOOM>
|
text-generation-launcher --model-id <PATH-TO-LOCAL-BLOOM>
|
||||||
``````
|
``````
|
||||||
|
|
||||||
|
|
||||||
## Supported Hardware
|
|
||||||
|
|
||||||
TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 12.2+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed.
|
|
||||||
|
|
||||||
TGI also has support of ROCm-enabled AMD Instinct MI210 and MI250 GPUs, with paged attention, GPTQ quantization, flash attention v2 support. The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
|
|
||||||
* Loading [AWQ](https://huggingface.co/docs/transformers/quantization#awq) checkpoints.
|
|
||||||
* Flash [layer norm kernel](https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm)
|
|
||||||
* Kernel for sliding window attention (Mistral)
|
|
||||||
|
|
||||||
TGI is also supported on the following AI hardware accelerators:
|
|
||||||
- *Habana first-gen Gaudi and Gaudi2:* check out this [repository](https://github.com/huggingface/tgi-gaudi) to serve models with TGI on Gaudi and Gaudi2 with [Optimum Habana](https://huggingface.co/docs/optimum/habana/index)
|
|
||||||
* *AWS Inferentia2:* check out this [guide](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference) on how to serve models with TGI on Inferentia2.
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9
|
flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9
|
||||||
flash_att_v2_commit_rocm := 8736558c287ff2ef28b24878e42828c595ac3e69
|
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
|
||||||
|
|
||||||
|
|
||||||
flash-attention-v2-cuda:
|
flash-attention-v2-cuda:
|
||||||
|
@ -18,12 +18,12 @@ install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
|
||||||
flash-attention-v2-rocm:
|
flash-attention-v2-rocm:
|
||||||
# Clone flash attention
|
# Clone flash attention
|
||||||
pip install -U packaging ninja --no-cache-dir
|
pip install -U packaging ninja --no-cache-dir
|
||||||
git clone https://github.com/fxmarty/flash-attention-rocm flash-attention-v2
|
git clone https://github.com/ROCm/flash-attention.git flash-attention-v2
|
||||||
|
|
||||||
build-flash-attention-v2-rocm: flash-attention-v2-rocm
|
build-flash-attention-v2-rocm: flash-attention-v2-rocm
|
||||||
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm)
|
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm)
|
||||||
cd flash-attention-v2 && git submodule update --init --recursive
|
cd flash-attention-v2 && git submodule update --init --recursive
|
||||||
cd flash-attention-v2 && PYTORCH_ROCM_ARCH=gfx90a python setup.py build
|
cd flash-attention-v2 && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
|
||||||
|
|
||||||
install-flash-attention-v2-rocm: build-flash-attention-v2-rocm
|
install-flash-attention-v2-rocm: build-flash-attention-v2-rocm
|
||||||
cd flash-attention-v2 && git submodule update --init --recursive && python setup.py install
|
cd flash-attention-v2 && git submodule update --init --recursive && python setup.py install
|
||||||
|
|
|
@ -14,11 +14,11 @@ install-vllm-cuda: build-vllm-cuda
|
||||||
vllm-rocm:
|
vllm-rocm:
|
||||||
# Clone vllm
|
# Clone vllm
|
||||||
pip install -U ninja packaging --no-cache-dir
|
pip install -U ninja packaging --no-cache-dir
|
||||||
git clone https://github.com/fxmarty/vllm-public.git vllm
|
git clone https://github.com/fxmarty/rocm-vllm.git vllm
|
||||||
|
|
||||||
build-vllm-rocm: vllm-rocm
|
build-vllm-rocm: vllm-rocm
|
||||||
cd vllm && git fetch && git checkout ad9b7c4095ef54419a0533d254f2ad84bd2dfcae
|
cd vllm && git fetch && git checkout ca6913b3c2ffacdcb7d15e914dc34adbc6c89479
|
||||||
cd vllm && python setup.py build
|
cd vllm && PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install
|
||||||
|
|
||||||
install-vllm-rocm: build-vllm-rocm
|
install-vllm-rocm: build-vllm-rocm
|
||||||
pip uninstall vllm -y || true
|
pip uninstall vllm -y || true
|
||||||
|
|
|
@ -10,8 +10,9 @@ __device__ __forceinline__ __half __compat_hrcp(__half x) {
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
|
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
|
||||||
return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
|
return _Float16_2{
|
||||||
static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
|
_Float16_2{static_cast<_Float16>(1.0f),
|
||||||
|
static_cast<_Float16>(1.0f)} / x.data};
|
||||||
}
|
}
|
||||||
|
|
||||||
#define hrcp __compat_hrcp
|
#define hrcp __compat_hrcp
|
||||||
|
|
|
@ -72,7 +72,7 @@ if SYSTEM == "cuda":
|
||||||
return normed_hidden_states, residual
|
return normed_hidden_states, residual
|
||||||
|
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from vllm import layernorm_ops
|
from vllm._C import ops
|
||||||
|
|
||||||
class FastLayerNorm(nn.LayerNorm):
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
|
@ -172,7 +172,7 @@ class FastRMSNorm(nn.Module):
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
out = torch.empty_like(hidden_states)
|
out = torch.empty_like(hidden_states)
|
||||||
layernorm_ops.rms_norm(
|
ops.rms_norm(
|
||||||
out,
|
out,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
self.weight.data,
|
self.weight.data,
|
||||||
|
|
|
@ -2,6 +2,12 @@ import torch
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
if SYSTEM == "rocm":
|
||||||
|
try:
|
||||||
|
from vllm import _custom_C
|
||||||
|
except Exception as e:
|
||||||
|
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||||
|
|
||||||
|
|
||||||
class FastLinear(torch.nn.Module):
|
class FastLinear(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -29,8 +35,65 @@ class FastLinear(torch.nn.Module):
|
||||||
return F.linear(input, self.weight, self.bias)
|
return F.linear(input, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
|
class FastLinearROCm(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.weight = torch.nn.Parameter(weight)
|
||||||
|
if bias is not None:
|
||||||
|
self.bias = torch.nn.Parameter(bias)
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, config, prefix: str, weights, bias: bool):
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
if bias:
|
||||||
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
return cls(weight, bias)
|
||||||
|
|
||||||
|
def forward(self, inp: torch.Tensor) -> torch.Tensor:
|
||||||
|
weight = self.weight
|
||||||
|
bias = self.bias
|
||||||
|
|
||||||
|
if SYSTEM == "rocm" and inp.numel() // inp.shape[-1] == 1:
|
||||||
|
batched = False
|
||||||
|
inp_shape = inp.shape
|
||||||
|
|
||||||
|
if inp.dim() == 3:
|
||||||
|
inp = inp.view(-1, inp_shape[-1])
|
||||||
|
batched = True
|
||||||
|
|
||||||
|
m, k = weight.shape[0], inp_shape[1]
|
||||||
|
out = torch.empty(
|
||||||
|
inp_shape[0], weight.shape[0], dtype=inp.dtype, device="cuda"
|
||||||
|
)
|
||||||
|
if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192):
|
||||||
|
_custom_C.LLMM1(weight, inp, out, 8)
|
||||||
|
elif k <= 8192 and k % 8 == 0 and m % 4 == 0:
|
||||||
|
_custom_C.LLMM1(weight, inp, out, 4)
|
||||||
|
else:
|
||||||
|
out = F.linear(inp, weight)
|
||||||
|
|
||||||
|
if batched:
|
||||||
|
out.view(*inp_shape[:-1], out.shape[-1])
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
out = out + bias
|
||||||
|
return out
|
||||||
|
return F.linear(inp, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
def get_linear(weight, bias, quantize):
|
def get_linear(weight, bias, quantize):
|
||||||
if quantize is None:
|
if quantize is None:
|
||||||
|
if SYSTEM == "rocm":
|
||||||
|
linear = FastLinearROCm(weight, bias)
|
||||||
|
else:
|
||||||
linear = FastLinear(weight, bias)
|
linear = FastLinear(weight, bias)
|
||||||
elif quantize == "eetq":
|
elif quantize == "eetq":
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -8,7 +8,7 @@ if SYSTEM == "cuda":
|
||||||
from flash_attn.layers.rotary import RotaryEmbedding
|
from flash_attn.layers.rotary import RotaryEmbedding
|
||||||
import rotary_emb
|
import rotary_emb
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from vllm import pos_encoding_ops
|
from vllm._C import ops
|
||||||
|
|
||||||
|
|
||||||
def _create_inv_freq(dim, base, device):
|
def _create_inv_freq(dim, base, device):
|
||||||
|
@ -66,7 +66,7 @@ class PositionRotaryEmbedding(nn.Module):
|
||||||
head_size = query.shape[-1]
|
head_size = query.shape[-1]
|
||||||
|
|
||||||
# Inplace operation, updating query and key.
|
# Inplace operation, updating query and key.
|
||||||
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "xpu":
|
||||||
ipex.llm.functional.rotary_embedding(
|
ipex.llm.functional.rotary_embedding(
|
||||||
query, key, sin, cos, query.size(-1), True
|
query, key, sin, cos, query.size(-1), True
|
||||||
|
|
|
@ -46,6 +46,7 @@ class BLOOMSharded(CausalLM):
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
|
|
@ -69,7 +69,7 @@ class CohereRotary(PositionRotaryEmbedding):
|
||||||
|
|
||||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from vllm import pos_encoding_ops
|
from vllm._C import ops
|
||||||
|
|
||||||
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
||||||
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
|
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
|
||||||
|
@ -77,7 +77,7 @@ class CohereRotary(PositionRotaryEmbedding):
|
||||||
head_size = query.shape[-1]
|
head_size = query.shape[-1]
|
||||||
|
|
||||||
# Inplace operation, updating query and key.
|
# Inplace operation, updating query and key.
|
||||||
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, False)
|
ops.rotary_embedding(query, key, head_size, cos, sin, False)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
||||||
|
|
|
@ -22,10 +22,12 @@ from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
|
@ -38,6 +40,12 @@ from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if SYSTEM == "rocm":
|
||||||
|
try:
|
||||||
|
from vllm import _custom_C
|
||||||
|
except Exception as e:
|
||||||
|
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
bias = config.attention_bias
|
bias = config.attention_bias
|
||||||
|
@ -182,14 +190,16 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
class LlamaMLP(nn.Module):
|
class LlamaMLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
act = config.hidden_act
|
self.hidden_act = config.hidden_act
|
||||||
self.act = (
|
self.act = (
|
||||||
ACT2FN[act]
|
ACT2FN[self.hidden_act]
|
||||||
if "gelu" not in act
|
if "gelu" not in self.hidden_act
|
||||||
else lambda x: torch.nn.functional.gelu(
|
else lambda x: torch.nn.functional.gelu(
|
||||||
x,
|
x,
|
||||||
approximate=(
|
approximate=(
|
||||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
"tanh"
|
||||||
|
if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||||
|
else "none"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -221,6 +231,20 @@ class LlamaMLP(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
|
if (
|
||||||
|
SYSTEM == "rocm"
|
||||||
|
and self.hidden_act == "silu"
|
||||||
|
and hidden_states.shape[0] == 1
|
||||||
|
):
|
||||||
|
out = torch.empty(
|
||||||
|
hidden_states.shape[0],
|
||||||
|
self.intermediate_size,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
|
||||||
|
return self.down_proj(out)
|
||||||
|
else:
|
||||||
gate_up_states = self.gate_up_proj(hidden_states)
|
gate_up_states = self.gate_up_proj(hidden_states)
|
||||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||||
|
|
|
@ -26,6 +26,7 @@ from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
|
@ -40,6 +41,13 @@ from text_generation_server.layers.layernorm import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if SYSTEM == "rocm":
|
||||||
|
try:
|
||||||
|
from vllm import _custom_C
|
||||||
|
except Exception as e:
|
||||||
|
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||||
|
|
||||||
|
|
||||||
class MistralConfig(PretrainedConfig):
|
class MistralConfig(PretrainedConfig):
|
||||||
model_type = "mistral"
|
model_type = "mistral"
|
||||||
|
|
||||||
|
@ -251,14 +259,16 @@ class MistralAttention(torch.nn.Module):
|
||||||
class MistralMLP(nn.Module):
|
class MistralMLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
act = config.hidden_act
|
self.hidden_act = config.hidden_act
|
||||||
self.act = (
|
self.act = (
|
||||||
ACT2FN[act]
|
ACT2FN[self.hidden_act]
|
||||||
if "gelu" not in act
|
if "gelu" not in self.hidden_act
|
||||||
else lambda x: torch.nn.functional.gelu(
|
else lambda x: torch.nn.functional.gelu(
|
||||||
x,
|
x,
|
||||||
approximate=(
|
approximate=(
|
||||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
"tanh"
|
||||||
|
if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||||
|
else "none"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -281,6 +291,20 @@ class MistralMLP(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
|
if (
|
||||||
|
SYSTEM == "rocm"
|
||||||
|
and self.hidden_act == "silu"
|
||||||
|
and hidden_states.shape[0] == 1
|
||||||
|
):
|
||||||
|
out = torch.empty(
|
||||||
|
hidden_states.shape[0],
|
||||||
|
self.intermediate_size,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
|
||||||
|
return self.down_proj(out)
|
||||||
|
else:
|
||||||
gate_up_states = self.gate_up_proj(hidden_states)
|
gate_up_states = self.gate_up_proj(hidden_states)
|
||||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||||
|
|
|
@ -60,7 +60,7 @@ from text_generation_server.utils.import_utils import SYSTEM
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
import dropout_layer_norm
|
import dropout_layer_norm
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from vllm import layernorm_ops
|
from vllm._C import ops
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unsupported system {SYSTEM}")
|
raise RuntimeError(f"Unsupported system {SYSTEM}")
|
||||||
|
|
||||||
|
@ -420,7 +420,7 @@ class IdeficsRMSNorm(nn.Module):
|
||||||
hidden_states = hidden_states.reshape(-1, shape[-1])
|
hidden_states = hidden_states.reshape(-1, shape[-1])
|
||||||
|
|
||||||
out = torch.empty_like(hidden_states)
|
out = torch.empty_like(hidden_states)
|
||||||
layernorm_ops.rms_norm(
|
ops.rms_norm(
|
||||||
out,
|
out,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
self.weight.data,
|
self.weight.data,
|
||||||
|
|
|
@ -12,6 +12,9 @@ from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
|
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
from text_generation_server.utils.speculate import get_speculate
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
|
@ -28,6 +31,7 @@ from text_generation_server.models.cache_manager import (
|
||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
|
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
|
||||||
|
import text_generation_server.models.globals as tgi_globals
|
||||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||||
|
|
||||||
|
@ -783,6 +787,9 @@ class FlashCausalLM(Model):
|
||||||
)
|
)
|
||||||
max_bt = batch.max_blocks
|
max_bt = batch.max_blocks
|
||||||
max_s = max_bt * get_cache_manager().block_size
|
max_s = max_bt * get_cache_manager().block_size
|
||||||
|
|
||||||
|
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
||||||
|
torch.cuda.tunable.tuning_enable(False)
|
||||||
_, batch, _ = self.generate_token(batch)
|
_, batch, _ = self.generate_token(batch)
|
||||||
except torch.cuda.OutOfMemoryError as e:
|
except torch.cuda.OutOfMemoryError as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
@ -820,6 +827,49 @@ class FlashCausalLM(Model):
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if SYSTEM == "rocm":
|
||||||
|
if (
|
||||||
|
os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
|
||||||
|
or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
|
||||||
|
):
|
||||||
|
if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0":
|
||||||
|
torch.cuda.tunable.tuning_enable(True)
|
||||||
|
|
||||||
|
if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS") is not None:
|
||||||
|
tuning_sequences = [
|
||||||
|
int(val)
|
||||||
|
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
tuning_sequences = CUDA_GRAPHS
|
||||||
|
|
||||||
|
tunableop_filepath = os.path.join(
|
||||||
|
HUGGINGFACE_HUB_CACHE,
|
||||||
|
f"tunableop_{tgi_globals.MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`."
|
||||||
|
)
|
||||||
|
|
||||||
|
if os.path.isfile(tunableop_filepath):
|
||||||
|
logger.info(
|
||||||
|
f"The file {tunableop_filepath} already exists and will be reused."
|
||||||
|
)
|
||||||
|
torch.cuda.tunable.read_file(tunableop_filepath)
|
||||||
|
|
||||||
|
os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)
|
||||||
|
|
||||||
|
for seqlen in tuning_sequences:
|
||||||
|
logger.info(f"Warming up TunableOp for seqlen={seqlen}")
|
||||||
|
self.tunableop_warmup(seqlen)
|
||||||
|
torch.cuda.tunable.write_file(tunableop_filepath)
|
||||||
|
torch.cuda.tunable.tuning_enable(False)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp."
|
||||||
|
)
|
||||||
|
|
||||||
if CUDA_GRAPHS:
|
if CUDA_GRAPHS:
|
||||||
try:
|
try:
|
||||||
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
|
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
|
||||||
|
@ -834,6 +884,27 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
return int(num_blocks * BLOCK_SIZE)
|
return int(num_blocks * BLOCK_SIZE)
|
||||||
|
|
||||||
|
def tunableop_warmup(self, seqlen: int):
|
||||||
|
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
|
||||||
|
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
|
||||||
|
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
||||||
|
kv_cache = get_cache_manager().kv_cache
|
||||||
|
|
||||||
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
|
self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=torch.tensor(
|
||||||
|
[0, seqlen], device=self.device, dtype=torch.int32
|
||||||
|
),
|
||||||
|
kv_cache=get_cache_manager().kv_cache,
|
||||||
|
block_tables=None,
|
||||||
|
input_lengths=None,
|
||||||
|
slots=slots,
|
||||||
|
max_s=seqlen,
|
||||||
|
lm_head_indices=None,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, batch: FlashCausalLMBatch
|
self, batch: FlashCausalLMBatch
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
@ -1113,8 +1184,6 @@ class FlashCausalLM(Model):
|
||||||
next_token_texts = []
|
next_token_texts = []
|
||||||
left = 0
|
left = 0
|
||||||
|
|
||||||
logger.debug(f"Accepted ids {n_accepted_ids}")
|
|
||||||
|
|
||||||
current_stopped = False
|
current_stopped = False
|
||||||
for j in range(index, index + n_accepted_ids):
|
for j in range(index, index + n_accepted_ids):
|
||||||
# Generated token
|
# Generated token
|
||||||
|
|
|
@ -15,11 +15,10 @@ from text_generation_server.utils import (
|
||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
|
|
||||||
class FlashGPT2(FlashCausalLM):
|
class FlashGPT2(FlashCausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -15,3 +15,12 @@ else:
|
||||||
cuda_graphs = None
|
cuda_graphs = None
|
||||||
|
|
||||||
CUDA_GRAPHS = cuda_graphs
|
CUDA_GRAPHS = cuda_graphs
|
||||||
|
|
||||||
|
# This is overridden at model loading.
|
||||||
|
global MODEL_ID
|
||||||
|
MODEL_ID = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_model_id(model_id: str):
|
||||||
|
global MODEL_ID
|
||||||
|
MODEL_ID = model_id
|
||||||
|
|
|
@ -21,6 +21,7 @@ from text_generation_server.models.vlm_causal_lm import (
|
||||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
||||||
|
from text_generation_server.models.globals import set_model_id
|
||||||
|
|
||||||
|
|
||||||
class SignalHandler:
|
class SignalHandler:
|
||||||
|
@ -252,6 +253,7 @@ def serve(
|
||||||
while signal_handler.KEEP_PROCESSING:
|
while signal_handler.KEEP_PROCESSING:
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
set_model_id(model_id)
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
serve_inner(
|
serve_inner(
|
||||||
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
|
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
|
||||||
|
|
|
@ -2,14 +2,18 @@ import os
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
import math
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.flash_attn_triton import triton_attention
|
||||||
|
|
||||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||||
HAS_FLASH_ATTN = True
|
HAS_FLASH_ATTN = False
|
||||||
HAS_FLASH_ATTN_V2_CUDA = False
|
HAS_FLASH_ATTN_V2_CUDA = False
|
||||||
HAS_FLASH_ATTN_V2_ROCM = False
|
HAS_FLASH_ATTN_V2_ROCM = False
|
||||||
|
ROCM_USE_FLASH_ATTN_V2_CK = False
|
||||||
|
ROCM_USE_FLASH_ATTN_V2_TRITON = False
|
||||||
|
|
||||||
if SYSTEM == "xpu":
|
if SYSTEM == "xpu":
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
@ -57,10 +61,21 @@ if SYSTEM in {"cuda", "rocm"}:
|
||||||
is_sm75 = major == 7 and minor == 5
|
is_sm75 = major == 7 and minor == 5
|
||||||
is_sm8x = major == 8 and minor >= 0
|
is_sm8x = major == 8 and minor >= 0
|
||||||
is_sm90 = major == 9 and minor == 0
|
is_sm90 = major == 9 and minor == 0
|
||||||
|
is_sm94 = major == 9 and minor == 4
|
||||||
|
|
||||||
|
if SYSTEM == "rocm":
|
||||||
|
if (
|
||||||
|
os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true"
|
||||||
|
or os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "0") == "1"
|
||||||
|
):
|
||||||
|
ROCM_USE_FLASH_ATTN_V2_TRITON = True
|
||||||
|
logger.info("ROCm: using Flash Attention 2 Triton implementation.")
|
||||||
|
else:
|
||||||
|
ROCM_USE_FLASH_ATTN_V2_CK = True
|
||||||
|
logger.info(
|
||||||
|
"ROCm: using Flash Attention 2 Composable Kernel implementation."
|
||||||
|
)
|
||||||
|
|
||||||
HAS_FLASH_ATTN = False
|
|
||||||
HAS_FLASH_ATTN_V2_CUDA = False
|
|
||||||
HAS_FLASH_ATTN_V2_ROCM = False
|
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
import flash_attn_2_cuda
|
import flash_attn_2_cuda
|
||||||
|
@ -71,11 +86,16 @@ if SYSTEM in {"cuda", "rocm"}:
|
||||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||||
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
||||||
)
|
)
|
||||||
if not (is_sm8x or is_sm90):
|
if SYSTEM == "cuda" and not (is_sm8x or is_sm90):
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"GPU with CUDA capability {major} {minor} is not supported for "
|
f"GPU with CUDA capability {major} {minor} is not supported for "
|
||||||
"Flash Attention V2"
|
"Flash Attention V2"
|
||||||
)
|
)
|
||||||
|
elif SYSTEM == "rocm" and not (is_sm8x or is_sm90 or is_sm94):
|
||||||
|
raise ImportError(
|
||||||
|
f"AMD GPU with compute capability {major} {minor} is not supported for "
|
||||||
|
"Flash Attention V2"
|
||||||
|
)
|
||||||
HAS_FLASH_ATTN_V2_CUDA = SYSTEM == "cuda"
|
HAS_FLASH_ATTN_V2_CUDA = SYSTEM == "cuda"
|
||||||
HAS_FLASH_ATTN_V2_ROCM = SYSTEM == "rocm"
|
HAS_FLASH_ATTN_V2_ROCM = SYSTEM == "rocm"
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
|
@ -142,7 +162,7 @@ if HAS_FLASH_ATTN_V2_CUDA:
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif HAS_FLASH_ATTN_V2_ROCM:
|
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q,
|
q,
|
||||||
|
@ -153,6 +173,7 @@ elif HAS_FLASH_ATTN_V2_ROCM:
|
||||||
max_s,
|
max_s,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
|
causal=True,
|
||||||
):
|
):
|
||||||
if window_size_left <= 0 and window_size_left != -1:
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
|
@ -174,11 +195,38 @@ elif HAS_FLASH_ATTN_V2_ROCM:
|
||||||
0.0,
|
0.0,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
False,
|
False,
|
||||||
True,
|
causal,
|
||||||
False,
|
False,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON:
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
softmax_scale,
|
||||||
|
window_size_left=-1,
|
||||||
|
causal=True,
|
||||||
|
):
|
||||||
|
output, _ = triton_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
causal,
|
||||||
|
softmax_scale,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
elif HAS_FLASH_ATTN:
|
elif HAS_FLASH_ATTN:
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
|
|
|
@ -0,0 +1,816 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
"""
|
||||||
|
Fused Attention
|
||||||
|
===============
|
||||||
|
|
||||||
|
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
|
||||||
|
(https://tridao.me/publications/flash2/flash2.pdf)
|
||||||
|
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
|
||||||
|
|
||||||
|
Features supported:
|
||||||
|
|
||||||
|
1) Fwd with causal masking
|
||||||
|
2) Any sequence lengths without padding (currently fwd kernel only)
|
||||||
|
3) Support for different sequence lengths for q and k
|
||||||
|
4) Nested tensor API currently does not support dropout or bias.
|
||||||
|
|
||||||
|
Not currently supported:
|
||||||
|
|
||||||
|
1) Non power of two head dims
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
torch_dtype: tl.constexpr = torch.float16
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def cdiv_fn(x, y):
|
||||||
|
return (x + y - 1) // y
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def max_fn(x, y):
|
||||||
|
return tl.math.max(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||||
|
ms = tl.arange(0, m)
|
||||||
|
ns = tl.arange(0, n)
|
||||||
|
return philox_offset + ms[:, None] * stride + ns[None, :]
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||||
|
rng_offsets = dropout_offsets(
|
||||||
|
philox_seed, philox_offset, dropout_p, m, n, stride
|
||||||
|
).to(tl.uint32)
|
||||||
|
# TODO: use tl.randint for better performance
|
||||||
|
return tl.rand(philox_seed, rng_offsets)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||||
|
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride)
|
||||||
|
rng_keep = rng_output > dropout_p
|
||||||
|
return rng_keep
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def load_fn(block_ptr, first, second, pad):
|
||||||
|
if first and second:
|
||||||
|
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
|
||||||
|
elif first:
|
||||||
|
tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad)
|
||||||
|
elif second:
|
||||||
|
tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad)
|
||||||
|
else:
|
||||||
|
tensor = tl.load(block_ptr)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _attn_fwd_inner(
|
||||||
|
acc,
|
||||||
|
l_i,
|
||||||
|
m_i,
|
||||||
|
q,
|
||||||
|
K_block_ptr,
|
||||||
|
V_block_ptr,
|
||||||
|
start_m,
|
||||||
|
actual_seqlen_k,
|
||||||
|
dropout_p,
|
||||||
|
philox_seed,
|
||||||
|
batch_philox_offset,
|
||||||
|
encoded_softmax_block_ptr,
|
||||||
|
block_min,
|
||||||
|
block_max,
|
||||||
|
offs_n_causal,
|
||||||
|
masked_blocks,
|
||||||
|
n_extra_tokens,
|
||||||
|
bias_ptr,
|
||||||
|
IS_CAUSAL: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
OFFS_M: tl.constexpr,
|
||||||
|
OFFS_N: tl.constexpr,
|
||||||
|
PRE_LOAD_V: tl.constexpr,
|
||||||
|
MASK_STEPS: tl.constexpr,
|
||||||
|
ENABLE_DROPOUT: tl.constexpr,
|
||||||
|
RETURN_ENCODED_SOFTMAX: tl.constexpr,
|
||||||
|
PADDED_HEAD: tl.constexpr,
|
||||||
|
):
|
||||||
|
# loop over k, v, and update accumulator
|
||||||
|
for start_n in range(block_min, block_max, BLOCK_N):
|
||||||
|
# For padded blocks, we will overrun the tensor size if
|
||||||
|
# we load all BLOCK_N. For others, the blocks are all within range.
|
||||||
|
k = load_fn(
|
||||||
|
K_block_ptr,
|
||||||
|
PADDED_HEAD,
|
||||||
|
MASK_STEPS and (n_extra_tokens != 0),
|
||||||
|
"zero",
|
||||||
|
)
|
||||||
|
if PRE_LOAD_V:
|
||||||
|
v = load_fn(
|
||||||
|
V_block_ptr,
|
||||||
|
MASK_STEPS and (n_extra_tokens != 0),
|
||||||
|
PADDED_HEAD,
|
||||||
|
"zero",
|
||||||
|
)
|
||||||
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
|
# We start from end of seqlen_k so only the first iteration would need
|
||||||
|
# to be checked for padding if it is not a multiple of block_n
|
||||||
|
# TODO: This can be optimized to only be true for the padded block.
|
||||||
|
if MASK_STEPS: # noqa: SIM102
|
||||||
|
# If this is the last block / iteration, we want to
|
||||||
|
# mask if the sequence length is not a multiple of block size
|
||||||
|
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps
|
||||||
|
# if not is_modulo_mn. last step might get wasted but that is okay.
|
||||||
|
# check if this masking works for that case.
|
||||||
|
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
|
||||||
|
boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32)
|
||||||
|
size_n = start_n + OFFS_N[None, :]
|
||||||
|
mask = size_n < boundary_m[:, None]
|
||||||
|
qk = tl.where(mask, qk, float("-inf"))
|
||||||
|
if IS_CAUSAL:
|
||||||
|
causal_boundary = start_n + offs_n_causal
|
||||||
|
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
|
||||||
|
qk = tl.where(causal_mask, qk, float("-inf"))
|
||||||
|
# -- compute qk ----
|
||||||
|
qk += tl.dot(q, k)
|
||||||
|
if bias_ptr is not None:
|
||||||
|
bias = load_fn(
|
||||||
|
bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero"
|
||||||
|
)
|
||||||
|
# While bias is added after multiplying qk with sm_scale, our
|
||||||
|
# optimization to use 2^x instead of e^x results in an additional
|
||||||
|
# scale factor of log2(e) which we must also multiply the bias with.
|
||||||
|
qk += bias * 1.44269504089
|
||||||
|
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||||
|
qk = qk - m_ij[:, None]
|
||||||
|
p = tl.math.exp2(qk)
|
||||||
|
|
||||||
|
# CAVEAT: Must update l_ij before applying dropout
|
||||||
|
l_ij = tl.sum(p, 1)
|
||||||
|
if ENABLE_DROPOUT:
|
||||||
|
philox_offset = (
|
||||||
|
batch_philox_offset
|
||||||
|
+ start_m * BLOCK_M * actual_seqlen_k
|
||||||
|
+ start_n
|
||||||
|
- BLOCK_N
|
||||||
|
)
|
||||||
|
keep = dropout_mask(
|
||||||
|
philox_seed,
|
||||||
|
philox_offset,
|
||||||
|
dropout_p,
|
||||||
|
BLOCK_M,
|
||||||
|
BLOCK_N,
|
||||||
|
actual_seqlen_k,
|
||||||
|
)
|
||||||
|
if RETURN_ENCODED_SOFTMAX:
|
||||||
|
tl.store(
|
||||||
|
encoded_softmax_block_ptr,
|
||||||
|
tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty),
|
||||||
|
)
|
||||||
|
p = tl.where(keep, p, 0.0)
|
||||||
|
elif RETURN_ENCODED_SOFTMAX:
|
||||||
|
tl.store(
|
||||||
|
encoded_softmax_block_ptr,
|
||||||
|
p.to(encoded_softmax_block_ptr.type.element_ty),
|
||||||
|
)
|
||||||
|
# -- update output accumulator --
|
||||||
|
alpha = tl.math.exp2(m_i - m_ij)
|
||||||
|
acc = acc * alpha[:, None]
|
||||||
|
if not PRE_LOAD_V:
|
||||||
|
v = load_fn(
|
||||||
|
V_block_ptr,
|
||||||
|
MASK_STEPS and (n_extra_tokens != 0),
|
||||||
|
PADDED_HEAD,
|
||||||
|
"zero",
|
||||||
|
)
|
||||||
|
# -- update m_i and l_i
|
||||||
|
l_i = l_i * alpha + l_ij
|
||||||
|
# update m_i and l_i
|
||||||
|
m_i = m_ij
|
||||||
|
acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
|
||||||
|
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
||||||
|
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||||
|
if bias_ptr is not None:
|
||||||
|
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
|
||||||
|
if RETURN_ENCODED_SOFTMAX:
|
||||||
|
encoded_softmax_block_ptr = tl.advance(
|
||||||
|
encoded_softmax_block_ptr, (0, BLOCK_N)
|
||||||
|
)
|
||||||
|
return acc, l_i, m_i
|
||||||
|
|
||||||
|
|
||||||
|
@triton.autotune(
|
||||||
|
configs=[
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_M": 256,
|
||||||
|
"BLOCK_N": 64,
|
||||||
|
"waves_per_eu": 2,
|
||||||
|
"PRE_LOAD_V": False,
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=8,
|
||||||
|
),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_M": 128,
|
||||||
|
"BLOCK_N": 128,
|
||||||
|
"waves_per_eu": 2,
|
||||||
|
"PRE_LOAD_V": False,
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=4,
|
||||||
|
),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_M": 256,
|
||||||
|
"BLOCK_N": 128,
|
||||||
|
"waves_per_eu": 2,
|
||||||
|
"PRE_LOAD_V": False,
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=8,
|
||||||
|
),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_M": 128,
|
||||||
|
"BLOCK_N": 64,
|
||||||
|
"waves_per_eu": 3,
|
||||||
|
"PRE_LOAD_V": True,
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=4,
|
||||||
|
),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_M": 128,
|
||||||
|
"BLOCK_N": 64,
|
||||||
|
"waves_per_eu": 3,
|
||||||
|
"PRE_LOAD_V": False,
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=4,
|
||||||
|
),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_M": 64,
|
||||||
|
"BLOCK_N": 64,
|
||||||
|
"waves_per_eu": 4,
|
||||||
|
"PRE_LOAD_V": False,
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=8,
|
||||||
|
),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_M": 32,
|
||||||
|
"BLOCK_N": 32,
|
||||||
|
"waves_per_eu": 4,
|
||||||
|
"PRE_LOAD_V": False,
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=8,
|
||||||
|
),
|
||||||
|
# TODO: This config fails with head_size not pow2 with data mismatches.
|
||||||
|
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
|
||||||
|
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_M": 16,
|
||||||
|
"BLOCK_N": 16,
|
||||||
|
"waves_per_eu": 1,
|
||||||
|
"PRE_LOAD_V": False,
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=4,
|
||||||
|
),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_M": 128,
|
||||||
|
"BLOCK_N": 64,
|
||||||
|
"waves_per_eu": 1,
|
||||||
|
"PRE_LOAD_V": False,
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=4,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
key=["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"],
|
||||||
|
)
|
||||||
|
@triton.jit
|
||||||
|
def attn_fwd(
|
||||||
|
Q,
|
||||||
|
K,
|
||||||
|
V,
|
||||||
|
bias,
|
||||||
|
sm_scale,
|
||||||
|
L,
|
||||||
|
Out,
|
||||||
|
stride_qz,
|
||||||
|
stride_qh,
|
||||||
|
stride_qm,
|
||||||
|
stride_qk,
|
||||||
|
stride_kz,
|
||||||
|
stride_kh,
|
||||||
|
stride_kn,
|
||||||
|
stride_kk,
|
||||||
|
stride_vz,
|
||||||
|
stride_vh,
|
||||||
|
stride_vk,
|
||||||
|
stride_vn,
|
||||||
|
stride_oz,
|
||||||
|
stride_oh,
|
||||||
|
stride_om,
|
||||||
|
stride_on,
|
||||||
|
stride_bz,
|
||||||
|
stride_bh,
|
||||||
|
stride_bm,
|
||||||
|
stride_bn,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
dropout_p,
|
||||||
|
philox_seed,
|
||||||
|
philox_offset_base,
|
||||||
|
encoded_softmax,
|
||||||
|
HQ: tl.constexpr,
|
||||||
|
HK: tl.constexpr,
|
||||||
|
ACTUAL_BLOCK_DMODEL: tl.constexpr,
|
||||||
|
MAX_SEQLENS_Q: tl.constexpr,
|
||||||
|
MAX_SEQLENS_K: tl.constexpr,
|
||||||
|
VARLEN: tl.constexpr,
|
||||||
|
IS_CAUSAL: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
PRE_LOAD_V: tl.constexpr,
|
||||||
|
BIAS_TYPE: tl.constexpr,
|
||||||
|
ENABLE_DROPOUT: tl.constexpr,
|
||||||
|
RETURN_ENCODED_SOFTMAX: tl.constexpr,
|
||||||
|
):
|
||||||
|
start_m = tl.program_id(0)
|
||||||
|
off_h_q = tl.program_id(1)
|
||||||
|
off_z = tl.program_id(2)
|
||||||
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
|
offs_n = tl.arange(0, BLOCK_N)
|
||||||
|
if VARLEN:
|
||||||
|
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
|
||||||
|
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
|
||||||
|
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
|
||||||
|
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
|
||||||
|
# small for all start_m so for those we return early.
|
||||||
|
if start_m * BLOCK_M > seqlen_q:
|
||||||
|
return
|
||||||
|
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
|
||||||
|
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
|
||||||
|
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
|
||||||
|
else:
|
||||||
|
cu_seqlens_q_start = 0
|
||||||
|
cu_seqlens_k_start = 0
|
||||||
|
seqlen_q = MAX_SEQLENS_Q
|
||||||
|
seqlen_k = MAX_SEQLENS_K
|
||||||
|
|
||||||
|
# Now we compute whether we need to exit early due to causal masking.
|
||||||
|
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
|
||||||
|
# are completely masked, resulting in 0s written to the output, and
|
||||||
|
# inf written to LSE. We don't need to do any GEMMs in this case.
|
||||||
|
# This block of code determines what N is, and if this WG is operating
|
||||||
|
# on those M rows.
|
||||||
|
n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
|
||||||
|
if IS_CAUSAL:
|
||||||
|
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
|
||||||
|
# If seqlen_q != seqlen_k, attn scores are rectangular which means
|
||||||
|
# the causal mask boundary is bottom right aligned, and ends at either
|
||||||
|
# the top edge (seqlen_q < seqlen_k) or left edge.
|
||||||
|
# This captures the decrease in n_blocks if we have a rectangular attn
|
||||||
|
# matrix
|
||||||
|
n_blocks_seqlen = cdiv_fn(
|
||||||
|
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N
|
||||||
|
)
|
||||||
|
# This is what adjusts the block_max for the current WG, only
|
||||||
|
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
|
||||||
|
n_blocks = min(n_blocks, n_blocks_seqlen)
|
||||||
|
# If we have no blocks after adjusting for seqlen deltas, this WG is
|
||||||
|
# part of the blocks that are all 0. We exit early.
|
||||||
|
if n_blocks <= 0:
|
||||||
|
o_offset = (
|
||||||
|
off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
|
||||||
|
)
|
||||||
|
O_block_ptr = tl.make_block_ptr(
|
||||||
|
base=Out + o_offset,
|
||||||
|
shape=(seqlen_q, BLOCK_DMODEL),
|
||||||
|
strides=(stride_om, stride_on),
|
||||||
|
offsets=(start_m * BLOCK_M, 0),
|
||||||
|
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||||
|
order=(1, 0),
|
||||||
|
)
|
||||||
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
|
||||||
|
# We still need to write 0s to the result
|
||||||
|
# tl.store(O_block_ptr,
|
||||||
|
# acc.to(Out.type.element_ty), boundary_check=(0,1))
|
||||||
|
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
|
||||||
|
# + offs_m
|
||||||
|
# We store inf to LSE, not -inf because in the bwd pass,
|
||||||
|
# we subtract this
|
||||||
|
# from qk which makes it -inf, such that exp(qk - inf) = 0
|
||||||
|
# for these masked blocks.
|
||||||
|
# l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
|
||||||
|
# tl.store(l_ptrs, l)
|
||||||
|
# TODO: Should dropout and return encoded softmax be handled here?
|
||||||
|
return
|
||||||
|
|
||||||
|
# If MQA / GQA, set the K and V head offsets appropriately.
|
||||||
|
GROUP_SIZE: tl.constexpr = HQ // HK
|
||||||
|
if GROUP_SIZE != 1:
|
||||||
|
off_h_k = off_h_q // GROUP_SIZE
|
||||||
|
else:
|
||||||
|
off_h_k = off_h_q
|
||||||
|
|
||||||
|
n_extra_tokens = 0
|
||||||
|
if seqlen_k < BLOCK_N:
|
||||||
|
n_extra_tokens = BLOCK_N - seqlen_k
|
||||||
|
elif seqlen_k % BLOCK_N:
|
||||||
|
n_extra_tokens = seqlen_k % BLOCK_N
|
||||||
|
PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
|
||||||
|
|
||||||
|
# Compute pointers for all the tensors used in this kernel.
|
||||||
|
q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm
|
||||||
|
Q_block_ptr = tl.make_block_ptr(
|
||||||
|
base=Q + q_offset,
|
||||||
|
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
||||||
|
strides=(stride_qm, stride_qk),
|
||||||
|
offsets=(start_m * BLOCK_M, 0),
|
||||||
|
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||||
|
order=(1, 0),
|
||||||
|
)
|
||||||
|
k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn
|
||||||
|
K_block_ptr = tl.make_block_ptr(
|
||||||
|
base=K + k_offset,
|
||||||
|
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
|
||||||
|
strides=(stride_kk, stride_kn),
|
||||||
|
offsets=(0, 0),
|
||||||
|
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||||
|
order=(0, 1),
|
||||||
|
)
|
||||||
|
v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk
|
||||||
|
V_block_ptr = tl.make_block_ptr(
|
||||||
|
base=V + v_offset,
|
||||||
|
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
|
||||||
|
strides=(stride_vk, stride_vn),
|
||||||
|
offsets=(0, 0),
|
||||||
|
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||||
|
order=(1, 0),
|
||||||
|
)
|
||||||
|
if BIAS_TYPE != 0:
|
||||||
|
bias_ptr = tl.make_block_ptr(
|
||||||
|
base=bias + off_h_q * stride_bh,
|
||||||
|
shape=(seqlen_q, seqlen_k),
|
||||||
|
strides=(stride_bm, stride_bn),
|
||||||
|
offsets=(start_m * BLOCK_M, 0),
|
||||||
|
block_shape=(BLOCK_M, BLOCK_N),
|
||||||
|
order=(1, 0),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
bias_ptr = None
|
||||||
|
if ENABLE_DROPOUT:
|
||||||
|
batch_philox_offset = (
|
||||||
|
philox_offset_base + (off_z * HQ + off_h_q) * seqlen_q * seqlen_k
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch_philox_offset = 0
|
||||||
|
# We can ask to return the dropout mask without actually doing any dropout.
|
||||||
|
# In this case, we return an invalid pointer so indicate the mask is not i
|
||||||
|
# valid.
|
||||||
|
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
|
||||||
|
if RETURN_ENCODED_SOFTMAX:
|
||||||
|
encoded_softmax_block_ptr = tl.make_block_ptr(
|
||||||
|
base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,
|
||||||
|
shape=(seqlen_q, seqlen_k),
|
||||||
|
strides=(seqlen_k, 1),
|
||||||
|
offsets=(start_m * BLOCK_M, 0),
|
||||||
|
block_shape=(BLOCK_M, BLOCK_N),
|
||||||
|
order=(1, 0),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
encoded_softmax_block_ptr = 0
|
||||||
|
# initialize pointer to m and l
|
||||||
|
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||||
|
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||||
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||||
|
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not
|
||||||
|
# have native e^x support in HW.
|
||||||
|
qk_scale = sm_scale * 1.44269504089
|
||||||
|
# Q is loaded once at the beginning and shared by all N blocks.
|
||||||
|
q = load_fn(Q_block_ptr, True, PADDED_HEAD, "zero")
|
||||||
|
q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
|
||||||
|
|
||||||
|
# Here we compute how many full and masked blocks we have.
|
||||||
|
padded_block_k = n_extra_tokens != 0
|
||||||
|
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
|
||||||
|
if IS_CAUSAL:
|
||||||
|
# There are always at least BLOCK_M // BLOCK_N masked blocks.
|
||||||
|
# Additionally there might be one more due to dissimilar seqlens.
|
||||||
|
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
|
||||||
|
else:
|
||||||
|
# Padding on Q does not need to be masked in the FA loop.
|
||||||
|
masked_blocks = padded_block_k
|
||||||
|
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional
|
||||||
|
# block. In this case we might exceed n_blocks so pick the min.
|
||||||
|
masked_blocks = min(masked_blocks, n_blocks)
|
||||||
|
n_full_blocks = n_blocks - masked_blocks
|
||||||
|
block_min = 0
|
||||||
|
block_max = n_blocks * BLOCK_N
|
||||||
|
# Compute for full blocks. Here we set causal to false regardless of its
|
||||||
|
# value because there is no masking. Similarly we do not need padding.
|
||||||
|
if n_full_blocks > 0:
|
||||||
|
block_max = (n_blocks - masked_blocks) * BLOCK_N
|
||||||
|
acc, l_i, m_i = _attn_fwd_inner(
|
||||||
|
acc,
|
||||||
|
l_i,
|
||||||
|
m_i,
|
||||||
|
q,
|
||||||
|
K_block_ptr,
|
||||||
|
V_block_ptr,
|
||||||
|
start_m,
|
||||||
|
seqlen_k,
|
||||||
|
dropout_p,
|
||||||
|
philox_seed,
|
||||||
|
batch_philox_offset,
|
||||||
|
encoded_softmax_block_ptr,
|
||||||
|
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
|
||||||
|
block_min,
|
||||||
|
block_max,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
bias_ptr,
|
||||||
|
# IS_CAUSAL, ....
|
||||||
|
False,
|
||||||
|
BLOCK_M,
|
||||||
|
BLOCK_DMODEL,
|
||||||
|
BLOCK_N,
|
||||||
|
offs_m,
|
||||||
|
offs_n,
|
||||||
|
# _, MASK_STEPS, ...
|
||||||
|
PRE_LOAD_V,
|
||||||
|
False,
|
||||||
|
ENABLE_DROPOUT,
|
||||||
|
RETURN_ENCODED_SOFTMAX,
|
||||||
|
PADDED_HEAD,
|
||||||
|
)
|
||||||
|
block_min = block_max
|
||||||
|
block_max = n_blocks * BLOCK_N
|
||||||
|
|
||||||
|
tl.debug_barrier()
|
||||||
|
# Remaining blocks, if any, are full / not masked.
|
||||||
|
if masked_blocks > 0:
|
||||||
|
offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0
|
||||||
|
K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))
|
||||||
|
V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))
|
||||||
|
if bias_ptr is not None:
|
||||||
|
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
|
||||||
|
if RETURN_ENCODED_SOFTMAX:
|
||||||
|
encoded_softmax_block_ptr = tl.advance(
|
||||||
|
encoded_softmax_block_ptr, (0, n_full_blocks)
|
||||||
|
)
|
||||||
|
acc, l_i, m_i = _attn_fwd_inner(
|
||||||
|
acc,
|
||||||
|
l_i,
|
||||||
|
m_i,
|
||||||
|
q,
|
||||||
|
K_block_ptr,
|
||||||
|
V_block_ptr,
|
||||||
|
start_m,
|
||||||
|
seqlen_k,
|
||||||
|
dropout_p,
|
||||||
|
philox_seed,
|
||||||
|
batch_philox_offset,
|
||||||
|
encoded_softmax_block_ptr,
|
||||||
|
block_min,
|
||||||
|
block_max,
|
||||||
|
offs_n_causal,
|
||||||
|
masked_blocks,
|
||||||
|
n_extra_tokens,
|
||||||
|
bias_ptr,
|
||||||
|
IS_CAUSAL,
|
||||||
|
BLOCK_M,
|
||||||
|
BLOCK_DMODEL,
|
||||||
|
BLOCK_N,
|
||||||
|
offs_m,
|
||||||
|
offs_n,
|
||||||
|
# _, MASK_STEPS, ...
|
||||||
|
PRE_LOAD_V,
|
||||||
|
True,
|
||||||
|
ENABLE_DROPOUT,
|
||||||
|
RETURN_ENCODED_SOFTMAX,
|
||||||
|
PADDED_HEAD,
|
||||||
|
)
|
||||||
|
# epilogue
|
||||||
|
acc = acc / l_i[:, None]
|
||||||
|
if ENABLE_DROPOUT:
|
||||||
|
acc = acc / (1 - dropout_p)
|
||||||
|
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
|
||||||
|
# then we have one block with a row of all NaNs which come from computing
|
||||||
|
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
|
||||||
|
# and store 0s where there are NaNs as these rows should've been zeroed out.
|
||||||
|
end_m_idx = (start_m + 1) * BLOCK_M
|
||||||
|
start_m_idx = start_m * BLOCK_M
|
||||||
|
causal_start_idx = seqlen_q - seqlen_k
|
||||||
|
acc = acc.to(Out.type.element_ty)
|
||||||
|
if IS_CAUSAL: # noqa: SIM102
|
||||||
|
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
|
||||||
|
out_mask_boundary = tl.full(
|
||||||
|
(BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32
|
||||||
|
)
|
||||||
|
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
|
||||||
|
out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :]
|
||||||
|
z = 0.0
|
||||||
|
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
|
||||||
|
# write back LSE
|
||||||
|
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
|
||||||
|
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
|
||||||
|
# few rows. This is only true for the last M block. For others,
|
||||||
|
# overflow_size will be -ve
|
||||||
|
# overflow_size = end_m_idx - seqlen_q
|
||||||
|
# if overflow_size > 0:
|
||||||
|
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
|
||||||
|
# # This is a > check because mask being 0 blocks the store.
|
||||||
|
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
|
||||||
|
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
|
||||||
|
# else:
|
||||||
|
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
|
||||||
|
|
||||||
|
# write back O
|
||||||
|
o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
|
||||||
|
O_block_ptr = tl.make_block_ptr(
|
||||||
|
base=Out + o_offset,
|
||||||
|
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
||||||
|
strides=(stride_om, stride_on),
|
||||||
|
offsets=(start_m * BLOCK_M, 0),
|
||||||
|
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||||
|
order=(1, 0),
|
||||||
|
)
|
||||||
|
# Need boundary check on this to make sure the padding from the
|
||||||
|
# Q and KV tensors in both dims are not part of what we store back.
|
||||||
|
# TODO: Do the boundary check optionally.
|
||||||
|
tl.store(O_block_ptr, acc, boundary_check=(0, 1))
|
||||||
|
|
||||||
|
|
||||||
|
def check_args(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
o,
|
||||||
|
varlen=True,
|
||||||
|
max_seqlens=None,
|
||||||
|
cu_seqlens_q=None,
|
||||||
|
cu_seqlens_k=None,
|
||||||
|
):
|
||||||
|
assert q.dim() == k.dim() and q.dim() == v.dim()
|
||||||
|
if varlen:
|
||||||
|
assert q.dim() == 3
|
||||||
|
total_q, nheads_q, head_size = q.shape
|
||||||
|
total_k, nheads_k, _ = k.shape
|
||||||
|
assert cu_seqlens_q is not None
|
||||||
|
assert cu_seqlens_k is not None
|
||||||
|
assert len(cu_seqlens_q) == len(cu_seqlens_k)
|
||||||
|
else:
|
||||||
|
assert q.dim() == 4
|
||||||
|
batch, nheads_q, seqlen_q, head_size = q.shape
|
||||||
|
_, nheads_k, seqlen_k, _ = k.shape
|
||||||
|
assert max_seqlens > 0
|
||||||
|
assert k.shape == v.shape
|
||||||
|
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
|
||||||
|
# TODO: Change assert if we support qkl f8 and v f16
|
||||||
|
assert q.dtype == k.dtype and q.dtype == v.dtype
|
||||||
|
# TODO: Fix assert to check head size <=256 once supported
|
||||||
|
assert head_size <= 128
|
||||||
|
assert o.shape == q.shape
|
||||||
|
assert (nheads_q % nheads_k) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class _attention(torch.autograd.Function):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(
|
||||||
|
ctx,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
o,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlens_q,
|
||||||
|
max_seqlens_k,
|
||||||
|
causal=False,
|
||||||
|
sm_scale=1.0,
|
||||||
|
bias=None,
|
||||||
|
):
|
||||||
|
if o is None:
|
||||||
|
o = torch.empty_like(q, dtype=v.dtype)
|
||||||
|
|
||||||
|
check_args(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
o,
|
||||||
|
varlen=True,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
)
|
||||||
|
if True: # varlen
|
||||||
|
total_q, nheads_q, head_size = q.shape
|
||||||
|
total_k, nheads_k, _ = k.shape
|
||||||
|
batch = len(cu_seqlens_q) - 1
|
||||||
|
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
|
||||||
|
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
|
||||||
|
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
|
||||||
|
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
|
||||||
|
else:
|
||||||
|
batch, seqlen_q, nheads_q, head_size = q.shape
|
||||||
|
_, seqlen_k, nheads_k, _ = k.shape
|
||||||
|
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
|
||||||
|
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
|
||||||
|
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
|
||||||
|
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
|
||||||
|
|
||||||
|
# Get closest power of 2 over or equal to 32.
|
||||||
|
padded_d_model = 1 << (head_size - 1).bit_length()
|
||||||
|
padded_d_model = max(padded_d_model, 16)
|
||||||
|
|
||||||
|
grid = lambda META: (
|
||||||
|
triton.cdiv(max_seqlens_q, META["BLOCK_M"]),
|
||||||
|
nheads_q,
|
||||||
|
batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
encoded_softmax = None
|
||||||
|
|
||||||
|
# Seed the RNG so we get reproducible results for testing.
|
||||||
|
philox_seed = 0x1BF52
|
||||||
|
philox_offset = 0x1D4B42
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
bias_strides = (
|
||||||
|
bias.stride(0),
|
||||||
|
bias.stride(1),
|
||||||
|
bias.stride(2),
|
||||||
|
bias.stride(3),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
bias_strides = (0, 0, 0, 0)
|
||||||
|
|
||||||
|
attn_fwd[grid](
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
bias,
|
||||||
|
sm_scale,
|
||||||
|
None,
|
||||||
|
o,
|
||||||
|
*q_strides,
|
||||||
|
*k_strides,
|
||||||
|
*v_strides,
|
||||||
|
*o_strides,
|
||||||
|
*bias_strides,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
dropout_p=0.0,
|
||||||
|
philox_seed=philox_seed,
|
||||||
|
philox_offset_base=philox_offset,
|
||||||
|
encoded_softmax=encoded_softmax,
|
||||||
|
HQ=nheads_q,
|
||||||
|
HK=nheads_k,
|
||||||
|
ACTUAL_BLOCK_DMODEL=head_size,
|
||||||
|
MAX_SEQLENS_Q=max_seqlens_q,
|
||||||
|
MAX_SEQLENS_K=max_seqlens_k,
|
||||||
|
IS_CAUSAL=causal,
|
||||||
|
VARLEN=True,
|
||||||
|
BLOCK_DMODEL=padded_d_model,
|
||||||
|
BIAS_TYPE=0 if bias is None else 1,
|
||||||
|
ENABLE_DROPOUT=False,
|
||||||
|
RETURN_ENCODED_SOFTMAX=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx.grid = grid
|
||||||
|
ctx.sm_scale = sm_scale
|
||||||
|
ctx.BLOCK_DMODEL = head_size
|
||||||
|
ctx.causal = causal
|
||||||
|
ctx.dropout_p = 0.0
|
||||||
|
ctx.philox_seed = philox_seed
|
||||||
|
ctx.philox_offset = philox_offset
|
||||||
|
ctx.encoded_softmax = encoded_softmax
|
||||||
|
ctx.return_encoded_softmax = False
|
||||||
|
return o, encoded_softmax
|
||||||
|
|
||||||
|
|
||||||
|
triton_attention = _attention.apply
|
|
@ -5,6 +5,14 @@ _PARTITION_SIZE = 512
|
||||||
|
|
||||||
if SYSTEM == "xpu":
|
if SYSTEM == "xpu":
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from vllm._C import cache_ops
|
||||||
|
from vllm._C import ops
|
||||||
|
except Exception as e:
|
||||||
|
raise ImportError(
|
||||||
|
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def reshape_and_cache(
|
def reshape_and_cache(
|
||||||
|
@ -14,22 +22,14 @@ def reshape_and_cache(
|
||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
):
|
):
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "xpu":
|
||||||
from vllm._C import cache_ops
|
|
||||||
|
|
||||||
cache_ops.reshape_and_cache(
|
|
||||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
|
||||||
)
|
|
||||||
elif SYSTEM == "rocm":
|
|
||||||
from vllm import cache_ops
|
|
||||||
|
|
||||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
|
|
||||||
elif SYSTEM == "xpu":
|
|
||||||
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
||||||
key, value, key_cache, value_cache, slots
|
key, value, key_cache, value_cache, slots
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("vllm is not supported on your system")
|
cache_ops.reshape_and_cache(
|
||||||
|
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
|
@ -87,9 +87,6 @@ def attention(
|
||||||
# to parallelize.
|
# to parallelize.
|
||||||
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||||
if use_v1:
|
if use_v1:
|
||||||
if SYSTEM == "cuda":
|
|
||||||
from vllm._C import ops
|
|
||||||
|
|
||||||
ops.paged_attention_v1(
|
ops.paged_attention_v1(
|
||||||
out,
|
out,
|
||||||
query,
|
query,
|
||||||
|
@ -105,25 +102,6 @@ def attention(
|
||||||
"auto",
|
"auto",
|
||||||
1.0,
|
1.0,
|
||||||
)
|
)
|
||||||
elif SYSTEM == "rocm":
|
|
||||||
from vllm import attention_ops
|
|
||||||
|
|
||||||
attention_ops.paged_attention_v1(
|
|
||||||
out,
|
|
||||||
query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
kv_head_mapping,
|
|
||||||
softmax_scale,
|
|
||||||
block_tables,
|
|
||||||
input_lengths,
|
|
||||||
block_size,
|
|
||||||
max_s,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("vllm is not supported on your system")
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Run PagedAttention V2.
|
# Run PagedAttention V2.
|
||||||
assert _PARTITION_SIZE % block_size == 0
|
assert _PARTITION_SIZE % block_size == 0
|
||||||
|
@ -139,9 +117,6 @@ def attention(
|
||||||
)
|
)
|
||||||
max_logits = torch.empty_like(exp_sums)
|
max_logits = torch.empty_like(exp_sums)
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
|
||||||
from vllm._C import ops
|
|
||||||
|
|
||||||
ops.paged_attention_v2(
|
ops.paged_attention_v2(
|
||||||
out,
|
out,
|
||||||
exp_sums,
|
exp_sums,
|
||||||
|
@ -160,24 +135,3 @@ def attention(
|
||||||
"auto",
|
"auto",
|
||||||
1.0,
|
1.0,
|
||||||
)
|
)
|
||||||
elif SYSTEM == "rocm":
|
|
||||||
from vllm import attention_ops
|
|
||||||
|
|
||||||
attention_ops.paged_attention_v2(
|
|
||||||
out,
|
|
||||||
exp_sums,
|
|
||||||
max_logits,
|
|
||||||
tmp_output,
|
|
||||||
query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
kv_head_mapping,
|
|
||||||
softmax_scale,
|
|
||||||
block_tables,
|
|
||||||
input_lengths,
|
|
||||||
block_size,
|
|
||||||
max_s,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("vllm is not supported on your system")
|
|
||||||
|
|
Loading…
Reference in New Issue