Hotfixing auto length (warmup max_s was wrong). (#2716)
This commit is contained in:
parent
08c4184eb2
commit
a5593ba83e
|
@ -1687,13 +1687,6 @@ fn main() -> Result<(), LauncherError> {
|
|||
let max_position_embeddings = if let Some(config) = &config {
|
||||
if let Some(max_position_embeddings) = config.max_position_embeddings {
|
||||
if max_position_embeddings > max_default {
|
||||
let max = max_position_embeddings;
|
||||
if args.max_input_tokens.is_none()
|
||||
&& args.max_total_tokens.is_none()
|
||||
&& args.max_batch_prefill_tokens.is_none()
|
||||
{
|
||||
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
||||
}
|
||||
max_default
|
||||
} else {
|
||||
max_position_embeddings
|
||||
|
|
|
@ -1532,8 +1532,6 @@ class FlashCausalLM(Model):
|
|||
self.kv_cache_dtype,
|
||||
self.device,
|
||||
)
|
||||
max_bt = batch.max_blocks
|
||||
max_s = max_bt * BLOCK_SIZE
|
||||
batch_num_blocks = batch.num_blocks
|
||||
|
||||
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
||||
|
@ -1651,7 +1649,7 @@ class FlashCausalLM(Model):
|
|||
# Warmup cuda graphs
|
||||
for bs in CUDA_GRAPHS:
|
||||
if self.speculate is None or self.speculate + 1 <= bs:
|
||||
self.cuda_graph_warmup(bs, max_s, max_bt)
|
||||
self.cuda_graph_warmup(bs, max_total_tokens, max_total_tokens)
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
logger.exception("Decode cuda graph warmup failed")
|
||||
else:
|
||||
|
|
|
@ -55,7 +55,7 @@ def block_tables_to_ragged(
|
|||
cache_lengths: List[int],
|
||||
input_lengths_tensor: torch.Tensor,
|
||||
cache_lengths_tensor: torch.Tensor,
|
||||
max_current_length: int
|
||||
max_current_length: int,
|
||||
) -> torch.Tensor:
|
||||
"""Convert block table to ragged format compatible with FlashInfer."""
|
||||
assert len(input_lengths) == len(cache_lengths)
|
||||
|
|
Loading…
Reference in New Issue