fix(server): fix has_position_ids (#395)

Fix #389
This commit is contained in:
OlivierDehaene 2023-06-01 11:41:35 +02:00 committed by GitHub
parent db2ebe3947
commit d69a0633be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 5 deletions

View File

@ -496,11 +496,6 @@ class CausalLM(Model):
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None)
is not None
)
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,

View File

@ -1,3 +1,4 @@
import inspect
import torch
from abc import ABC, abstractmethod
@ -29,6 +30,12 @@ class Model(ABC):
self.device = device
self.rank = rank
self.world_size = world_size
self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None)
is not None
)
self.check_initialized()
@property