add tuned config
This commit is contained in:
parent
50d239ba8f
commit
78776cdd25
|
@ -328,6 +328,13 @@ ENV ATTENTION=paged
|
|||
ENV USE_PREFIX_CACHING=0
|
||||
ENV ROCM_USE_SKINNY_GEMM=1
|
||||
|
||||
COPY ./rocm_tuned_ops/afo_tune_device_0_full.csv /afo_tune/
|
||||
RUN seq 1 7 | xargs -I{} cp /afo_tune/afo_tune_device_0_full.csv /afo_tune/afo_tune_device_{}_full.csv
|
||||
|
||||
ENV PYTORCH_TUNABLEOP_FILENAME=/afo_tune/afo_tune_device_%d_full.csv
|
||||
ENV PYTORCH_TUNABLEOP_TUNING=0
|
||||
ENV PYTORCH_TUNABLEOP_ENABLED=1
|
||||
|
||||
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||
RUN chmod +x /tgi-entrypoint.sh
|
||||
|
||||
|
|
|
@ -23,7 +23,17 @@ TGI's docker image for AMD GPUs integrates [PyTorch's TunableOp](https://github.
|
|||
|
||||
Experimentally, on MI300X, we noticed a 6-8% latency improvement when using TunableOp on top of ROCm 6.1 and PyTorch 2.3.
|
||||
|
||||
TunableOp is enabled by default, the warmup may take 1-2 minutes. In case you would like to disable TunableOp, please pass `--env PYTORCH_TUNABLEOP_ENABLED="0"` when launcher TGI's docker container.
|
||||
TunableOp is enabled by default, the warmup may take 1-2 minutes. In case you would like to disable TunableOp, please pass `--env PYTORCH_TUNABLEOP_ENABLED="0"` when launching TGI's docker container.
|
||||
|
||||
TGI's ROCm image comes preloaded with tuned configurations for commonly occurring GEMMs. However, if you want to enable tuning for other GEMMs during the warmup process, please pass `--env PYTORCH_TUNABLEOP_TUNING="1"` when launching TGI's docker container.
|
||||
|
||||
When tuning is enabled, there are two types of tuning performed in TGI:
|
||||
* Decode: Enabled by default.
|
||||
* Prefill: Disabled by default due to its longer warmup time.
|
||||
|
||||
To enable tuning for prefill for specific input lengths: pass `--env PYTORCH_TUNABLEOP_PREFILL_SEQLENS=<seqlen1>,<seqlen2>,...` to docker container.
|
||||
|
||||
Note: if a shape already exist in the tuned configurations, the tuning will be skipped.
|
||||
|
||||
## Flash attention implementation
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1322,6 +1322,18 @@ class FlashCausalLM(Model):
|
|||
else:
|
||||
tuning_sequences = [1, 2, 3, 4, 5, 6, 7]
|
||||
|
||||
tuning_sequences_prefill = None
|
||||
if (
|
||||
os.environ.get("PYTORCH_TUNABLEOP_PREFILL_SEQLENS") is not None
|
||||
and os.environ.get("PYTORCH_TUNABLEOP_PREFILL_SEQLENS") != ""
|
||||
):
|
||||
tuning_sequences_prefill = [
|
||||
int(val)
|
||||
for val in os.environ[
|
||||
"PYTORCH_TUNABLEOP_PREFILL_SEQLENS"
|
||||
].split(",")
|
||||
]
|
||||
|
||||
tunableop_filepath = os.path.join(
|
||||
HUGGINGFACE_HUB_CACHE,
|
||||
f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
|
||||
|
@ -1329,11 +1341,9 @@ class FlashCausalLM(Model):
|
|||
|
||||
log_master(
|
||||
logger.info,
|
||||
f"PyTorch TunableOp is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.",
|
||||
)
|
||||
|
||||
torch.cuda.tunable.set_filename(
|
||||
tunableop_filepath, insert_device_ordinal=False
|
||||
f"PyTorch TunableOp is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target decode lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])} "
|
||||
f"and prefill lengths {', '.join([str(seqlen) for seqlen in tuning_sequences_prefill]) if tuning_sequences_prefill is not None else 'N/A'}, "
|
||||
f"with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.",
|
||||
)
|
||||
|
||||
if os.path.isfile(tunableop_filepath):
|
||||
|
@ -1346,9 +1356,21 @@ class FlashCausalLM(Model):
|
|||
os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)
|
||||
|
||||
for seqlen in tuning_sequences:
|
||||
log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
|
||||
log_master(
|
||||
logger.info, f"Warming up TunableOp for Decode seqlen={seqlen}"
|
||||
)
|
||||
self.tunableop_warmup(seqlen)
|
||||
torch.cuda.tunable.write_file(tunableop_filepath)
|
||||
|
||||
if tuning_sequences_prefill is not None:
|
||||
for seqlen in tuning_sequences_prefill:
|
||||
for bs in tuning_sequences:
|
||||
log_master(
|
||||
logger.info,
|
||||
f"Warming up TunableOp for Prefill seqlen={seqlen}, bs={bs}",
|
||||
)
|
||||
self.tunableop_prefill_warmup(bs, seqlen)
|
||||
torch.cuda.tunable.write_file(tunableop_filepath)
|
||||
if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1":
|
||||
torch.cuda.tunable.tuning_enable(False)
|
||||
else:
|
||||
|
@ -1409,6 +1431,40 @@ class FlashCausalLM(Model):
|
|||
prefill_cache_indices=None,
|
||||
)
|
||||
|
||||
def tunableop_prefill_warmup(self, bs: int, seqlen: int):
|
||||
input_ids = torch.zeros(seqlen * bs, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.zeros(seqlen * bs, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(seqlen * bs + 1, dtype=torch.int64, device=self.device)
|
||||
|
||||
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * seqlen
|
||||
prefix_lens_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
|
||||
cu_seqlen_prefill = torch.arange(
|
||||
0, (bs + 1) * seqlen, seqlen, device=self.device, dtype=torch.int32
|
||||
)
|
||||
max_s = seqlen
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
prefix_lengths=prefix_lens_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
max_q=seqlen,
|
||||
max_k=seqlen,
|
||||
)
|
||||
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=self.kv_cache,
|
||||
block_tables=None,
|
||||
seqlen=seqlen,
|
||||
slots=slots,
|
||||
max_s=max_s,
|
||||
lm_head_indices=None,
|
||||
prefill_cache_indices=None,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
|
Loading…
Reference in New Issue