fix(server): remove position_ids from galactica forward (#82)
closes #80
This commit is contained in:
parent
17bc841b1b
commit
78063c0569
|
@ -2,7 +2,7 @@ import re
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from typing import List, Optional, Type
|
from typing import List, Optional, Type, Tuple
|
||||||
|
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
|
@ -145,6 +145,20 @@ class Galactica(CausalLM):
|
||||||
generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False
|
generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
|
"""Overwrite forward to ignore position_ids"""
|
||||||
|
|
||||||
|
# Model Forward
|
||||||
|
outputs = self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
return outputs.logits, outputs.past_key_values
|
||||||
|
|
||||||
|
|
||||||
class GalacticaSharded(Galactica):
|
class GalacticaSharded(Galactica):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -322,7 +336,6 @@ class GalacticaSharded(Galactica):
|
||||||
outputs = self.model.forward(
|
outputs = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue