parent
3fef90d50f
commit
b1485e18c5
|
@ -96,6 +96,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||
input_lengths = []
|
||||
|
||||
# Parse batch
|
||||
max_sequence_length = 0
|
||||
padding_right_offset = 0
|
||||
for r in pb.requests:
|
||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||
|
@ -103,8 +105,13 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||
next_token_choosers.append(
|
||||
NextTokenChooser.from_pb(r.parameters, len(tokenizer), device)
|
||||
)
|
||||
stopping_criterias.append(
|
||||
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
max_sequence_length = max(max_sequence_length, r.input_length)
|
||||
padding_right_offset = max(
|
||||
padding_right_offset, stopping_criteria.max_new_tokens
|
||||
)
|
||||
|
||||
# Tokenize batch
|
||||
|
@ -114,6 +121,14 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||
padding=True,
|
||||
return_token_type_ids=False,
|
||||
).to(device)
|
||||
input_ids = tokenized_inputs["input_ids"]
|
||||
# Allocate maximum attention_mask
|
||||
attention_mask = input_ids.new_zeros(
|
||||
(pb.size, max_sequence_length + padding_right_offset)
|
||||
)
|
||||
# Copy tokenizer attention_mask into fully allocated attention_mask
|
||||
attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"]
|
||||
|
||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
|
||||
|
@ -121,8 +136,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||
return cls(
|
||||
batch_id=pb.id,
|
||||
requests=pb.requests,
|
||||
input_ids=tokenized_inputs["input_ids"],
|
||||
attention_mask=tokenized_inputs["attention_mask"],
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
all_input_ids=all_input_ids,
|
||||
|
@ -130,7 +145,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
size=pb.size,
|
||||
max_sequence_length=max(input_lengths),
|
||||
max_sequence_length=max_sequence_length,
|
||||
padding_right_offset=padding_right_offset,
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue