Make `--cuda-graphs` work as expected (bis) (#1768)

This was ignored up to now, even with `--cuda-graphs 0`.

With this fix, `--cuda-graphs` is obeyed to.
This commit is contained in:
fxmarty 2024-04-22 16:09:19 +02:00 committed by GitHub
parent 2d0a7173d4
commit 26b3916612
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 9 additions and 1 deletions

View File

@ -1379,6 +1379,7 @@ fn main() -> Result<(), LauncherError> {
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) { let cuda_graphs = match (&args.cuda_graphs, &args.quantize) {
(Some(cuda_graphs), Some(_q)) => cuda_graphs.clone(), (Some(cuda_graphs), Some(_q)) => cuda_graphs.clone(),
(Some(cuda_graphs), None) => cuda_graphs.clone(),
#[allow(deprecated)] #[allow(deprecated)]
( (
None, None,

View File

@ -816,6 +816,8 @@ class FlashCausalLM(Model):
self.cuda_graph_warmup(bs, max_s, max_bt) self.cuda_graph_warmup(bs, max_s, max_bt)
except torch.cuda.OutOfMemoryError: except torch.cuda.OutOfMemoryError:
logger.exception(f"Decode cuda graph warmup failed") logger.exception(f"Decode cuda graph warmup failed")
else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
return int(num_blocks * BLOCK_SIZE) return int(num_blocks * BLOCK_SIZE)

View File

@ -4,11 +4,14 @@ import os
MEM_POOL = torch.cuda.graph_pool_handle() MEM_POOL = torch.cuda.graph_pool_handle()
# This is overridden by the cli # This is overridden by the cli
cuda_graphs = os.getenv("CUDA_GRAPHS") cuda_graphs = os.getenv("CUDA_GRAPHS")
if cuda_graphs is not None: if cuda_graphs is not None and cuda_graphs != "0":
try: try:
cuda_graphs = [int(item) for item in cuda_graphs.split(",")] cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}" f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
) )
else:
cuda_graphs = None
CUDA_GRAPHS = cuda_graphs CUDA_GRAPHS = cuda_graphs

View File

@ -474,6 +474,8 @@ class Mamba(Model):
self.cuda_graph_warmup(bs) self.cuda_graph_warmup(bs)
except Exception: except Exception:
logger.exception(f"Decode cuda graph warmup failed") logger.exception(f"Decode cuda graph warmup failed")
else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
return None return None