fix(server): remove position_ids from galactica forward (#82)

closes #80
This commit is contained in:
OlivierDehaene 2023-02-20 19:28:57 +01:00 committed by GitHub
parent 17bc841b1b
commit 78063c0569
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 2 deletions

View File

@ -2,7 +2,7 @@ import re
import torch import torch
import torch.distributed import torch.distributed
from typing import List, Optional, Type from typing import List, Optional, Type, Tuple
from accelerate import init_empty_weights from accelerate import init_empty_weights
from safetensors import safe_open from safetensors import safe_open
@ -145,6 +145,20 @@ class Galactica(CausalLM):
generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False
) )
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
"""Overwrite forward to ignore position_ids"""
# Model Forward
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, outputs.past_key_values
class GalacticaSharded(Galactica): class GalacticaSharded(Galactica):
def __init__( def __init__(
@ -322,7 +336,6 @@ class GalacticaSharded(Galactica):
outputs = self.model.forward( outputs = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=True, use_cache=True,
) )