feat(server): improve max tokens calculation
This commit is contained in:
parent
7de8a377b0
commit
4f460e5bfe
|
@ -48,6 +48,8 @@ class CausalLMBatch(Batch):
|
|||
|
||||
# Maximum number of tokens this batch will grow to
|
||||
max_tokens: int
|
||||
# Maximum number of decode steps before at least one request finish
|
||||
max_decode_steps: int
|
||||
|
||||
# Past metadata
|
||||
keys_head_dim_last: bool = True
|
||||
|
@ -77,7 +79,7 @@ class CausalLMBatch(Batch):
|
|||
# Parse batch
|
||||
max_truncation = 0
|
||||
padding_right_offset = 0
|
||||
max_decode_tokens = 0
|
||||
max_decode_steps = None
|
||||
for i, r in enumerate(pb.requests):
|
||||
requests_idx_mapping[r.id] = i
|
||||
inputs.append(r.inputs)
|
||||
|
@ -89,7 +91,15 @@ class CausalLMBatch(Batch):
|
|||
)
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
max_decode_tokens += stopping_criteria.max_new_tokens
|
||||
|
||||
# Maximum number of decode steps before one request finish
|
||||
if max_decode_steps is None:
|
||||
max_decode_steps = stopping_criteria.max_new_tokens
|
||||
else:
|
||||
max_decode_steps = min(
|
||||
max_decode_steps, stopping_criteria.max_new_tokens
|
||||
)
|
||||
|
||||
padding_right_offset = max(
|
||||
padding_right_offset, stopping_criteria.max_new_tokens
|
||||
)
|
||||
|
@ -118,7 +128,10 @@ class CausalLMBatch(Batch):
|
|||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
||||
|
||||
max_tokens = len(inputs) * max_input_length + max_decode_tokens
|
||||
# Since we are sure that at least one request will be dropped in max_decode_steps,
|
||||
# we know the kv_cache will only grow to cumulative_length + batch_size * max_decode_steps
|
||||
# before getting filtered and decreasing in size
|
||||
max_tokens = len(inputs) * (max_input_length + max_decode_steps)
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
|
@ -137,6 +150,7 @@ class CausalLMBatch(Batch):
|
|||
max_input_length=max_input_length.item(),
|
||||
padding_right_offset=padding_right_offset,
|
||||
max_tokens=max_tokens,
|
||||
max_decode_steps=max_decode_steps,
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
|
@ -159,8 +173,8 @@ class CausalLMBatch(Batch):
|
|||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
|
||||
total_remaining_decode_tokens = 0
|
||||
new_padding_right_offset = 0
|
||||
max_decode_steps = None
|
||||
|
||||
for i, r in enumerate(requests):
|
||||
idx = self.requests_idx_mapping[r.id]
|
||||
|
@ -178,13 +192,17 @@ class CausalLMBatch(Batch):
|
|||
next_token_choosers.append(self.next_token_choosers[idx])
|
||||
stopping_criteria = self.stopping_criterias[idx]
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
remaining_decode_tokens = (
|
||||
|
||||
# Remaining decode steps for this request
|
||||
remaining_decode = (
|
||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
)
|
||||
total_remaining_decode_tokens += remaining_decode_tokens
|
||||
new_padding_right_offset = max(
|
||||
new_padding_right_offset, remaining_decode_tokens
|
||||
)
|
||||
if max_decode_steps is None:
|
||||
max_decode_steps = remaining_decode
|
||||
else:
|
||||
max_decode_steps = min(max_decode_steps, remaining_decode)
|
||||
|
||||
new_padding_right_offset = max(new_padding_right_offset, remaining_decode)
|
||||
|
||||
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
|
||||
input_ids = self.input_ids[keep_indices]
|
||||
|
@ -217,7 +235,10 @@ class CausalLMBatch(Batch):
|
|||
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
|
||||
del past_values
|
||||
|
||||
max_tokens = len(requests) * max_input_length + total_remaining_decode_tokens
|
||||
# Since we are sure that at least one request will be dropped in max_decode_steps,
|
||||
# we know the kv_cache will only grow to cumulative_length + batch_size * max_decode_steps
|
||||
# before getting filtered and decreasing in size
|
||||
max_tokens = len(requests) * (max_input_length + max_decode_steps)
|
||||
|
||||
self.requests = requests
|
||||
self.requests_idx_mapping = requests_idx_mapping
|
||||
|
@ -232,6 +253,7 @@ class CausalLMBatch(Batch):
|
|||
self.max_input_length = max_input_length
|
||||
self.padding_right_offset = new_padding_right_offset
|
||||
self.max_tokens = max_tokens
|
||||
self.max_decode_steps = max_decode_steps
|
||||
|
||||
return self
|
||||
|
||||
|
@ -256,7 +278,6 @@ class CausalLMBatch(Batch):
|
|||
all_input_ids = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
max_tokens = 0
|
||||
|
||||
# Batch tensors
|
||||
input_ids = None
|
||||
|
@ -264,6 +285,8 @@ class CausalLMBatch(Batch):
|
|||
position_ids = None
|
||||
past_key_values = []
|
||||
|
||||
max_decode_steps = None
|
||||
|
||||
# Used for slicing correctly inside the tensors
|
||||
# Equivalent to a cumsum on batch sizes
|
||||
start_index = 0
|
||||
|
@ -341,10 +364,11 @@ class CausalLMBatch(Batch):
|
|||
layer[k] = t.view(len(batch), -1, *t.shape[-2:])
|
||||
|
||||
start_index = end_index
|
||||
# Add eventual padding tokens that were added while concatenating
|
||||
max_tokens += batch.max_tokens + (
|
||||
max_input_length - batch.max_input_length
|
||||
) * len(batch)
|
||||
|
||||
if max_decode_steps is None:
|
||||
max_decode_steps = batch.max_decode_steps
|
||||
else:
|
||||
max_decode_steps = min(max_decode_steps, batch.max_decode_steps)
|
||||
|
||||
first_past_kvs = batches[0].past_key_values
|
||||
_, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape
|
||||
|
@ -417,6 +441,8 @@ class CausalLMBatch(Batch):
|
|||
|
||||
past_key_values.append([padded_past_keys, padded_past_values])
|
||||
|
||||
max_tokens = len(requests) * (max_input_length + max_decode_steps)
|
||||
|
||||
return cls(
|
||||
batch_id=batches[0].batch_id,
|
||||
requests=requests,
|
||||
|
@ -435,6 +461,7 @@ class CausalLMBatch(Batch):
|
|||
padding_right_offset=padding_right_offset,
|
||||
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||
max_tokens=max_tokens,
|
||||
max_decode_steps=max_decode_steps,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
|
@ -636,6 +663,8 @@ class CausalLM(Model):
|
|||
batch.attention_mask[:, -batch.padding_right_offset] = 1
|
||||
# Decrease right offset
|
||||
batch.padding_right_offset -= 1
|
||||
# Decrease max_decode_steps
|
||||
batch.max_decode_steps -= 1
|
||||
|
||||
# Update position_ids
|
||||
batch.position_ids = batch.position_ids[:, -1:] + 1
|
||||
|
|
|
@ -58,6 +58,8 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
# Maximum number of tokens this batch will grow to
|
||||
max_tokens: int
|
||||
# Maximum number of decode steps before at least one request finish
|
||||
max_decode_steps: int
|
||||
|
||||
def to_pb(self) -> generate_pb2.Batch:
|
||||
return generate_pb2.Batch(
|
||||
|
@ -92,7 +94,7 @@ class FlashCausalLMBatch(Batch):
|
|||
# Cumulative length
|
||||
cumulative_length = 0
|
||||
|
||||
max_tokens = 0
|
||||
max_decode_steps = None
|
||||
|
||||
# Parse batch
|
||||
for i, r in enumerate(pb.requests):
|
||||
|
@ -127,7 +129,15 @@ class FlashCausalLMBatch(Batch):
|
|||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
max_new_tokens = stopping_criteria.max_new_tokens
|
||||
|
||||
# Maximum number of decode steps before one request finish
|
||||
if max_decode_steps is None:
|
||||
max_decode_steps = stopping_criteria.max_new_tokens
|
||||
else:
|
||||
max_decode_steps = min(
|
||||
max_decode_steps, stopping_criteria.max_new_tokens
|
||||
)
|
||||
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
|
||||
all_input_ids_tensor.append(
|
||||
|
@ -136,7 +146,11 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
# Update
|
||||
cumulative_length += input_length
|
||||
max_tokens += input_length + max_new_tokens
|
||||
|
||||
# Since we are sure that at least one request will be dropped in max_decode_steps,
|
||||
# we know the kv_cache will only grow to cumulative_length + batch_size * max_decode_steps
|
||||
# before getting filtered and decreasing in size
|
||||
max_tokens = cumulative_length + max_decode_steps * len(pb.requests)
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
|
@ -156,6 +170,7 @@ class FlashCausalLMBatch(Batch):
|
|||
stopping_criterias=stopping_criterias,
|
||||
past_pad=None,
|
||||
max_tokens=max_tokens,
|
||||
max_decode_steps=max_decode_steps,
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
|
@ -190,7 +205,7 @@ class FlashCausalLMBatch(Batch):
|
|||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
|
||||
max_tokens = 0
|
||||
max_decode_steps = None
|
||||
|
||||
for i, r in enumerate(requests):
|
||||
idx = self.requests_idx_mapping[r.id]
|
||||
|
@ -221,11 +236,21 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
stopping_criteria = self.stopping_criterias[idx]
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
|
||||
cumulative_length += request_input_length
|
||||
max_tokens += request_input_length + (
|
||||
# Remaining decode steps for this request
|
||||
remaining_decode = (
|
||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
)
|
||||
if max_decode_steps is None:
|
||||
max_decode_steps = remaining_decode
|
||||
else:
|
||||
max_decode_steps = min(max_decode_steps, remaining_decode)
|
||||
|
||||
cumulative_length += request_input_length
|
||||
|
||||
# Since we are sure that at least one request will be dropped in max_decode_steps,
|
||||
# we know the kv_cache will only grow to cumulative_length + batch_size * max_decode_steps
|
||||
# before getting filtered and decreasing in size
|
||||
max_tokens = cumulative_length + max_decode_steps * len(requests)
|
||||
|
||||
if single_request:
|
||||
# Preallocate tensor for bs = 1 case
|
||||
|
@ -290,7 +315,8 @@ class FlashCausalLMBatch(Batch):
|
|||
# Cumulative length
|
||||
cumulative_batch_size = 0
|
||||
cumulative_length = 0
|
||||
max_tokens = 0
|
||||
|
||||
max_decode_steps = None
|
||||
|
||||
for i, batch in enumerate(batches):
|
||||
requests.extend(batch.requests)
|
||||
|
@ -329,10 +355,16 @@ class FlashCausalLMBatch(Batch):
|
|||
next_token_choosers.extend(batch.next_token_choosers)
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
|
||||
if max_decode_steps is None:
|
||||
max_decode_steps = batch.max_decode_steps
|
||||
else:
|
||||
max_decode_steps = min(max_decode_steps, batch.max_decode_steps)
|
||||
|
||||
# Update
|
||||
cumulative_length += batch.cu_seqlens[-1]
|
||||
cumulative_batch_size += len(batch)
|
||||
max_tokens += batch.max_tokens
|
||||
|
||||
max_tokens = cumulative_length + max_decode_steps * cumulative_batch_size
|
||||
|
||||
return FlashCausalLMBatch(
|
||||
batch_id=batches[0].batch_id,
|
||||
|
@ -352,6 +384,7 @@ class FlashCausalLMBatch(Batch):
|
|||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
max_tokens=max_tokens,
|
||||
max_decode_steps=max_decode_steps,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
|
@ -617,6 +650,7 @@ class FlashCausalLM(Model):
|
|||
batch.all_input_ids[i] = all_input_ids
|
||||
batch.all_input_ids_tensor[i] = all_input_ids_tensor
|
||||
batch.max_seqlen = max(batch.max_seqlen, new_input_length)
|
||||
batch.max_decode_steps -= 1
|
||||
if len(batch) != 1:
|
||||
# Add each sequence before its padding
|
||||
batch.past_key_values[i * 2] = present[:, start_index:end_index]
|
||||
|
|
|
@ -56,6 +56,8 @@ class Seq2SeqLMBatch(Batch):
|
|||
|
||||
# Maximum number of tokens this batch will grow to
|
||||
max_tokens: int
|
||||
# Maximum number of decode steps before at least one request finish
|
||||
max_decode_steps: int
|
||||
|
||||
def to_pb(self) -> generate_pb2.Batch:
|
||||
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
|
||||
|
@ -86,7 +88,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
# Parse batch
|
||||
max_truncation = 0
|
||||
padding_right_offset = 0
|
||||
max_decode_tokens = 0
|
||||
max_decode_steps = None
|
||||
for i, r in enumerate(pb.requests):
|
||||
inputs.append(r.inputs)
|
||||
requests_idx_mapping[r.id] = i
|
||||
|
@ -99,7 +101,15 @@ class Seq2SeqLMBatch(Batch):
|
|||
)
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
max_decode_tokens += stopping_criteria.max_new_tokens
|
||||
|
||||
# Maximum number of decode steps before one request finish
|
||||
if max_decode_steps is None:
|
||||
max_decode_steps = stopping_criteria.max_new_tokens
|
||||
else:
|
||||
max_decode_steps = min(
|
||||
max_decode_steps, stopping_criteria.max_new_tokens
|
||||
)
|
||||
|
||||
padding_right_offset = max(
|
||||
padding_right_offset, stopping_criteria.max_new_tokens
|
||||
)
|
||||
|
@ -125,7 +135,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
)
|
||||
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
||||
|
||||
max_tokens = len(inputs) * max_input_length + max_decode_tokens
|
||||
max_tokens = len(inputs) * (max_input_length + max_decode_steps)
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
|
@ -148,6 +158,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
max_decoder_input_length=1,
|
||||
padding_right_offset=padding_right_offset,
|
||||
max_tokens=max_tokens,
|
||||
max_decode_steps=max_decode_steps,
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
|
@ -177,7 +188,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
max_decoder_input_length = 0
|
||||
padding_right_offset = 0
|
||||
|
||||
remaining_decode_tokens = 0
|
||||
max_decode_steps = None
|
||||
|
||||
for i, r in enumerate(requests):
|
||||
idx = self.requests_idx_mapping[r.id]
|
||||
|
@ -207,9 +218,15 @@ class Seq2SeqLMBatch(Batch):
|
|||
next_token_choosers.append(self.next_token_choosers[idx])
|
||||
stopping_criteria = self.stopping_criterias[idx]
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
remaining_decode_tokens += (
|
||||
|
||||
# Remaining decode steps for this request
|
||||
remaining_decode = (
|
||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
)
|
||||
if max_decode_steps is None:
|
||||
max_decode_steps = remaining_decode
|
||||
else:
|
||||
max_decode_steps = min(max_decode_steps, remaining_decode)
|
||||
|
||||
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
|
||||
self.decoder_input_ids = self.decoder_input_ids[keep_indices]
|
||||
|
@ -240,9 +257,8 @@ class Seq2SeqLMBatch(Batch):
|
|||
layer[2] = layer[2][keep_indices, :, -max_input_length:]
|
||||
layer[3] = layer[3][keep_indices, :, -max_input_length:]
|
||||
|
||||
max_tokens = (
|
||||
len(requests) * (max_input_length + max_decoder_input_length)
|
||||
+ remaining_decode_tokens
|
||||
max_tokens = len(requests) * (
|
||||
max_input_length + max_decoder_input_length + max_decode_steps
|
||||
)
|
||||
|
||||
self.requests = requests
|
||||
|
@ -259,6 +275,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
self.max_decoder_input_length = max_decoder_input_length
|
||||
self.padding_right_offset = padding_right_offset
|
||||
self.max_tokens = max_tokens
|
||||
self.max_decode_steps = max_decode_steps
|
||||
|
||||
return self
|
||||
|
||||
|
@ -290,7 +307,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
token_offsets = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
max_tokens = 0
|
||||
max_decode_steps = 0
|
||||
|
||||
# Batch tensors
|
||||
attention_mask = None
|
||||
|
@ -398,13 +415,11 @@ class Seq2SeqLMBatch(Batch):
|
|||
]
|
||||
|
||||
start_index = end_index
|
||||
# Add eventual padding tokens that were added while concatenating
|
||||
max_tokens += batch.max_tokens + (
|
||||
max_input_length
|
||||
- batch.max_input_length
|
||||
+ max_decoder_input_length
|
||||
- batch.max_decoder_input_length
|
||||
) * len(batch)
|
||||
|
||||
if max_decode_steps is None:
|
||||
max_decode_steps = batch.max_decode_steps
|
||||
else:
|
||||
max_decode_steps = min(max_decode_steps, batch.max_decode_steps)
|
||||
|
||||
# Determine shapes for new past kv tensors
|
||||
first_past_kvs = batches[0].past_key_values
|
||||
|
@ -471,6 +486,10 @@ class Seq2SeqLMBatch(Batch):
|
|||
|
||||
start_index = end_index
|
||||
|
||||
max_tokens = len(requests) * (
|
||||
max_input_length + max_decoder_input_length + max_decode_steps
|
||||
)
|
||||
|
||||
return cls(
|
||||
batch_id=batches[0].batch_id,
|
||||
requests=requests,
|
||||
|
@ -492,6 +511,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
max_decoder_input_length=max_decoder_input_length,
|
||||
padding_right_offset=padding_right_offset,
|
||||
max_tokens=max_tokens,
|
||||
max_decode_steps=max_decode_steps,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
|
@ -717,5 +737,6 @@ class Seq2SeqLM(Model):
|
|||
if batch.decoder_attention_mask is not None:
|
||||
batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
|
||||
batch.padding_right_offset -= 1
|
||||
batch.max_decode_steps -= 1
|
||||
|
||||
return generations, batch
|
||||
|
|
Loading…
Reference in New Issue