fix(server): fix galactica batch (#106)

closes #105
This commit is contained in:
OlivierDehaene 2023-03-07 20:05:21 +01:00 committed by GitHub
parent 3fef90d50f
commit b1485e18c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 21 additions and 5 deletions

View File

@ -96,6 +96,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
input_lengths = []
# Parse batch
max_sequence_length = 0
padding_right_offset = 0
for r in pb.requests:
# Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs))
@ -103,8 +105,13 @@ class GalacticaCausalLMBatch(CausalLMBatch):
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, len(tokenizer), device)
)
stopping_criterias.append(
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
max_sequence_length = max(max_sequence_length, r.input_length)
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
# Tokenize batch
@ -114,6 +121,14 @@ class GalacticaCausalLMBatch(CausalLMBatch):
padding=True,
return_token_type_ids=False,
).to(device)
input_ids = tokenized_inputs["input_ids"]
# Allocate maximum attention_mask
attention_mask = input_ids.new_zeros(
(pb.size, max_sequence_length + padding_right_offset)
)
# Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"]
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
@ -121,8 +136,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
return cls(
batch_id=pb.id,
requests=pb.requests,
input_ids=tokenized_inputs["input_ids"],
attention_mask=tokenized_inputs["attention_mask"],
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=None,
all_input_ids=all_input_ids,
@ -130,7 +145,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=pb.size,
max_sequence_length=max(input_lengths),
max_sequence_length=max_sequence_length,
padding_right_offset=padding_right_offset,
)