Much simpler logic after the overhead.

This commit is contained in:
Nicolas Patry 2024-10-24 06:55:25 +02:00
parent 849d8821ab
commit 10534511ea
No known key found for this signature in database
GPG Key ID: D2920555C90F704C
1 changed files with 19 additions and 28 deletions

View File

@ -1397,34 +1397,6 @@ class FlashCausalLM(Model):
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
if max_total_tokens is None:
if get_support_chunking():
model_max_length = self.tokenizer.model_max_length
free_memory = get_free_memory(self.device, MEMORY_FRACTION)
spare_blocks = (
# Leave 5% for some wiggle room
int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size)
+ batch.num_blocks
)
spare_blocks = small_power_of_2(spare_blocks)
available_blocks = min(model_max_length, spare_blocks)
batch.num_blocks = available_blocks
batch.max_blocks = available_blocks
max_input_tokens = (
available_blocks - 1
if max_input_tokens is None
else max_input_tokens
)
max_total_tokens = available_blocks
else:
max_total_tokens = sum(len(input_ids) for input_ids in batch.input_ids)
max_input_tokens = (
max_total_tokens - 1
if max_input_tokens is None
else max_input_tokens
)
try: try:
self.init_kv_cache( self.init_kv_cache(
batch.num_blocks, batch.num_blocks,
@ -1459,8 +1431,27 @@ class FlashCausalLM(Model):
) )
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}") log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
if max_total_tokens is None:
if get_support_chunking():
model_max_length = self.tokenizer.model_max_length
max_input_tokens = (
min((num_blocks * BLOCK_SIZE - 1), model_max_length)
if max_input_tokens is None
else max_input_tokens
)
max_total_tokens = num_blocks * BLOCK_SIZE
else:
max_total_tokens = sum(len(input_ids) for input_ids in batch.input_ids)
max_input_tokens = (
max_total_tokens - 1
if max_input_tokens is None
else max_input_tokens
)
del batch del batch
self.kv_cache = []
empty_cache()
self.init_kv_cache( self.init_kv_cache(
num_blocks, num_blocks,