diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index 36fa1241..86394ff7 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -206,12 +206,13 @@ Options: [env: MAX_BATCH_SIZE=] ``` -## ENABLE_CUDA_GRAPHS +## CUDA_GRAPHS ```shell - --enable-cuda-graphs - Enable experimental support for cuda graphs + --cuda-graphs + Specify the batch sizes to compute cuda graphs for. Use "0" to disable - [env: ENABLE_CUDA_GRAPHS=] + [env: CUDA_GRAPHS=] + [default: 1,2,4,8,16,32,64,96,128] ``` ## HOSTNAME diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 32bf4e54..022b2298 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -383,7 +383,6 @@ def launcher(event_loop): env = { "LOG_LEVEL": "info,text_generation_router=debug", - "ENABLE_CUDA_GRAPHS": "true", } if not use_flash_attention: env["USE_FLASH_ATTENTION"] = "false" diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 990eade4..63676392 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -284,9 +284,15 @@ struct Args { #[clap(long, env)] max_batch_size: Option, - /// Enable experimental support for cuda graphs - #[clap(long, env)] - enable_cuda_graphs: bool, + /// Specify the batch sizes to compute cuda graphs for. + /// Use "0" to disable. + #[clap( + long, + env, + value_delimiter = ',', + default_value = "1,2,4,8,16,32,64,96,128" + )] + cuda_graphs: Vec, /// The IP address to listen on #[clap(default_value = "0.0.0.0", long, env)] @@ -416,7 +422,7 @@ fn shard_manager( disable_custom_kernels: bool, watermark_gamma: Option, watermark_delta: Option, - enable_cuda_graphs: bool, + cuda_graphs: Vec, cuda_memory_fraction: f32, rope_scaling: Option, rope_factor: Option, @@ -549,8 +555,16 @@ fn shard_manager( }; // Enable experimental support for cuda graphs - if enable_cuda_graphs { - envs.push(("ENABLE_CUDA_GRAPHS".into(), "True".into())) + if !cuda_graphs.is_empty() { + envs.push(( + "CUDA_GRAPHS".into(), + cuda_graphs + .into_iter() + .map(|c| c.to_string()) + .collect::>() + .join(",") + .into(), + )); } // If disable_custom_kernels is true, pass it to the shard as an env var @@ -941,7 +955,11 @@ fn spawn_shards( let disable_custom_kernels = args.disable_custom_kernels; let watermark_gamma = args.watermark_gamma; let watermark_delta = args.watermark_delta; - let enable_cuda_graphs = args.enable_cuda_graphs; + let cuda_graphs: Vec = args + .cuda_graphs + .iter() + .filter_map(|&c| if c > 0 { Some(c) } else { None }) + .collect(); let cuda_memory_fraction = args.cuda_memory_fraction; let rope_scaling = args.rope_scaling; let rope_factor = args.rope_factor; @@ -963,7 +981,7 @@ fn spawn_shards( disable_custom_kernels, watermark_gamma, watermark_delta, - enable_cuda_graphs, + cuda_graphs, cuda_memory_fraction, rope_scaling, rope_factor, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 20f93820..5c25f341 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -28,7 +28,7 @@ from text_generation_server.models.cache_manager import ( BLOCK_SIZE, ) from text_generation_server.pb import generate_pb2 -from text_generation_server.models.globals import MEM_POOL, ENABLE_CUDA_GRAPHS +from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION @@ -798,11 +798,11 @@ class FlashCausalLM(Model): self.device, ) - if ENABLE_CUDA_GRAPHS: + if CUDA_GRAPHS: try: - logger.info("Experimental support for Cuda Graphs is enabled") + logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}") # Warmup cuda graphs - for bs in [1, 2, 4] + [8 * i for i in range(1, 9)]: + for bs in CUDA_GRAPHS: if self.speculate is None or self.speculate + 1 <= bs: self.cuda_graph_warmup(bs, max_s, max_bt) except Exception: diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 3b8a70bc..6f554049 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -3,4 +3,12 @@ import os MEM_POOL = torch.cuda.graph_pool_handle() # This is overridden by the cli -ENABLE_CUDA_GRAPHS = os.getenv("ENABLE_CUDA_GRAPHS", "false").lower() in {"1", "true"} +cuda_graphs = os.getenv("CUDA_GRAPHS") +if cuda_graphs is not None: + try: + cuda_graphs = [int(item) for item in cuda_graphs.split(",")] + except Exception as e: + raise RuntimeError( + f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}" + ) +CUDA_GRAPHS = cuda_graphs diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 2500d454..07a81491 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -13,7 +13,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.models.globals import ENABLE_CUDA_GRAPHS, MEM_POOL +from text_generation_server.models.globals import CUDA_GRAPHS, MEM_POOL import time from text_generation_server.models.custom_modeling.mamba_modeling import ( MambaModel, @@ -465,12 +465,12 @@ class Mamba(Model): def warmup(self, batch) -> Optional[int]: # TODO: implement warmup for Mamba if needed - if ENABLE_CUDA_GRAPHS: + if CUDA_GRAPHS: if self.speculate is None or self.speculate == 0: try: - logger.info("Experimental support for Cuda Graphs is enabled") + logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}") # Warmup cuda graphs - for bs in [1, 2, 4] + [8 * i for i in range(1, 9)]: + for bs in CUDA_GRAPHS: self.cuda_graph_warmup(bs) except Exception: logger.exception(f"Decode cuda graph warmup failed")