From 10534511ea69d289d67afec84aee65b674d3c3c5 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 24 Oct 2024 06:55:25 +0200 Subject: [PATCH] Much simpler logic after the overhead. --- .../models/flash_causal_lm.py | 47 ++++++++----------- 1 file changed, 19 insertions(+), 28 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index ca45020c..bbff0243 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1397,34 +1397,6 @@ class FlashCausalLM(Model): cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size - if max_total_tokens is None: - if get_support_chunking(): - model_max_length = self.tokenizer.model_max_length - free_memory = get_free_memory(self.device, MEMORY_FRACTION) - spare_blocks = ( - # Leave 5% for some wiggle room - int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size) - + batch.num_blocks - ) - spare_blocks = small_power_of_2(spare_blocks) - - available_blocks = min(model_max_length, spare_blocks) - batch.num_blocks = available_blocks - batch.max_blocks = available_blocks - max_input_tokens = ( - available_blocks - 1 - if max_input_tokens is None - else max_input_tokens - ) - max_total_tokens = available_blocks - else: - max_total_tokens = sum(len(input_ids) for input_ids in batch.input_ids) - max_input_tokens = ( - max_total_tokens - 1 - if max_input_tokens is None - else max_input_tokens - ) - try: self.init_kv_cache( batch.num_blocks, @@ -1459,8 +1431,27 @@ class FlashCausalLM(Model): ) log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}") + if max_total_tokens is None: + if get_support_chunking(): + model_max_length = self.tokenizer.model_max_length + max_input_tokens = ( + min((num_blocks * BLOCK_SIZE - 1), model_max_length) + if max_input_tokens is None + else max_input_tokens + ) + max_total_tokens = num_blocks * BLOCK_SIZE + + else: + max_total_tokens = sum(len(input_ids) for input_ids in batch.input_ids) + max_input_tokens = ( + max_total_tokens - 1 + if max_input_tokens is None + else max_input_tokens + ) del batch + self.kv_cache = [] + empty_cache() self.init_kv_cache( num_blocks,