From bb37321b9fcff91f18c4678809821a5e3ba66b3b Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 5 Jun 2024 10:05:04 +0000 Subject: [PATCH] allow to fix paged attention num blocks --- .../models/flash_causal_lm.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 86d9b4c8..2795c7bd 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -808,12 +808,17 @@ class FlashCausalLM(Model): free_memory = get_free_memory(self.device, MEMORY_FRACTION) - num_blocks = ( - # Leave 5% for some wiggle room - int((free_memory * 0.95) // total_cache_size) - # Add batch.blocks as we allocated it above, so it is included in the peak memory. - + cache_manager.num_blocks - ) + if os.environ.get("NUM_BLOCKS") is None: + num_blocks = ( + # Leave 5% for some wiggle room + int((free_memory * 0.95) // total_cache_size) + # Add batch.blocks as we allocated it above, so it is included in the peak memory. + + cache_manager.num_blocks + ) + else: + num_blocks = int(os.environ["NUM_BLOCKS"]) + + logger.debug(f"Paged attention num_blocks: {num_blocks}") del batch del cache_manager