Make `--cuda-graphs` work as expected (bis) (#1768)
This was ignored up to now, even with `--cuda-graphs 0`. With this fix, `--cuda-graphs` is obeyed to.
This commit is contained in:
parent
2d0a7173d4
commit
26b3916612
|
@ -1379,6 +1379,7 @@ fn main() -> Result<(), LauncherError> {
|
||||||
|
|
||||||
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) {
|
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) {
|
||||||
(Some(cuda_graphs), Some(_q)) => cuda_graphs.clone(),
|
(Some(cuda_graphs), Some(_q)) => cuda_graphs.clone(),
|
||||||
|
(Some(cuda_graphs), None) => cuda_graphs.clone(),
|
||||||
#[allow(deprecated)]
|
#[allow(deprecated)]
|
||||||
(
|
(
|
||||||
None,
|
None,
|
||||||
|
|
|
@ -816,6 +816,8 @@ class FlashCausalLM(Model):
|
||||||
self.cuda_graph_warmup(bs, max_s, max_bt)
|
self.cuda_graph_warmup(bs, max_s, max_bt)
|
||||||
except torch.cuda.OutOfMemoryError:
|
except torch.cuda.OutOfMemoryError:
|
||||||
logger.exception(f"Decode cuda graph warmup failed")
|
logger.exception(f"Decode cuda graph warmup failed")
|
||||||
|
else:
|
||||||
|
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
||||||
|
|
||||||
return int(num_blocks * BLOCK_SIZE)
|
return int(num_blocks * BLOCK_SIZE)
|
||||||
|
|
||||||
|
|
|
@ -4,11 +4,14 @@ import os
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle()
|
MEM_POOL = torch.cuda.graph_pool_handle()
|
||||||
# This is overridden by the cli
|
# This is overridden by the cli
|
||||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||||
if cuda_graphs is not None:
|
if cuda_graphs is not None and cuda_graphs != "0":
|
||||||
try:
|
try:
|
||||||
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
|
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
|
f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
cuda_graphs = None
|
||||||
|
|
||||||
CUDA_GRAPHS = cuda_graphs
|
CUDA_GRAPHS = cuda_graphs
|
||||||
|
|
|
@ -474,6 +474,8 @@ class Mamba(Model):
|
||||||
self.cuda_graph_warmup(bs)
|
self.cuda_graph_warmup(bs)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f"Decode cuda graph warmup failed")
|
logger.exception(f"Decode cuda graph warmup failed")
|
||||||
|
else:
|
||||||
|
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue