allow to fix paged attention num blocks
This commit is contained in:
parent
9a59ebcec3
commit
bb37321b9f
|
@ -808,12 +808,17 @@ class FlashCausalLM(Model):
|
|||
|
||||
free_memory = get_free_memory(self.device, MEMORY_FRACTION)
|
||||
|
||||
num_blocks = (
|
||||
# Leave 5% for some wiggle room
|
||||
int((free_memory * 0.95) // total_cache_size)
|
||||
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
|
||||
+ cache_manager.num_blocks
|
||||
)
|
||||
if os.environ.get("NUM_BLOCKS") is None:
|
||||
num_blocks = (
|
||||
# Leave 5% for some wiggle room
|
||||
int((free_memory * 0.95) // total_cache_size)
|
||||
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
|
||||
+ cache_manager.num_blocks
|
||||
)
|
||||
else:
|
||||
num_blocks = int(os.environ["NUM_BLOCKS"])
|
||||
|
||||
logger.debug(f"Paged attention num_blocks: {num_blocks}")
|
||||
|
||||
del batch
|
||||
del cache_manager
|
||||
|
|
Loading…
Reference in New Issue