diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index a20a6143..92622350 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -496,11 +496,6 @@ class CausalLM(Model): else: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - self.has_position_ids = ( - inspect.signature(model.forward).parameters.get("position_ids", None) - is not None - ) - super(CausalLM, self).__init__( model=model, tokenizer=tokenizer, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 29bad321..6b8472a5 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -1,3 +1,4 @@ +import inspect import torch from abc import ABC, abstractmethod @@ -29,6 +30,12 @@ class Model(ABC): self.device = device self.rank = rank self.world_size = world_size + + self.has_position_ids = ( + inspect.signature(model.forward).parameters.get("position_ids", None) + is not None + ) + self.check_initialized() @property