fix style
This commit is contained in:
parent
88e2997b9c
commit
3f2dc61500
|
@ -153,6 +153,7 @@ ENV HIP_FORCE_DEV_KERNARG=1
|
|||
# However, Triton requires a tunning for each prompt length, which is prohibitive.
|
||||
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
|
||||
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
|
||||
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
|
||||
|
||||
FROM base AS kernel-builder
|
||||
|
||||
|
|
|
@ -25,6 +25,10 @@ Experimentally, on MI300X, we noticed a 6-8% latency improvement when using Tuna
|
|||
|
||||
TunableOp is enabled by default, the warmup may take 1-2 minutes. In case you would like to disable TunableOp, please pass `--env PYTORCH_TUNABLEOP_ENABLED="0"` when launcher TGI's docker container.
|
||||
|
||||
TunableOps tuning is disabled by default after the warmup phase. If you wish to keep tuning enabled for the entire run, set the environment variable `PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=1`.
|
||||
|
||||
Note: With tuning enabled, every time a new input shape is encountered, tuning will be performed, which can slow down the first inference for that shape.
|
||||
|
||||
## Flash attention implementation
|
||||
|
||||
Two implementations of Flash Attention are available for ROCm, the first is [ROCm/flash-attention](https://github.com/ROCm/flash-attention) based on a [Composable Kernel](https://github.com/ROCm/composable_kernel) (CK) implementation, and the second is a [Triton implementation](https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/layers/attention/flash_attn_triton.py).
|
||||
|
|
|
@ -61,7 +61,11 @@ class FastLinearROCm(torch.nn.Module):
|
|||
weight = self.weight
|
||||
bias = self.bias
|
||||
|
||||
if SYSTEM == "rocm" and inp.numel() // inp.shape[-1] == 1:
|
||||
if (
|
||||
SYSTEM == "rocm"
|
||||
and inp.numel() // inp.shape[-1] == 1
|
||||
and inp.dtype == torch.float16
|
||||
):
|
||||
batched = False
|
||||
inp_shape = inp.shape
|
||||
|
||||
|
|
|
@ -1175,7 +1175,8 @@ class FlashCausalLM(Model):
|
|||
log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
|
||||
self.tunableop_warmup(seqlen)
|
||||
torch.cuda.tunable.write_file(tunableop_filepath)
|
||||
torch.cuda.tunable.tuning_enable(False)
|
||||
if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1":
|
||||
torch.cuda.tunable.tuning_enable(False)
|
||||
else:
|
||||
log_master(
|
||||
logger.info,
|
||||
|
|
Loading…
Reference in New Issue