diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 19a79115..64f4f515 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1687,13 +1687,6 @@ fn main() -> Result<(), LauncherError> { let max_position_embeddings = if let Some(config) = &config { if let Some(max_position_embeddings) = config.max_position_embeddings { if max_position_embeddings > max_default { - let max = max_position_embeddings; - if args.max_input_tokens.is_none() - && args.max_total_tokens.is_none() - && args.max_batch_prefill_tokens.is_none() - { - tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); - } max_default } else { max_position_embeddings diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 52ab5d6a..6e905b4a 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1532,8 +1532,6 @@ class FlashCausalLM(Model): self.kv_cache_dtype, self.device, ) - max_bt = batch.max_blocks - max_s = max_bt * BLOCK_SIZE batch_num_blocks = batch.num_blocks if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): @@ -1651,7 +1649,7 @@ class FlashCausalLM(Model): # Warmup cuda graphs for bs in CUDA_GRAPHS: if self.speculate is None or self.speculate + 1 <= bs: - self.cuda_graph_warmup(bs, max_s, max_bt) + self.cuda_graph_warmup(bs, max_total_tokens, max_total_tokens) except torch.cuda.OutOfMemoryError: logger.exception("Decode cuda graph warmup failed") else: diff --git a/server/text_generation_server/models/metadata_kernels.py b/server/text_generation_server/models/metadata_kernels.py index b3e2160d..783aab80 100644 --- a/server/text_generation_server/models/metadata_kernels.py +++ b/server/text_generation_server/models/metadata_kernels.py @@ -55,7 +55,7 @@ def block_tables_to_ragged( cache_lengths: List[int], input_lengths_tensor: torch.Tensor, cache_lengths_tensor: torch.Tensor, - max_current_length: int + max_current_length: int, ) -> torch.Tensor: """Convert block table to ragged format compatible with FlashInfer.""" assert len(input_lengths) == len(cache_lengths)