style
This commit is contained in:
parent
ff0505e7f9
commit
88e2997b9c
|
@ -152,6 +152,7 @@ ENV HIP_FORCE_DEV_KERNARG=1
|
|||
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
|
||||
# However, Triton requires a tunning for each prompt length, which is prohibitive.
|
||||
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
|
||||
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
|
||||
|
||||
FROM base AS kernel-builder
|
||||
|
||||
|
@ -246,3 +247,4 @@ COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
|||
RUN chmod +x /tgi-entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||
CMD ["--json-output"]
|
||||
|
|
|
@ -31,6 +31,12 @@ Two implementations of Flash Attention are available for ROCm, the first is [ROC
|
|||
|
||||
By default, the Composable Kernel implementation is used. However, the Triton implementation has slightly lower latency on MI250 and MI300, but requires a warmup which can be prohibitive as it needs to be done again for each new prompt length. If needed, FA Triton impelmentation can be enabled with `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container.
|
||||
|
||||
## Custom PagedAttention
|
||||
|
||||
For better performance on ROCm, a custom Paged Attention kernel is available and is enabled by default. To disable it and fall back to the PagedAttention v2 kernel, set the environment variable `ROCM_USE_CUSTOM_PAGED_ATTN=0`.
|
||||
|
||||
The custom kernel supports bf16 and fp16 data types, block size of 16, head size of 128, a maximum context length of 16k, and GQA ratios between 1 and 16. For other configurations, we use the PagedAttention v2 kernel.
|
||||
|
||||
## Unsupported features
|
||||
|
||||
The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
|
||||
|
|
|
@ -15,7 +15,7 @@ _PARTITION_SIZE_CUSTOM = 256
|
|||
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
|
||||
ENGINE = "triton" if use_triton else "ck"
|
||||
|
||||
custom_attn_available = os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "1") != "0"
|
||||
custom_attn_available = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
|
||||
if custom_attn_available:
|
||||
from vllm._custom_C import paged_attention_custom
|
||||
|
||||
|
|
Loading…
Reference in New Issue