diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index dd4203e0..2cdc49a0 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1231,6 +1231,13 @@ class FlashCausalLM(Model): torch.cuda.synchronize() def warmup(self, batch: FlashCausalLMBatch): + inital_free_memory = get_free_memory(self.device, MEMORY_FRACTION) + + log_master( + logger.info, + f"Free memory before the warmup: {inital_free_memory/1024/1024:.2f} MB", + ) + # The warmup batch is the biggest batch we could ever receive empty_cache() @@ -1284,6 +1291,15 @@ class FlashCausalLM(Model): self.device, ) + # cuda graphs must fit within the new memory limit. In order to avoid an OOM, we + # need to exit early if there is not enough memory to fit a particular cuda graph + free_memory_post_alloc = get_free_memory(self.device, MEMORY_FRACTION) + + log_master( + logger.info, + f"Free memory after allocating the cache: {free_memory_post_alloc/1024/1024:.2f} MB", + ) + if SYSTEM == "rocm": if ( os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None @@ -1341,9 +1357,37 @@ class FlashCausalLM(Model): logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}" ) # Warmup cuda graphs + last_allocation_amount = 0 + last_available_memory = free_memory_post_alloc + last_bs = 0 for bs in CUDA_GRAPHS: if self.speculate is None or self.speculate + 1 <= bs: + expected_memory = int( + last_allocation_amount * (bs / last_bs if last_bs else 2) + ) + if expected_memory > last_available_memory: + skipped_graphs = [str(k) for k in CUDA_GRAPHS if k <= bs] + log_master( + logger.warning, + f"Avoiding CUDA graph warmup for sizes {', '.join(skipped_graphs)} due to insufficient memory.", + ) + break + self.cuda_graph_warmup(bs, max_s, max_bt) + current_available_memory = get_free_memory( + self.device, MEMORY_FRACTION + ) + last_allocation_amount = ( + last_available_memory - current_available_memory + ) + last_available_memory = current_available_memory + last_bs = bs + # report the total memory used + total_cuda_graph_memory = free_memory_post_alloc - last_available_memory + log_master( + logger.info, + f"Total memory used for CUDA graphs: {total_cuda_graph_memory/1024/1024:.2f} MB", + ) except torch.cuda.OutOfMemoryError: logger.exception("Decode cuda graph warmup failed") else: