diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index b4d1c55..484621d 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -96,6 +96,8 @@ class GalacticaCausalLMBatch(CausalLMBatch): input_lengths = [] # Parse batch + max_sequence_length = 0 + padding_right_offset = 0 for r in pb.requests: # Add escape_custom_split_sequence to the CausalLMBatch logic inputs.append(escape_custom_split_sequence(r.inputs)) @@ -103,8 +105,13 @@ class GalacticaCausalLMBatch(CausalLMBatch): next_token_choosers.append( NextTokenChooser.from_pb(r.parameters, len(tokenizer), device) ) - stopping_criterias.append( - StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) + stopping_criteria = StoppingCriteria.from_pb( + r.stopping_parameters, tokenizer + ) + stopping_criterias.append(stopping_criteria) + max_sequence_length = max(max_sequence_length, r.input_length) + padding_right_offset = max( + padding_right_offset, stopping_criteria.max_new_tokens ) # Tokenize batch @@ -114,6 +121,14 @@ class GalacticaCausalLMBatch(CausalLMBatch): padding=True, return_token_type_ids=False, ).to(device) + input_ids = tokenized_inputs["input_ids"] + # Allocate maximum attention_mask + attention_mask = input_ids.new_zeros( + (pb.size, max_sequence_length + padding_right_offset) + ) + # Copy tokenizer attention_mask into fully allocated attention_mask + attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"] + position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1) @@ -121,8 +136,8 @@ class GalacticaCausalLMBatch(CausalLMBatch): return cls( batch_id=pb.id, requests=pb.requests, - input_ids=tokenized_inputs["input_ids"], - attention_mask=tokenized_inputs["attention_mask"], + input_ids=input_ids, + attention_mask=attention_mask, position_ids=position_ids, past_key_values=None, all_input_ids=all_input_ids, @@ -130,7 +145,8 @@ class GalacticaCausalLMBatch(CausalLMBatch): next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=pb.size, - max_sequence_length=max(input_lengths), + max_sequence_length=max_sequence_length, + padding_right_offset=padding_right_offset, )