add tuned config

This commit is contained in:
Mohit Sharma 2024-10-03 11:59:14 +00:00
parent 50d239ba8f
commit 78776cdd25
4 changed files with 16685 additions and 7 deletions

View File

@ -328,6 +328,13 @@ ENV ATTENTION=paged
ENV USE_PREFIX_CACHING=0
ENV ROCM_USE_SKINNY_GEMM=1
COPY ./rocm_tuned_ops/afo_tune_device_0_full.csv /afo_tune/
RUN seq 1 7 | xargs -I{} cp /afo_tune/afo_tune_device_0_full.csv /afo_tune/afo_tune_device_{}_full.csv
ENV PYTORCH_TUNABLEOP_FILENAME=/afo_tune/afo_tune_device_%d_full.csv
ENV PYTORCH_TUNABLEOP_TUNING=0
ENV PYTORCH_TUNABLEOP_ENABLED=1
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh

View File

@ -23,7 +23,17 @@ TGI's docker image for AMD GPUs integrates [PyTorch's TunableOp](https://github.
Experimentally, on MI300X, we noticed a 6-8% latency improvement when using TunableOp on top of ROCm 6.1 and PyTorch 2.3.
TunableOp is enabled by default, the warmup may take 1-2 minutes. In case you would like to disable TunableOp, please pass `--env PYTORCH_TUNABLEOP_ENABLED="0"` when launcher TGI's docker container.
TunableOp is enabled by default, the warmup may take 1-2 minutes. In case you would like to disable TunableOp, please pass `--env PYTORCH_TUNABLEOP_ENABLED="0"` when launching TGI's docker container.
TGI's ROCm image comes preloaded with tuned configurations for commonly occurring GEMMs. However, if you want to enable tuning for other GEMMs during the warmup process, please pass `--env PYTORCH_TUNABLEOP_TUNING="1"` when launching TGI's docker container.
When tuning is enabled, there are two types of tuning performed in TGI:
* Decode: Enabled by default.
* Prefill: Disabled by default due to its longer warmup time.
To enable tuning for prefill for specific input lengths: pass `--env PYTORCH_TUNABLEOP_PREFILL_SEQLENS=<seqlen1>,<seqlen2>,...` to docker container.
Note: if a shape already exist in the tuned configurations, the tuning will be skipped.
## Flash attention implementation

File diff suppressed because it is too large Load Diff

View File

@ -1322,6 +1322,18 @@ class FlashCausalLM(Model):
else:
tuning_sequences = [1, 2, 3, 4, 5, 6, 7]
tuning_sequences_prefill = None
if (
os.environ.get("PYTORCH_TUNABLEOP_PREFILL_SEQLENS") is not None
and os.environ.get("PYTORCH_TUNABLEOP_PREFILL_SEQLENS") != ""
):
tuning_sequences_prefill = [
int(val)
for val in os.environ[
"PYTORCH_TUNABLEOP_PREFILL_SEQLENS"
].split(",")
]
tunableop_filepath = os.path.join(
HUGGINGFACE_HUB_CACHE,
f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
@ -1329,11 +1341,9 @@ class FlashCausalLM(Model):
log_master(
logger.info,
f"PyTorch TunableOp is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.",
)
torch.cuda.tunable.set_filename(
tunableop_filepath, insert_device_ordinal=False
f"PyTorch TunableOp is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target decode lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])} "
f"and prefill lengths {', '.join([str(seqlen) for seqlen in tuning_sequences_prefill]) if tuning_sequences_prefill is not None else 'N/A'}, "
f"with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.",
)
if os.path.isfile(tunableop_filepath):
@ -1346,9 +1356,21 @@ class FlashCausalLM(Model):
os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)
for seqlen in tuning_sequences:
log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
log_master(
logger.info, f"Warming up TunableOp for Decode seqlen={seqlen}"
)
self.tunableop_warmup(seqlen)
torch.cuda.tunable.write_file(tunableop_filepath)
if tuning_sequences_prefill is not None:
for seqlen in tuning_sequences_prefill:
for bs in tuning_sequences:
log_master(
logger.info,
f"Warming up TunableOp for Prefill seqlen={seqlen}, bs={bs}",
)
self.tunableop_prefill_warmup(bs, seqlen)
torch.cuda.tunable.write_file(tunableop_filepath)
if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1":
torch.cuda.tunable.tuning_enable(False)
else:
@ -1409,6 +1431,40 @@ class FlashCausalLM(Model):
prefill_cache_indices=None,
)
def tunableop_prefill_warmup(self, bs: int, seqlen: int):
input_ids = torch.zeros(seqlen * bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(seqlen * bs, dtype=torch.int32, device=self.device)
slots = torch.arange(seqlen * bs + 1, dtype=torch.int64, device=self.device)
# Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * seqlen
prefix_lens_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
cu_seqlen_prefill = torch.arange(
0, (bs + 1) * seqlen, seqlen, device=self.device, dtype=torch.int32
)
max_s = seqlen
seqlen = Seqlen(
input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=seqlen,
max_k=seqlen,
)
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=self.kv_cache,
block_tables=None,
seqlen=seqlen,
slots=slots,
max_s=max_s,
lm_head_indices=None,
prefill_cache_indices=None,
)
def forward(
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: