diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index e04a9719..886fe486 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -28,7 +28,7 @@ from text_generation_server.models.cache_manager import ( BLOCK_SIZE, ) from text_generation_server.pb import generate_pb2 -from text_generation_server.models.globals import MEM_POOL +from text_generation_server.models.globals import MEM_POOL, ENABLE_CUDA_GRAPHS from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION @@ -793,7 +793,7 @@ class FlashCausalLM(Model): self.device, ) - if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True": + if ENABLE_CUDA_GRAPHS: try: logger.info("Experimental support for Cuda Graphs is enabled") # Warmup cuda graphs diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index b0dca376..3b8a70bc 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -1,3 +1,6 @@ import torch +import os MEM_POOL = torch.cuda.graph_pool_handle() +# This is overridden by the cli +ENABLE_CUDA_GRAPHS = os.getenv("ENABLE_CUDA_GRAPHS", "false").lower() in {"1", "true"} diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 8f18e475..868db6aa 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -13,7 +13,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.models.globals import MEM_POOL +from text_generation_server.models.globals import ENABLE_CUDA_GRAPHS, MEM_POOL import time from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel, InferenceParams from text_generation_server.models import Model @@ -377,7 +377,9 @@ class Mamba(Model): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - self.process_group, _rank, _world_size = initialize_torch_distributed() + self.process_group, _rank, world_size = initialize_torch_distributed() + if world_size > 1: + raise RuntimeError("Mamba does not support Tensor Parallelism (TP)") self.cuda_graphs = {} if torch.cuda.is_available(): device = torch.device("cuda") @@ -427,7 +429,7 @@ class Mamba(Model): def warmup(self, batch) -> Optional[int]: # TODO: implement warmup for Mamba if needed - if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True": + if ENABLE_CUDA_GRAPHS: if self.speculate is None or self.speculate == 0: try: logger.info("Experimental support for Cuda Graphs is enabled")