diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 86d9b4c8..2795c7bd 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -808,12 +808,17 @@ class FlashCausalLM(Model): free_memory = get_free_memory(self.device, MEMORY_FRACTION) - num_blocks = ( - # Leave 5% for some wiggle room - int((free_memory * 0.95) // total_cache_size) - # Add batch.blocks as we allocated it above, so it is included in the peak memory. - + cache_manager.num_blocks - ) + if os.environ.get("NUM_BLOCKS") is None: + num_blocks = ( + # Leave 5% for some wiggle room + int((free_memory * 0.95) // total_cache_size) + # Add batch.blocks as we allocated it above, so it is included in the peak memory. + + cache_manager.num_blocks + ) + else: + num_blocks = int(os.environ["NUM_BLOCKS"]) + + logger.debug(f"Paged attention num_blocks: {num_blocks}") del batch del cache_manager