From 3f2dc6150009b74c71d3cfeb6ea9d12bbc2f9f7c Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Mon, 9 Sep 2024 10:13:59 +0000 Subject: [PATCH] fix style --- Dockerfile_amd | 1 + docs/source/installation_amd.md | 4 ++++ server/text_generation_server/layers/linear.py | 6 +++++- server/text_generation_server/models/flash_causal_lm.py | 3 ++- 4 files changed, 12 insertions(+), 2 deletions(-) diff --git a/Dockerfile_amd b/Dockerfile_amd index 90a1ddf5..f01b160d 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -153,6 +153,7 @@ ENV HIP_FORCE_DEV_KERNARG=1 # However, Triton requires a tunning for each prompt length, which is prohibitive. ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0 ENV ROCM_USE_CUSTOM_PAGED_ATTN=1 +ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0 FROM base AS kernel-builder diff --git a/docs/source/installation_amd.md b/docs/source/installation_amd.md index 8bf60830..070e268e 100644 --- a/docs/source/installation_amd.md +++ b/docs/source/installation_amd.md @@ -25,6 +25,10 @@ Experimentally, on MI300X, we noticed a 6-8% latency improvement when using Tuna TunableOp is enabled by default, the warmup may take 1-2 minutes. In case you would like to disable TunableOp, please pass `--env PYTORCH_TUNABLEOP_ENABLED="0"` when launcher TGI's docker container. +TunableOps tuning is disabled by default after the warmup phase. If you wish to keep tuning enabled for the entire run, set the environment variable `PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=1`. + +Note: With tuning enabled, every time a new input shape is encountered, tuning will be performed, which can slow down the first inference for that shape. + ## Flash attention implementation Two implementations of Flash Attention are available for ROCm, the first is [ROCm/flash-attention](https://github.com/ROCm/flash-attention) based on a [Composable Kernel](https://github.com/ROCm/composable_kernel) (CK) implementation, and the second is a [Triton implementation](https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/layers/attention/flash_attn_triton.py). diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index 12d7f83a..78815d74 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -61,7 +61,11 @@ class FastLinearROCm(torch.nn.Module): weight = self.weight bias = self.bias - if SYSTEM == "rocm" and inp.numel() // inp.shape[-1] == 1: + if ( + SYSTEM == "rocm" + and inp.numel() // inp.shape[-1] == 1 + and inp.dtype == torch.float16 + ): batched = False inp_shape = inp.shape diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 35388f49..7a3f57ab 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1175,7 +1175,8 @@ class FlashCausalLM(Model): log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}") self.tunableop_warmup(seqlen) torch.cuda.tunable.write_file(tunableop_filepath) - torch.cuda.tunable.tuning_enable(False) + if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1": + torch.cuda.tunable.tuning_enable(False) else: log_master( logger.info,