allow to fix paged attention num blocks

This commit is contained in:
fxmarty 2024-06-05 10:05:04 +00:00
parent 9a59ebcec3
commit bb37321b9f
1 changed files with 11 additions and 6 deletions

View File

@ -808,12 +808,17 @@ class FlashCausalLM(Model):
free_memory = get_free_memory(self.device, MEMORY_FRACTION)
num_blocks = (
# Leave 5% for some wiggle room
int((free_memory * 0.95) // total_cache_size)
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
+ cache_manager.num_blocks
)
if os.environ.get("NUM_BLOCKS") is None:
num_blocks = (
# Leave 5% for some wiggle room
int((free_memory * 0.95) // total_cache_size)
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
+ cache_manager.num_blocks
)
else:
num_blocks = int(os.environ["NUM_BLOCKS"])
logger.debug(f"Paged attention num_blocks: {num_blocks}")
del batch
del cache_manager