diff --git a/server/text_generation/models/galactica.py b/server/text_generation/models/galactica.py index 780a94f1..fad094dd 100644 --- a/server/text_generation/models/galactica.py +++ b/server/text_generation/models/galactica.py @@ -2,7 +2,7 @@ import re import torch import torch.distributed -from typing import List, Optional, Type +from typing import List, Optional, Type, Tuple from accelerate import init_empty_weights from safetensors import safe_open @@ -145,6 +145,20 @@ class Galactica(CausalLM): generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False ) + def forward( + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + """Overwrite forward to ignore position_ids""" + + # Model Forward + outputs = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=True, + ) + return outputs.logits, outputs.past_key_values + class GalacticaSharded(Galactica): def __init__( @@ -322,7 +336,6 @@ class GalacticaSharded(Galactica): outputs = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, - position_ids=position_ids, past_key_values=past_key_values, use_cache=True, )