Hotfixing auto length (warmup max_s was wrong). (#2716)

This commit is contained in:
Nicolas Patry 2024-11-04 16:55:54 +08:00 committed by GitHub
parent 08c4184eb2
commit a5593ba83e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 2 additions and 11 deletions

View File

@ -1687,13 +1687,6 @@ fn main() -> Result<(), LauncherError> {
let max_position_embeddings = if let Some(config) = &config { let max_position_embeddings = if let Some(config) = &config {
if let Some(max_position_embeddings) = config.max_position_embeddings { if let Some(max_position_embeddings) = config.max_position_embeddings {
if max_position_embeddings > max_default { if max_position_embeddings > max_default {
let max = max_position_embeddings;
if args.max_input_tokens.is_none()
&& args.max_total_tokens.is_none()
&& args.max_batch_prefill_tokens.is_none()
{
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
}
max_default max_default
} else { } else {
max_position_embeddings max_position_embeddings

View File

@ -1532,8 +1532,6 @@ class FlashCausalLM(Model):
self.kv_cache_dtype, self.kv_cache_dtype,
self.device, self.device,
) )
max_bt = batch.max_blocks
max_s = max_bt * BLOCK_SIZE
batch_num_blocks = batch.num_blocks batch_num_blocks = batch.num_blocks
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
@ -1651,7 +1649,7 @@ class FlashCausalLM(Model):
# Warmup cuda graphs # Warmup cuda graphs
for bs in CUDA_GRAPHS: for bs in CUDA_GRAPHS:
if self.speculate is None or self.speculate + 1 <= bs: if self.speculate is None or self.speculate + 1 <= bs:
self.cuda_graph_warmup(bs, max_s, max_bt) self.cuda_graph_warmup(bs, max_total_tokens, max_total_tokens)
except torch.cuda.OutOfMemoryError: except torch.cuda.OutOfMemoryError:
logger.exception("Decode cuda graph warmup failed") logger.exception("Decode cuda graph warmup failed")
else: else:

View File

@ -55,7 +55,7 @@ def block_tables_to_ragged(
cache_lengths: List[int], cache_lengths: List[int],
input_lengths_tensor: torch.Tensor, input_lengths_tensor: torch.Tensor,
cache_lengths_tensor: torch.Tensor, cache_lengths_tensor: torch.Tensor,
max_current_length: int max_current_length: int,
) -> torch.Tensor: ) -> torch.Tensor:
"""Convert block table to ragged format compatible with FlashInfer.""" """Convert block table to ragged format compatible with FlashInfer."""
assert len(input_lengths) == len(cache_lengths) assert len(input_lengths) == len(cache_lengths)