add tuned config
This commit is contained in:
parent
50d239ba8f
commit
78776cdd25
|
@ -328,6 +328,13 @@ ENV ATTENTION=paged
|
||||||
ENV USE_PREFIX_CACHING=0
|
ENV USE_PREFIX_CACHING=0
|
||||||
ENV ROCM_USE_SKINNY_GEMM=1
|
ENV ROCM_USE_SKINNY_GEMM=1
|
||||||
|
|
||||||
|
COPY ./rocm_tuned_ops/afo_tune_device_0_full.csv /afo_tune/
|
||||||
|
RUN seq 1 7 | xargs -I{} cp /afo_tune/afo_tune_device_0_full.csv /afo_tune/afo_tune_device_{}_full.csv
|
||||||
|
|
||||||
|
ENV PYTORCH_TUNABLEOP_FILENAME=/afo_tune/afo_tune_device_%d_full.csv
|
||||||
|
ENV PYTORCH_TUNABLEOP_TUNING=0
|
||||||
|
ENV PYTORCH_TUNABLEOP_ENABLED=1
|
||||||
|
|
||||||
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||||
RUN chmod +x /tgi-entrypoint.sh
|
RUN chmod +x /tgi-entrypoint.sh
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,17 @@ TGI's docker image for AMD GPUs integrates [PyTorch's TunableOp](https://github.
|
||||||
|
|
||||||
Experimentally, on MI300X, we noticed a 6-8% latency improvement when using TunableOp on top of ROCm 6.1 and PyTorch 2.3.
|
Experimentally, on MI300X, we noticed a 6-8% latency improvement when using TunableOp on top of ROCm 6.1 and PyTorch 2.3.
|
||||||
|
|
||||||
TunableOp is enabled by default, the warmup may take 1-2 minutes. In case you would like to disable TunableOp, please pass `--env PYTORCH_TUNABLEOP_ENABLED="0"` when launcher TGI's docker container.
|
TunableOp is enabled by default, the warmup may take 1-2 minutes. In case you would like to disable TunableOp, please pass `--env PYTORCH_TUNABLEOP_ENABLED="0"` when launching TGI's docker container.
|
||||||
|
|
||||||
|
TGI's ROCm image comes preloaded with tuned configurations for commonly occurring GEMMs. However, if you want to enable tuning for other GEMMs during the warmup process, please pass `--env PYTORCH_TUNABLEOP_TUNING="1"` when launching TGI's docker container.
|
||||||
|
|
||||||
|
When tuning is enabled, there are two types of tuning performed in TGI:
|
||||||
|
* Decode: Enabled by default.
|
||||||
|
* Prefill: Disabled by default due to its longer warmup time.
|
||||||
|
|
||||||
|
To enable tuning for prefill for specific input lengths: pass `--env PYTORCH_TUNABLEOP_PREFILL_SEQLENS=<seqlen1>,<seqlen2>,...` to docker container.
|
||||||
|
|
||||||
|
Note: if a shape already exist in the tuned configurations, the tuning will be skipped.
|
||||||
|
|
||||||
## Flash attention implementation
|
## Flash attention implementation
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1322,6 +1322,18 @@ class FlashCausalLM(Model):
|
||||||
else:
|
else:
|
||||||
tuning_sequences = [1, 2, 3, 4, 5, 6, 7]
|
tuning_sequences = [1, 2, 3, 4, 5, 6, 7]
|
||||||
|
|
||||||
|
tuning_sequences_prefill = None
|
||||||
|
if (
|
||||||
|
os.environ.get("PYTORCH_TUNABLEOP_PREFILL_SEQLENS") is not None
|
||||||
|
and os.environ.get("PYTORCH_TUNABLEOP_PREFILL_SEQLENS") != ""
|
||||||
|
):
|
||||||
|
tuning_sequences_prefill = [
|
||||||
|
int(val)
|
||||||
|
for val in os.environ[
|
||||||
|
"PYTORCH_TUNABLEOP_PREFILL_SEQLENS"
|
||||||
|
].split(",")
|
||||||
|
]
|
||||||
|
|
||||||
tunableop_filepath = os.path.join(
|
tunableop_filepath = os.path.join(
|
||||||
HUGGINGFACE_HUB_CACHE,
|
HUGGINGFACE_HUB_CACHE,
|
||||||
f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
|
f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
|
||||||
|
@ -1329,11 +1341,9 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
log_master(
|
log_master(
|
||||||
logger.info,
|
logger.info,
|
||||||
f"PyTorch TunableOp is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.",
|
f"PyTorch TunableOp is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target decode lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])} "
|
||||||
)
|
f"and prefill lengths {', '.join([str(seqlen) for seqlen in tuning_sequences_prefill]) if tuning_sequences_prefill is not None else 'N/A'}, "
|
||||||
|
f"with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.",
|
||||||
torch.cuda.tunable.set_filename(
|
|
||||||
tunableop_filepath, insert_device_ordinal=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if os.path.isfile(tunableop_filepath):
|
if os.path.isfile(tunableop_filepath):
|
||||||
|
@ -1346,9 +1356,21 @@ class FlashCausalLM(Model):
|
||||||
os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)
|
os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)
|
||||||
|
|
||||||
for seqlen in tuning_sequences:
|
for seqlen in tuning_sequences:
|
||||||
log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
|
log_master(
|
||||||
|
logger.info, f"Warming up TunableOp for Decode seqlen={seqlen}"
|
||||||
|
)
|
||||||
self.tunableop_warmup(seqlen)
|
self.tunableop_warmup(seqlen)
|
||||||
torch.cuda.tunable.write_file(tunableop_filepath)
|
torch.cuda.tunable.write_file(tunableop_filepath)
|
||||||
|
|
||||||
|
if tuning_sequences_prefill is not None:
|
||||||
|
for seqlen in tuning_sequences_prefill:
|
||||||
|
for bs in tuning_sequences:
|
||||||
|
log_master(
|
||||||
|
logger.info,
|
||||||
|
f"Warming up TunableOp for Prefill seqlen={seqlen}, bs={bs}",
|
||||||
|
)
|
||||||
|
self.tunableop_prefill_warmup(bs, seqlen)
|
||||||
|
torch.cuda.tunable.write_file(tunableop_filepath)
|
||||||
if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1":
|
if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1":
|
||||||
torch.cuda.tunable.tuning_enable(False)
|
torch.cuda.tunable.tuning_enable(False)
|
||||||
else:
|
else:
|
||||||
|
@ -1409,6 +1431,40 @@ class FlashCausalLM(Model):
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def tunableop_prefill_warmup(self, bs: int, seqlen: int):
|
||||||
|
input_ids = torch.zeros(seqlen * bs, dtype=torch.int64, device=self.device)
|
||||||
|
position_ids = torch.zeros(seqlen * bs, dtype=torch.int32, device=self.device)
|
||||||
|
slots = torch.arange(seqlen * bs + 1, dtype=torch.int64, device=self.device)
|
||||||
|
|
||||||
|
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||||
|
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * seqlen
|
||||||
|
prefix_lens_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||||
|
|
||||||
|
cu_seqlen_prefill = torch.arange(
|
||||||
|
0, (bs + 1) * seqlen, seqlen, device=self.device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
max_s = seqlen
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
prefix_lengths=prefix_lens_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
|
max_q=seqlen,
|
||||||
|
max_k=seqlen,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=self.kv_cache,
|
||||||
|
block_tables=None,
|
||||||
|
seqlen=seqlen,
|
||||||
|
slots=slots,
|
||||||
|
max_s=max_s,
|
||||||
|
lm_head_indices=None,
|
||||||
|
prefill_cache_indices=None,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
|
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
|
Loading…
Reference in New Issue