Trying to fix non chunking targets.
This commit is contained in:
parent
a31db04709
commit
0a01dde986
|
@ -1398,22 +1398,32 @@ class FlashCausalLM(Model):
|
|||
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||
|
||||
if max_total_tokens is None:
|
||||
model_max_length = self.tokenizer.model_max_length
|
||||
free_memory = get_free_memory(self.device, MEMORY_FRACTION)
|
||||
spare_blocks = (
|
||||
# Leave 5% for some wiggle room
|
||||
int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size)
|
||||
+ batch.num_blocks
|
||||
)
|
||||
spare_blocks = small_power_of_2(spare_blocks)
|
||||
if get_support_chunking():
|
||||
model_max_length = self.tokenizer.model_max_length
|
||||
free_memory = get_free_memory(self.device, MEMORY_FRACTION)
|
||||
spare_blocks = (
|
||||
# Leave 5% for some wiggle room
|
||||
int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size)
|
||||
+ batch.num_blocks
|
||||
)
|
||||
spare_blocks = small_power_of_2(spare_blocks)
|
||||
|
||||
available_blocks = min(model_max_length, spare_blocks)
|
||||
batch.num_blocks = available_blocks
|
||||
batch.max_blocks = available_blocks
|
||||
max_input_tokens = (
|
||||
available_blocks - 1 if max_input_tokens is None else max_input_tokens
|
||||
)
|
||||
max_total_tokens = available_blocks
|
||||
available_blocks = min(model_max_length, spare_blocks)
|
||||
batch.num_blocks = available_blocks
|
||||
batch.max_blocks = available_blocks
|
||||
max_input_tokens = (
|
||||
available_blocks - 1
|
||||
if max_input_tokens is None
|
||||
else max_input_tokens
|
||||
)
|
||||
max_total_tokens = available_blocks
|
||||
else:
|
||||
max_total_tokens = batch.num_blocks
|
||||
max_input_tokens = (
|
||||
batch.num_blocks - 1
|
||||
if max_input_tokens is None
|
||||
else max_input_tokens
|
||||
)
|
||||
|
||||
try:
|
||||
self.init_kv_cache(
|
||||
|
|
Loading…
Reference in New Issue