From d69a0633bee6f8a665a1f7d258fceaa4475c102f Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Thu, 1 Jun 2023 11:41:35 +0200 Subject: [PATCH] fix(server): fix has_position_ids (#395) Fix #389 --- server/text_generation_server/models/causal_lm.py | 5 ----- server/text_generation_server/models/model.py | 7 +++++++ 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index a20a6143..92622350 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -496,11 +496,6 @@ class CausalLM(Model): else: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - self.has_position_ids = ( - inspect.signature(model.forward).parameters.get("position_ids", None) - is not None - ) - super(CausalLM, self).__init__( model=model, tokenizer=tokenizer, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 29bad321..6b8472a5 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -1,3 +1,4 @@ +import inspect import torch from abc import ABC, abstractmethod @@ -29,6 +30,12 @@ class Model(ABC): self.device = device self.rank = rank self.world_size = world_size + + self.has_position_ids = ( + inspect.signature(model.forward).parameters.get("position_ids", None) + is not None + ) + self.check_initialized() @property