diff --git a/launcher/src/main.rs b/launcher/src/main.rs index ca6aa8dd..23944f40 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1284,7 +1284,7 @@ fn main() -> Result<(), LauncherError> { tracing::info!("{}", env_runtime); } - tracing::info!("{:?}", args); + tracing::info!("{:#?}", args); let get_max_position_embeddings = || -> Result> { let model_id = args.model_id.clone(); @@ -1317,7 +1317,12 @@ fn main() -> Result<(), LauncherError> { (Some(max_position_embeddings), _) | (None, Some(max_position_embeddings)) => { if max_position_embeddings > max_default { let max = max_position_embeddings; - tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); + if args.max_input_tokens.is_none() + && args.max_total_tokens.is_none() + && args.max_batch_prefill_tokens.is_none() + { + tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); + } max_default } else { max_position_embeddings @@ -1389,8 +1394,7 @@ fn main() -> Result<(), LauncherError> { } let cuda_graphs = match (&args.cuda_graphs, &args.quantize) { - (Some(cuda_graphs), Some(_q)) => cuda_graphs.clone(), - (Some(cuda_graphs), None) => cuda_graphs.clone(), + (Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(), #[allow(deprecated)] ( None, diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index b92aa65b..6f8d1017 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -4,7 +4,7 @@ import os MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli cuda_graphs = os.getenv("CUDA_GRAPHS") -if torch.cuda.is_available() and cuda_graphs is not None and cuda_graphs != "0": +if cuda_graphs is not None: try: cuda_graphs = [int(item) for item in cuda_graphs.split(",")] except Exception as e: