feat(server): improve max tokens calculation

This commit is contained in:
OlivierDehaene 2023-04-26 13:07:25 +02:00
parent 7de8a377b0
commit 4f460e5bfe
3 changed files with 124 additions and 40 deletions

View File

@ -48,6 +48,8 @@ class CausalLMBatch(Batch):
# Maximum number of tokens this batch will grow to
max_tokens: int
# Maximum number of decode steps before at least one request finish
max_decode_steps: int
# Past metadata
keys_head_dim_last: bool = True
@ -77,7 +79,7 @@ class CausalLMBatch(Batch):
# Parse batch
max_truncation = 0
padding_right_offset = 0
max_decode_tokens = 0
max_decode_steps = None
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
@ -89,7 +91,15 @@ class CausalLMBatch(Batch):
)
stopping_criterias.append(stopping_criteria)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
# Maximum number of decode steps before one request finish
if max_decode_steps is None:
max_decode_steps = stopping_criteria.max_new_tokens
else:
max_decode_steps = min(
max_decode_steps, stopping_criteria.max_new_tokens
)
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
@ -118,7 +128,10 @@ class CausalLMBatch(Batch):
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
max_tokens = len(inputs) * max_input_length + max_decode_tokens
# Since we are sure that at least one request will be dropped in max_decode_steps,
# we know the kv_cache will only grow to cumulative_length + batch_size * max_decode_steps
# before getting filtered and decreasing in size
max_tokens = len(inputs) * (max_input_length + max_decode_steps)
return cls(
batch_id=pb.id,
@ -137,6 +150,7 @@ class CausalLMBatch(Batch):
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
max_decode_steps=max_decode_steps,
)
@tracer.start_as_current_span("filter")
@ -159,8 +173,8 @@ class CausalLMBatch(Batch):
next_token_choosers = []
stopping_criterias = []
total_remaining_decode_tokens = 0
new_padding_right_offset = 0
max_decode_steps = None
for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id]
@ -178,13 +192,17 @@ class CausalLMBatch(Batch):
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
remaining_decode_tokens = (
# Remaining decode steps for this request
remaining_decode = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
total_remaining_decode_tokens += remaining_decode_tokens
new_padding_right_offset = max(
new_padding_right_offset, remaining_decode_tokens
)
if max_decode_steps is None:
max_decode_steps = remaining_decode
else:
max_decode_steps = min(max_decode_steps, remaining_decode)
new_padding_right_offset = max(new_padding_right_offset, remaining_decode)
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
input_ids = self.input_ids[keep_indices]
@ -217,7 +235,10 @@ class CausalLMBatch(Batch):
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
del past_values
max_tokens = len(requests) * max_input_length + total_remaining_decode_tokens
# Since we are sure that at least one request will be dropped in max_decode_steps,
# we know the kv_cache will only grow to cumulative_length + batch_size * max_decode_steps
# before getting filtered and decreasing in size
max_tokens = len(requests) * (max_input_length + max_decode_steps)
self.requests = requests
self.requests_idx_mapping = requests_idx_mapping
@ -232,6 +253,7 @@ class CausalLMBatch(Batch):
self.max_input_length = max_input_length
self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens
self.max_decode_steps = max_decode_steps
return self
@ -256,7 +278,6 @@ class CausalLMBatch(Batch):
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
max_tokens = 0
# Batch tensors
input_ids = None
@ -264,6 +285,8 @@ class CausalLMBatch(Batch):
position_ids = None
past_key_values = []
max_decode_steps = None
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index = 0
@ -341,10 +364,11 @@ class CausalLMBatch(Batch):
layer[k] = t.view(len(batch), -1, *t.shape[-2:])
start_index = end_index
# Add eventual padding tokens that were added while concatenating
max_tokens += batch.max_tokens + (
max_input_length - batch.max_input_length
) * len(batch)
if max_decode_steps is None:
max_decode_steps = batch.max_decode_steps
else:
max_decode_steps = min(max_decode_steps, batch.max_decode_steps)
first_past_kvs = batches[0].past_key_values
_, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape
@ -417,6 +441,8 @@ class CausalLMBatch(Batch):
past_key_values.append([padded_past_keys, padded_past_values])
max_tokens = len(requests) * (max_input_length + max_decode_steps)
return cls(
batch_id=batches[0].batch_id,
requests=requests,
@ -435,6 +461,7 @@ class CausalLMBatch(Batch):
padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last,
max_tokens=max_tokens,
max_decode_steps=max_decode_steps,
)
def __len__(self):
@ -636,6 +663,8 @@ class CausalLM(Model):
batch.attention_mask[:, -batch.padding_right_offset] = 1
# Decrease right offset
batch.padding_right_offset -= 1
# Decrease max_decode_steps
batch.max_decode_steps -= 1
# Update position_ids
batch.position_ids = batch.position_ids[:, -1:] + 1

View File

@ -58,6 +58,8 @@ class FlashCausalLMBatch(Batch):
# Maximum number of tokens this batch will grow to
max_tokens: int
# Maximum number of decode steps before at least one request finish
max_decode_steps: int
def to_pb(self) -> generate_pb2.Batch:
return generate_pb2.Batch(
@ -92,7 +94,7 @@ class FlashCausalLMBatch(Batch):
# Cumulative length
cumulative_length = 0
max_tokens = 0
max_decode_steps = None
# Parse batch
for i, r in enumerate(pb.requests):
@ -127,7 +129,15 @@ class FlashCausalLMBatch(Batch):
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
max_new_tokens = stopping_criteria.max_new_tokens
# Maximum number of decode steps before one request finish
if max_decode_steps is None:
max_decode_steps = stopping_criteria.max_new_tokens
else:
max_decode_steps = min(
max_decode_steps, stopping_criteria.max_new_tokens
)
stopping_criterias.append(stopping_criteria)
all_input_ids_tensor.append(
@ -136,7 +146,11 @@ class FlashCausalLMBatch(Batch):
# Update
cumulative_length += input_length
max_tokens += input_length + max_new_tokens
# Since we are sure that at least one request will be dropped in max_decode_steps,
# we know the kv_cache will only grow to cumulative_length + batch_size * max_decode_steps
# before getting filtered and decreasing in size
max_tokens = cumulative_length + max_decode_steps * len(pb.requests)
return cls(
batch_id=pb.id,
@ -156,6 +170,7 @@ class FlashCausalLMBatch(Batch):
stopping_criterias=stopping_criterias,
past_pad=None,
max_tokens=max_tokens,
max_decode_steps=max_decode_steps,
)
@tracer.start_as_current_span("filter")
@ -190,7 +205,7 @@ class FlashCausalLMBatch(Batch):
next_token_choosers = []
stopping_criterias = []
max_tokens = 0
max_decode_steps = None
for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id]
@ -221,11 +236,21 @@ class FlashCausalLMBatch(Batch):
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
cumulative_length += request_input_length
max_tokens += request_input_length + (
# Remaining decode steps for this request
remaining_decode = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
if max_decode_steps is None:
max_decode_steps = remaining_decode
else:
max_decode_steps = min(max_decode_steps, remaining_decode)
cumulative_length += request_input_length
# Since we are sure that at least one request will be dropped in max_decode_steps,
# we know the kv_cache will only grow to cumulative_length + batch_size * max_decode_steps
# before getting filtered and decreasing in size
max_tokens = cumulative_length + max_decode_steps * len(requests)
if single_request:
# Preallocate tensor for bs = 1 case
@ -290,7 +315,8 @@ class FlashCausalLMBatch(Batch):
# Cumulative length
cumulative_batch_size = 0
cumulative_length = 0
max_tokens = 0
max_decode_steps = None
for i, batch in enumerate(batches):
requests.extend(batch.requests)
@ -329,10 +355,16 @@ class FlashCausalLMBatch(Batch):
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
if max_decode_steps is None:
max_decode_steps = batch.max_decode_steps
else:
max_decode_steps = min(max_decode_steps, batch.max_decode_steps)
# Update
cumulative_length += batch.cu_seqlens[-1]
cumulative_batch_size += len(batch)
max_tokens += batch.max_tokens
max_tokens = cumulative_length + max_decode_steps * cumulative_batch_size
return FlashCausalLMBatch(
batch_id=batches[0].batch_id,
@ -352,6 +384,7 @@ class FlashCausalLMBatch(Batch):
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_tokens=max_tokens,
max_decode_steps=max_decode_steps,
)
def __len__(self):
@ -617,6 +650,7 @@ class FlashCausalLM(Model):
batch.all_input_ids[i] = all_input_ids
batch.all_input_ids_tensor[i] = all_input_ids_tensor
batch.max_seqlen = max(batch.max_seqlen, new_input_length)
batch.max_decode_steps -= 1
if len(batch) != 1:
# Add each sequence before its padding
batch.past_key_values[i * 2] = present[:, start_index:end_index]

View File

@ -56,6 +56,8 @@ class Seq2SeqLMBatch(Batch):
# Maximum number of tokens this batch will grow to
max_tokens: int
# Maximum number of decode steps before at least one request finish
max_decode_steps: int
def to_pb(self) -> generate_pb2.Batch:
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
@ -86,7 +88,7 @@ class Seq2SeqLMBatch(Batch):
# Parse batch
max_truncation = 0
padding_right_offset = 0
max_decode_tokens = 0
max_decode_steps = None
for i, r in enumerate(pb.requests):
inputs.append(r.inputs)
requests_idx_mapping[r.id] = i
@ -99,7 +101,15 @@ class Seq2SeqLMBatch(Batch):
)
stopping_criterias.append(stopping_criteria)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
# Maximum number of decode steps before one request finish
if max_decode_steps is None:
max_decode_steps = stopping_criteria.max_new_tokens
else:
max_decode_steps = min(
max_decode_steps, stopping_criteria.max_new_tokens
)
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
@ -125,7 +135,7 @@ class Seq2SeqLMBatch(Batch):
)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
max_tokens = len(inputs) * max_input_length + max_decode_tokens
max_tokens = len(inputs) * (max_input_length + max_decode_steps)
return cls(
batch_id=pb.id,
@ -148,6 +158,7 @@ class Seq2SeqLMBatch(Batch):
max_decoder_input_length=1,
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
max_decode_steps=max_decode_steps,
)
@tracer.start_as_current_span("filter")
@ -177,7 +188,7 @@ class Seq2SeqLMBatch(Batch):
max_decoder_input_length = 0
padding_right_offset = 0
remaining_decode_tokens = 0
max_decode_steps = None
for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id]
@ -207,9 +218,15 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
remaining_decode_tokens += (
# Remaining decode steps for this request
remaining_decode = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
if max_decode_steps is None:
max_decode_steps = remaining_decode
else:
max_decode_steps = min(max_decode_steps, remaining_decode)
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
self.decoder_input_ids = self.decoder_input_ids[keep_indices]
@ -240,9 +257,8 @@ class Seq2SeqLMBatch(Batch):
layer[2] = layer[2][keep_indices, :, -max_input_length:]
layer[3] = layer[3][keep_indices, :, -max_input_length:]
max_tokens = (
len(requests) * (max_input_length + max_decoder_input_length)
+ remaining_decode_tokens
max_tokens = len(requests) * (
max_input_length + max_decoder_input_length + max_decode_steps
)
self.requests = requests
@ -259,6 +275,7 @@ class Seq2SeqLMBatch(Batch):
self.max_decoder_input_length = max_decoder_input_length
self.padding_right_offset = padding_right_offset
self.max_tokens = max_tokens
self.max_decode_steps = max_decode_steps
return self
@ -290,7 +307,7 @@ class Seq2SeqLMBatch(Batch):
token_offsets = []
next_token_choosers = []
stopping_criterias = []
max_tokens = 0
max_decode_steps = 0
# Batch tensors
attention_mask = None
@ -398,13 +415,11 @@ class Seq2SeqLMBatch(Batch):
]
start_index = end_index
# Add eventual padding tokens that were added while concatenating
max_tokens += batch.max_tokens + (
max_input_length
- batch.max_input_length
+ max_decoder_input_length
- batch.max_decoder_input_length
) * len(batch)
if max_decode_steps is None:
max_decode_steps = batch.max_decode_steps
else:
max_decode_steps = min(max_decode_steps, batch.max_decode_steps)
# Determine shapes for new past kv tensors
first_past_kvs = batches[0].past_key_values
@ -471,6 +486,10 @@ class Seq2SeqLMBatch(Batch):
start_index = end_index
max_tokens = len(requests) * (
max_input_length + max_decoder_input_length + max_decode_steps
)
return cls(
batch_id=batches[0].batch_id,
requests=requests,
@ -492,6 +511,7 @@ class Seq2SeqLMBatch(Batch):
max_decoder_input_length=max_decoder_input_length,
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
max_decode_steps=max_decode_steps,
)
def __len__(self):
@ -717,5 +737,6 @@ class Seq2SeqLM(Model):
if batch.decoder_attention_mask is not None:
batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
batch.padding_right_offset -= 1
batch.max_decode_steps -= 1
return generations, batch