Much simpler logic after the overhead.
This commit is contained in:
parent
849d8821ab
commit
10534511ea
|
@ -1397,34 +1397,6 @@ class FlashCausalLM(Model):
|
||||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||||
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||||
|
|
||||||
if max_total_tokens is None:
|
|
||||||
if get_support_chunking():
|
|
||||||
model_max_length = self.tokenizer.model_max_length
|
|
||||||
free_memory = get_free_memory(self.device, MEMORY_FRACTION)
|
|
||||||
spare_blocks = (
|
|
||||||
# Leave 5% for some wiggle room
|
|
||||||
int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size)
|
|
||||||
+ batch.num_blocks
|
|
||||||
)
|
|
||||||
spare_blocks = small_power_of_2(spare_blocks)
|
|
||||||
|
|
||||||
available_blocks = min(model_max_length, spare_blocks)
|
|
||||||
batch.num_blocks = available_blocks
|
|
||||||
batch.max_blocks = available_blocks
|
|
||||||
max_input_tokens = (
|
|
||||||
available_blocks - 1
|
|
||||||
if max_input_tokens is None
|
|
||||||
else max_input_tokens
|
|
||||||
)
|
|
||||||
max_total_tokens = available_blocks
|
|
||||||
else:
|
|
||||||
max_total_tokens = sum(len(input_ids) for input_ids in batch.input_ids)
|
|
||||||
max_input_tokens = (
|
|
||||||
max_total_tokens - 1
|
|
||||||
if max_input_tokens is None
|
|
||||||
else max_input_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.init_kv_cache(
|
self.init_kv_cache(
|
||||||
batch.num_blocks,
|
batch.num_blocks,
|
||||||
|
@ -1459,8 +1431,27 @@ class FlashCausalLM(Model):
|
||||||
)
|
)
|
||||||
|
|
||||||
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
|
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
|
||||||
|
if max_total_tokens is None:
|
||||||
|
if get_support_chunking():
|
||||||
|
model_max_length = self.tokenizer.model_max_length
|
||||||
|
max_input_tokens = (
|
||||||
|
min((num_blocks * BLOCK_SIZE - 1), model_max_length)
|
||||||
|
if max_input_tokens is None
|
||||||
|
else max_input_tokens
|
||||||
|
)
|
||||||
|
max_total_tokens = num_blocks * BLOCK_SIZE
|
||||||
|
|
||||||
|
else:
|
||||||
|
max_total_tokens = sum(len(input_ids) for input_ids in batch.input_ids)
|
||||||
|
max_input_tokens = (
|
||||||
|
max_total_tokens - 1
|
||||||
|
if max_input_tokens is None
|
||||||
|
else max_input_tokens
|
||||||
|
)
|
||||||
|
|
||||||
del batch
|
del batch
|
||||||
|
self.kv_cache = []
|
||||||
|
empty_cache()
|
||||||
|
|
||||||
self.init_kv_cache(
|
self.init_kv_cache(
|
||||||
num_blocks,
|
num_blocks,
|
||||||
|
|
Loading…
Reference in New Issue