hf_text-generation-inference/docs/source/installation_amd.md

45 lines
3.3 KiB
Markdown
Raw Normal View History

MI300 compatibility (#1764) Adds support for AMD Instinct MI300 in TGI. Most changes are: * Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable. TunableOp is disabled by default, and can be enabled with `PYTORCH_TUNABLEOP_ENABLED=1`. * Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes from https://github.com/pytorch/pytorch/pull/124362) * Support SILU & Linear custom kernels contributed by AMD * Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/, branching out of a much more recent commit https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308 * Support FA2 Triton kernel as recommended by AMD. Can be used by specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`. * Update dockerfile to ROCm 6.1 By default, TunableOp tuning results are saved in `/data` (e.g. `/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order to avoid to have to rerun the tuning at each `docker run`. Example: ``` Validator,PT_VERSION,2.3.0 Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c Validator,HIPBLASLT_VERSION,0.7.0-1549b021 Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack- Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098 GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431 GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546 GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119 GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645 GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971 GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694 GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522 GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671 GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834 GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622 GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122 GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191 GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514 GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914 GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516 GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953 GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043 GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497 GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895 GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716 GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731 GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816 GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701 GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159 GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524 GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074 GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045 GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582 GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705 GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489 ``` --------- Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
2024-05-17 07:30:47 -06:00
# Using TGI with AMD GPUs
TGI is supported and tested on [AMD Instinct MI210](https://www.amd.com/en/products/accelerators/instinct/mi200/mi210.html), [MI250](https://www.amd.com/en/products/accelerators/instinct/mi200/mi250.html) and [MI300](https://www.amd.com/en/products/accelerators/instinct/mi300.html) GPUs. The support may be extended in the future. The recommended usage is through Docker. Make sure to check the [AMD documentation](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/how-to/docker.html) on how to use Docker with AMD GPUs.
On a server powered by AMD GPUs, TGI can be launched with the following command:
```bash
model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--device=/dev/kfd --device=/dev/dri --group-add video \
--ipc=host --shm-size 256g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.4.1-rocm \
MI300 compatibility (#1764) Adds support for AMD Instinct MI300 in TGI. Most changes are: * Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable. TunableOp is disabled by default, and can be enabled with `PYTORCH_TUNABLEOP_ENABLED=1`. * Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes from https://github.com/pytorch/pytorch/pull/124362) * Support SILU & Linear custom kernels contributed by AMD * Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/, branching out of a much more recent commit https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308 * Support FA2 Triton kernel as recommended by AMD. Can be used by specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`. * Update dockerfile to ROCm 6.1 By default, TunableOp tuning results are saved in `/data` (e.g. `/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order to avoid to have to rerun the tuning at each `docker run`. Example: ``` Validator,PT_VERSION,2.3.0 Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c Validator,HIPBLASLT_VERSION,0.7.0-1549b021 Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack- Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098 GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431 GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546 GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119 GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645 GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971 GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694 GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522 GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671 GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834 GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622 GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122 GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191 GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514 GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914 GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516 GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953 GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043 GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497 GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895 GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716 GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731 GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816 GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701 GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159 GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524 GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074 GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045 GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582 GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705 GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489 ``` --------- Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
2024-05-17 07:30:47 -06:00
--model-id $model
```
The launched TGI server can then be queried from clients, make sure to check out the [Consuming TGI](./basic_tutorials/consuming_tgi) guide.
## TunableOp
TGI's docker image for AMD GPUs integrates [PyTorch's TunableOp](https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable), which allows to do an additional warmup to select the best performing matrix multiplication (GEMM) kernel from rocBLAS or hipBLASLt.
Experimentally, on MI300X, we noticed a 6-8% latency improvement when using TunableOp on top of ROCm 6.1 and PyTorch 2.3.
TunableOp is enabled by default, the warmup may take 1-2 minutes. In case you would like to disable TunableOp, please pass `--env PYTORCH_TUNABLEOP_ENABLED="0"` when launcher TGI's docker container.
## Flash attention implementation
fix: update triton implementation reference (#2002) # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> PR #1986 moved the location of the `flash_attn_triton.py` file. This PR adjusts sources to changes. ## Before submitting - [x] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2024-06-04 06:26:35 -06:00
Two implementations of Flash Attention are available for ROCm, the first is [ROCm/flash-attention](https://github.com/ROCm/flash-attention) based on a [Composable Kernel](https://github.com/ROCm/composable_kernel) (CK) implementation, and the second is a [Triton implementation](https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/layers/attention/flash_attn_triton.py).
MI300 compatibility (#1764) Adds support for AMD Instinct MI300 in TGI. Most changes are: * Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable. TunableOp is disabled by default, and can be enabled with `PYTORCH_TUNABLEOP_ENABLED=1`. * Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes from https://github.com/pytorch/pytorch/pull/124362) * Support SILU & Linear custom kernels contributed by AMD * Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/, branching out of a much more recent commit https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308 * Support FA2 Triton kernel as recommended by AMD. Can be used by specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`. * Update dockerfile to ROCm 6.1 By default, TunableOp tuning results are saved in `/data` (e.g. `/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order to avoid to have to rerun the tuning at each `docker run`. Example: ``` Validator,PT_VERSION,2.3.0 Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c Validator,HIPBLASLT_VERSION,0.7.0-1549b021 Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack- Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098 GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431 GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546 GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119 GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645 GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971 GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694 GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522 GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671 GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834 GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622 GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122 GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191 GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514 GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914 GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516 GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953 GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043 GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497 GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895 GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716 GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731 GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816 GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701 GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159 GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524 GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074 GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045 GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582 GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705 GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489 ``` --------- Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
2024-05-17 07:30:47 -06:00
By default, the Composable Kernel implementation is used. However, the Triton implementation has slightly lower latency on MI250 and MI300, but requires a warmup which can be prohibitive as it needs to be done again for each new prompt length. If needed, FA Triton impelmentation can be enabled with `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container.
MI300 compatibility (#1764) Adds support for AMD Instinct MI300 in TGI. Most changes are: * Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable. TunableOp is disabled by default, and can be enabled with `PYTORCH_TUNABLEOP_ENABLED=1`. * Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes from https://github.com/pytorch/pytorch/pull/124362) * Support SILU & Linear custom kernels contributed by AMD * Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/, branching out of a much more recent commit https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308 * Support FA2 Triton kernel as recommended by AMD. Can be used by specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`. * Update dockerfile to ROCm 6.1 By default, TunableOp tuning results are saved in `/data` (e.g. `/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order to avoid to have to rerun the tuning at each `docker run`. Example: ``` Validator,PT_VERSION,2.3.0 Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c Validator,HIPBLASLT_VERSION,0.7.0-1549b021 Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack- Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098 GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431 GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546 GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119 GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645 GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971 GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694 GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522 GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671 GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834 GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622 GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122 GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191 GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514 GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914 GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516 GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953 GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043 GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497 GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895 GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716 GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731 GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816 GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701 GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159 GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524 GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074 GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045 GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582 GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705 GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489 ``` --------- Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
2024-05-17 07:30:47 -06:00
## Custom PagedAttention
For better performance on ROCm, a custom Paged Attention kernel is available and is enabled by default. To disable it and fall back to the PagedAttention v2 kernel, set the environment variable `ROCM_USE_CUSTOM_PAGED_ATTN=0`.
The custom kernel supports bf16 and fp16 data types, block size of 16, head size of 128, a maximum context length of 16k, and GQA ratios between 1 and 16. For other configurations, we use the PagedAttention v2 kernel.
MI300 compatibility (#1764) Adds support for AMD Instinct MI300 in TGI. Most changes are: * Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable. TunableOp is disabled by default, and can be enabled with `PYTORCH_TUNABLEOP_ENABLED=1`. * Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes from https://github.com/pytorch/pytorch/pull/124362) * Support SILU & Linear custom kernels contributed by AMD * Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/, branching out of a much more recent commit https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308 * Support FA2 Triton kernel as recommended by AMD. Can be used by specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`. * Update dockerfile to ROCm 6.1 By default, TunableOp tuning results are saved in `/data` (e.g. `/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order to avoid to have to rerun the tuning at each `docker run`. Example: ``` Validator,PT_VERSION,2.3.0 Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c Validator,HIPBLASLT_VERSION,0.7.0-1549b021 Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack- Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098 GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431 GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546 GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119 GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645 GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971 GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694 GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522 GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671 GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834 GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622 GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122 GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191 GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514 GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914 GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516 GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953 GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043 GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497 GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895 GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716 GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731 GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816 GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701 GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159 GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524 GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074 GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045 GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582 GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705 GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489 ``` --------- Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
2024-05-17 07:30:47 -06:00
## Unsupported features
The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
* Loading [AWQ](https://huggingface.co/docs/transformers/quantization#awq) checkpoints.
* Kernel for sliding window attention (Mistral)