From 26b3916612bb85067b8328d988138f67453a89e1 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Mon, 22 Apr 2024 16:09:19 +0200 Subject: [PATCH] Make `--cuda-graphs` work as expected (bis) (#1768) This was ignored up to now, even with `--cuda-graphs 0`. With this fix, `--cuda-graphs` is obeyed to. --- launcher/src/main.rs | 1 + server/text_generation_server/models/flash_causal_lm.py | 2 ++ server/text_generation_server/models/globals.py | 5 ++++- server/text_generation_server/models/mamba.py | 2 ++ 4 files changed, 9 insertions(+), 1 deletion(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index d904f91b..40e7364f 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1379,6 +1379,7 @@ fn main() -> Result<(), LauncherError> { let cuda_graphs = match (&args.cuda_graphs, &args.quantize) { (Some(cuda_graphs), Some(_q)) => cuda_graphs.clone(), + (Some(cuda_graphs), None) => cuda_graphs.clone(), #[allow(deprecated)] ( None, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 2c440083..1189ccdd 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -816,6 +816,8 @@ class FlashCausalLM(Model): self.cuda_graph_warmup(bs, max_s, max_bt) except torch.cuda.OutOfMemoryError: logger.exception(f"Decode cuda graph warmup failed") + else: + logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") return int(num_blocks * BLOCK_SIZE) diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 6f554049..91b4225a 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -4,11 +4,14 @@ import os MEM_POOL = torch.cuda.graph_pool_handle() # This is overridden by the cli cuda_graphs = os.getenv("CUDA_GRAPHS") -if cuda_graphs is not None: +if cuda_graphs is not None and cuda_graphs != "0": try: cuda_graphs = [int(item) for item in cuda_graphs.split(",")] except Exception as e: raise RuntimeError( f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}" ) +else: + cuda_graphs = None + CUDA_GRAPHS = cuda_graphs diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 07a81491..0884317e 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -474,6 +474,8 @@ class Mamba(Model): self.cuda_graph_warmup(bs) except Exception: logger.exception(f"Decode cuda graph warmup failed") + else: + logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") return None