parent
db2ebe3947
commit
d69a0633be
|
@ -496,11 +496,6 @@ class CausalLM(Model):
|
||||||
else:
|
else:
|
||||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
|
|
||||||
self.has_position_ids = (
|
|
||||||
inspect.signature(model.forward).parameters.get("position_ids", None)
|
|
||||||
is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import inspect
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
@ -29,6 +30,12 @@ class Model(ABC):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
|
|
||||||
|
self.has_position_ids = (
|
||||||
|
inspect.signature(model.forward).parameters.get("position_ids", None)
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
||||||
self.check_initialized()
|
self.check_initialized()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
Loading…
Reference in New Issue