From 78063c05698a8555a617168b9b126910e9f720c7 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Mon, 20 Feb 2023 19:28:57 +0100 Subject: [PATCH] fix(server): remove position_ids from galactica forward (#82) closes #80 --- server/text_generation/models/galactica.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/server/text_generation/models/galactica.py b/server/text_generation/models/galactica.py index 780a94f..fad094d 100644 --- a/server/text_generation/models/galactica.py +++ b/server/text_generation/models/galactica.py @@ -2,7 +2,7 @@ import re import torch import torch.distributed -from typing import List, Optional, Type +from typing import List, Optional, Type, Tuple from accelerate import init_empty_weights from safetensors import safe_open @@ -145,6 +145,20 @@ class Galactica(CausalLM): generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False ) + def forward( + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + """Overwrite forward to ignore position_ids""" + + # Model Forward + outputs = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=True, + ) + return outputs.logits, outputs.past_key_values + class GalacticaSharded(Galactica): def __init__( @@ -322,7 +336,6 @@ class GalacticaSharded(Galactica): outputs = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, - position_ids=position_ids, past_key_values=past_key_values, use_cache=True, )