fix style

This commit is contained in:
Mohit Sharma 2024-09-09 10:13:59 +00:00
parent 88e2997b9c
commit 3f2dc61500
4 changed files with 12 additions and 2 deletions

View File

@ -153,6 +153,7 @@ ENV HIP_FORCE_DEV_KERNARG=1
# However, Triton requires a tunning for each prompt length, which is prohibitive. # However, Triton requires a tunning for each prompt length, which is prohibitive.
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0 ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1 ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
FROM base AS kernel-builder FROM base AS kernel-builder

View File

@ -25,6 +25,10 @@ Experimentally, on MI300X, we noticed a 6-8% latency improvement when using Tuna
TunableOp is enabled by default, the warmup may take 1-2 minutes. In case you would like to disable TunableOp, please pass `--env PYTORCH_TUNABLEOP_ENABLED="0"` when launcher TGI's docker container. TunableOp is enabled by default, the warmup may take 1-2 minutes. In case you would like to disable TunableOp, please pass `--env PYTORCH_TUNABLEOP_ENABLED="0"` when launcher TGI's docker container.
TunableOps tuning is disabled by default after the warmup phase. If you wish to keep tuning enabled for the entire run, set the environment variable `PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=1`.
Note: With tuning enabled, every time a new input shape is encountered, tuning will be performed, which can slow down the first inference for that shape.
## Flash attention implementation ## Flash attention implementation
Two implementations of Flash Attention are available for ROCm, the first is [ROCm/flash-attention](https://github.com/ROCm/flash-attention) based on a [Composable Kernel](https://github.com/ROCm/composable_kernel) (CK) implementation, and the second is a [Triton implementation](https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/layers/attention/flash_attn_triton.py). Two implementations of Flash Attention are available for ROCm, the first is [ROCm/flash-attention](https://github.com/ROCm/flash-attention) based on a [Composable Kernel](https://github.com/ROCm/composable_kernel) (CK) implementation, and the second is a [Triton implementation](https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/layers/attention/flash_attn_triton.py).

View File

@ -61,7 +61,11 @@ class FastLinearROCm(torch.nn.Module):
weight = self.weight weight = self.weight
bias = self.bias bias = self.bias
if SYSTEM == "rocm" and inp.numel() // inp.shape[-1] == 1: if (
SYSTEM == "rocm"
and inp.numel() // inp.shape[-1] == 1
and inp.dtype == torch.float16
):
batched = False batched = False
inp_shape = inp.shape inp_shape = inp.shape

View File

@ -1175,7 +1175,8 @@ class FlashCausalLM(Model):
log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}") log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
self.tunableop_warmup(seqlen) self.tunableop_warmup(seqlen)
torch.cuda.tunable.write_file(tunableop_filepath) torch.cuda.tunable.write_file(tunableop_filepath)
torch.cuda.tunable.tuning_enable(False) if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1":
torch.cuda.tunable.tuning_enable(False)
else: else:
log_master( log_master(
logger.info, logger.info,