This commit is contained in:
Mohit Sharma 2024-09-06 12:23:18 +00:00
parent ff0505e7f9
commit 88e2997b9c
3 changed files with 9 additions and 1 deletions

View File

@ -152,6 +152,7 @@ ENV HIP_FORCE_DEV_KERNARG=1
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
# However, Triton requires a tunning for each prompt length, which is prohibitive.
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
FROM base AS kernel-builder
@ -246,3 +247,4 @@ COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh
ENTRYPOINT ["/tgi-entrypoint.sh"]
CMD ["--json-output"]

View File

@ -31,6 +31,12 @@ Two implementations of Flash Attention are available for ROCm, the first is [ROC
By default, the Composable Kernel implementation is used. However, the Triton implementation has slightly lower latency on MI250 and MI300, but requires a warmup which can be prohibitive as it needs to be done again for each new prompt length. If needed, FA Triton impelmentation can be enabled with `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container.
## Custom PagedAttention
For better performance on ROCm, a custom Paged Attention kernel is available and is enabled by default. To disable it and fall back to the PagedAttention v2 kernel, set the environment variable `ROCM_USE_CUSTOM_PAGED_ATTN=0`.
The custom kernel supports bf16 and fp16 data types, block size of 16, head size of 128, a maximum context length of 16k, and GQA ratios between 1 and 16. For other configurations, we use the PagedAttention v2 kernel.
## Unsupported features
The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:

View File

@ -15,7 +15,7 @@ _PARTITION_SIZE_CUSTOM = 256
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
ENGINE = "triton" if use_triton else "ck"
custom_attn_available = os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "1") != "0"
custom_attn_available = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
if custom_attn_available:
from vllm._custom_C import paged_attention_custom