allow to fix paged attention num blocks
This commit is contained in:
parent
9a59ebcec3
commit
bb37321b9f
|
@ -808,12 +808,17 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
free_memory = get_free_memory(self.device, MEMORY_FRACTION)
|
free_memory = get_free_memory(self.device, MEMORY_FRACTION)
|
||||||
|
|
||||||
|
if os.environ.get("NUM_BLOCKS") is None:
|
||||||
num_blocks = (
|
num_blocks = (
|
||||||
# Leave 5% for some wiggle room
|
# Leave 5% for some wiggle room
|
||||||
int((free_memory * 0.95) // total_cache_size)
|
int((free_memory * 0.95) // total_cache_size)
|
||||||
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
|
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
|
||||||
+ cache_manager.num_blocks
|
+ cache_manager.num_blocks
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
num_blocks = int(os.environ["NUM_BLOCKS"])
|
||||||
|
|
||||||
|
logger.debug(f"Paged attention num_blocks: {num_blocks}")
|
||||||
|
|
||||||
del batch
|
del batch
|
||||||
del cache_manager
|
del cache_manager
|
||||||
|
|
Loading…
Reference in New Issue