parent
3fef90d50f
commit
b1485e18c5
|
@ -96,6 +96,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
|
max_sequence_length = 0
|
||||||
|
padding_right_offset = 0
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||||
|
@ -103,8 +105,13 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||||
next_token_choosers.append(
|
next_token_choosers.append(
|
||||||
NextTokenChooser.from_pb(r.parameters, len(tokenizer), device)
|
NextTokenChooser.from_pb(r.parameters, len(tokenizer), device)
|
||||||
)
|
)
|
||||||
stopping_criterias.append(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
r.stopping_parameters, tokenizer
|
||||||
|
)
|
||||||
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
max_sequence_length = max(max_sequence_length, r.input_length)
|
||||||
|
padding_right_offset = max(
|
||||||
|
padding_right_offset, stopping_criteria.max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tokenize batch
|
# Tokenize batch
|
||||||
|
@ -114,6 +121,14 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||||
padding=True,
|
padding=True,
|
||||||
return_token_type_ids=False,
|
return_token_type_ids=False,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
input_ids = tokenized_inputs["input_ids"]
|
||||||
|
# Allocate maximum attention_mask
|
||||||
|
attention_mask = input_ids.new_zeros(
|
||||||
|
(pb.size, max_sequence_length + padding_right_offset)
|
||||||
|
)
|
||||||
|
# Copy tokenizer attention_mask into fully allocated attention_mask
|
||||||
|
attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"]
|
||||||
|
|
||||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||||
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
|
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
|
||||||
|
@ -121,8 +136,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
requests=pb.requests,
|
requests=pb.requests,
|
||||||
input_ids=tokenized_inputs["input_ids"],
|
input_ids=input_ids,
|
||||||
attention_mask=tokenized_inputs["attention_mask"],
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
|
@ -130,7 +145,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=pb.size,
|
size=pb.size,
|
||||||
max_sequence_length=max(input_lengths),
|
max_sequence_length=max_sequence_length,
|
||||||
|
padding_right_offset=padding_right_offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue