From 04d4765bad5707458955189fbf39e8b485de5cbd Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 30 Apr 2024 11:39:38 +0200 Subject: [PATCH] Small CI cleanup. (#1801) # What does this PR do? Just unifying some branches and making intentions clearer (no cuda graph when 0 all the way in the launcher) Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- launcher/src/main.rs | 12 ++++++++---- server/text_generation_server/models/globals.py | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index ca6aa8dd..23944f40 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1284,7 +1284,7 @@ fn main() -> Result<(), LauncherError> { tracing::info!("{}", env_runtime); } - tracing::info!("{:?}", args); + tracing::info!("{:#?}", args); let get_max_position_embeddings = || -> Result> { let model_id = args.model_id.clone(); @@ -1317,7 +1317,12 @@ fn main() -> Result<(), LauncherError> { (Some(max_position_embeddings), _) | (None, Some(max_position_embeddings)) => { if max_position_embeddings > max_default { let max = max_position_embeddings; - tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); + if args.max_input_tokens.is_none() + && args.max_total_tokens.is_none() + && args.max_batch_prefill_tokens.is_none() + { + tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); + } max_default } else { max_position_embeddings @@ -1389,8 +1394,7 @@ fn main() -> Result<(), LauncherError> { } let cuda_graphs = match (&args.cuda_graphs, &args.quantize) { - (Some(cuda_graphs), Some(_q)) => cuda_graphs.clone(), - (Some(cuda_graphs), None) => cuda_graphs.clone(), + (Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(), #[allow(deprecated)] ( None, diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index b92aa65b..6f8d1017 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -4,7 +4,7 @@ import os MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli cuda_graphs = os.getenv("CUDA_GRAPHS") -if torch.cuda.is_available() and cuda_graphs is not None and cuda_graphs != "0": +if cuda_graphs is not None: try: cuda_graphs = [int(item) for item in cuda_graphs.split(",")] except Exception as e: