Make `--cuda-graphs` work as expected (bis) (#1768)
This was ignored up to now, even with `--cuda-graphs 0`. With this fix, `--cuda-graphs` is obeyed to.
This commit is contained in:
parent
2d0a7173d4
commit
26b3916612
|
@ -1379,6 +1379,7 @@ fn main() -> Result<(), LauncherError> {
|
|||
|
||||
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) {
|
||||
(Some(cuda_graphs), Some(_q)) => cuda_graphs.clone(),
|
||||
(Some(cuda_graphs), None) => cuda_graphs.clone(),
|
||||
#[allow(deprecated)]
|
||||
(
|
||||
None,
|
||||
|
|
|
@ -816,6 +816,8 @@ class FlashCausalLM(Model):
|
|||
self.cuda_graph_warmup(bs, max_s, max_bt)
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
logger.exception(f"Decode cuda graph warmup failed")
|
||||
else:
|
||||
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
||||
|
||||
return int(num_blocks * BLOCK_SIZE)
|
||||
|
||||
|
|
|
@ -4,11 +4,14 @@ import os
|
|||
MEM_POOL = torch.cuda.graph_pool_handle()
|
||||
# This is overridden by the cli
|
||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||
if cuda_graphs is not None:
|
||||
if cuda_graphs is not None and cuda_graphs != "0":
|
||||
try:
|
||||
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
|
||||
)
|
||||
else:
|
||||
cuda_graphs = None
|
||||
|
||||
CUDA_GRAPHS = cuda_graphs
|
||||
|
|
|
@ -474,6 +474,8 @@ class Mamba(Model):
|
|||
self.cuda_graph_warmup(bs)
|
||||
except Exception:
|
||||
logger.exception(f"Decode cuda graph warmup failed")
|
||||
else:
|
||||
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
||||
|
||||
return None
|
||||
|
||||
|
|
Loading…
Reference in New Issue