Hotfixing auto length (warmup max_s was wrong). (#2716)
This commit is contained in:
parent
08c4184eb2
commit
a5593ba83e
|
@ -1687,13 +1687,6 @@ fn main() -> Result<(), LauncherError> {
|
||||||
let max_position_embeddings = if let Some(config) = &config {
|
let max_position_embeddings = if let Some(config) = &config {
|
||||||
if let Some(max_position_embeddings) = config.max_position_embeddings {
|
if let Some(max_position_embeddings) = config.max_position_embeddings {
|
||||||
if max_position_embeddings > max_default {
|
if max_position_embeddings > max_default {
|
||||||
let max = max_position_embeddings;
|
|
||||||
if args.max_input_tokens.is_none()
|
|
||||||
&& args.max_total_tokens.is_none()
|
|
||||||
&& args.max_batch_prefill_tokens.is_none()
|
|
||||||
{
|
|
||||||
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
|
||||||
}
|
|
||||||
max_default
|
max_default
|
||||||
} else {
|
} else {
|
||||||
max_position_embeddings
|
max_position_embeddings
|
||||||
|
|
|
@ -1532,8 +1532,6 @@ class FlashCausalLM(Model):
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
max_bt = batch.max_blocks
|
|
||||||
max_s = max_bt * BLOCK_SIZE
|
|
||||||
batch_num_blocks = batch.num_blocks
|
batch_num_blocks = batch.num_blocks
|
||||||
|
|
||||||
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
||||||
|
@ -1651,7 +1649,7 @@ class FlashCausalLM(Model):
|
||||||
# Warmup cuda graphs
|
# Warmup cuda graphs
|
||||||
for bs in CUDA_GRAPHS:
|
for bs in CUDA_GRAPHS:
|
||||||
if self.speculate is None or self.speculate + 1 <= bs:
|
if self.speculate is None or self.speculate + 1 <= bs:
|
||||||
self.cuda_graph_warmup(bs, max_s, max_bt)
|
self.cuda_graph_warmup(bs, max_total_tokens, max_total_tokens)
|
||||||
except torch.cuda.OutOfMemoryError:
|
except torch.cuda.OutOfMemoryError:
|
||||||
logger.exception("Decode cuda graph warmup failed")
|
logger.exception("Decode cuda graph warmup failed")
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -55,7 +55,7 @@ def block_tables_to_ragged(
|
||||||
cache_lengths: List[int],
|
cache_lengths: List[int],
|
||||||
input_lengths_tensor: torch.Tensor,
|
input_lengths_tensor: torch.Tensor,
|
||||||
cache_lengths_tensor: torch.Tensor,
|
cache_lengths_tensor: torch.Tensor,
|
||||||
max_current_length: int
|
max_current_length: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Convert block table to ragged format compatible with FlashInfer."""
|
"""Convert block table to ragged format compatible with FlashInfer."""
|
||||||
assert len(input_lengths) == len(cache_lengths)
|
assert len(input_lengths) == len(cache_lengths)
|
||||||
|
|
Loading…
Reference in New Issue