From 88e2997b9c47b725dcb48d68b5c8ea448b45d591 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Fri, 6 Sep 2024 12:23:18 +0000 Subject: [PATCH] style --- Dockerfile_amd | 2 ++ docs/source/installation_amd.md | 6 ++++++ server/text_generation_server/layers/attention/rocm.py | 2 +- 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/Dockerfile_amd b/Dockerfile_amd index fd612af5..90a1ddf5 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -152,6 +152,7 @@ ENV HIP_FORCE_DEV_KERNARG=1 # On MI250 and MI300, performances for flash with Triton FA are slightly better than CK. # However, Triton requires a tunning for each prompt length, which is prohibitive. ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0 +ENV ROCM_USE_CUSTOM_PAGED_ATTN=1 FROM base AS kernel-builder @@ -246,3 +247,4 @@ COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh ENTRYPOINT ["/tgi-entrypoint.sh"] +CMD ["--json-output"] diff --git a/docs/source/installation_amd.md b/docs/source/installation_amd.md index 931a9e3a..8bf60830 100644 --- a/docs/source/installation_amd.md +++ b/docs/source/installation_amd.md @@ -31,6 +31,12 @@ Two implementations of Flash Attention are available for ROCm, the first is [ROC By default, the Composable Kernel implementation is used. However, the Triton implementation has slightly lower latency on MI250 and MI300, but requires a warmup which can be prohibitive as it needs to be done again for each new prompt length. If needed, FA Triton impelmentation can be enabled with `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container. +## Custom PagedAttention + +For better performance on ROCm, a custom Paged Attention kernel is available and is enabled by default. To disable it and fall back to the PagedAttention v2 kernel, set the environment variable `ROCM_USE_CUSTOM_PAGED_ATTN=0`. + +The custom kernel supports bf16 and fp16 data types, block size of 16, head size of 128, a maximum context length of 16k, and GQA ratios between 1 and 16. For other configurations, we use the PagedAttention v2 kernel. + ## Unsupported features The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future: diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index bd033017..58165dc7 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -15,7 +15,7 @@ _PARTITION_SIZE_CUSTOM = 256 use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"} ENGINE = "triton" if use_triton else "ck" -custom_attn_available = os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "1") != "0" +custom_attn_available = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0" if custom_attn_available: from vllm._custom_C import paged_attention_custom