From b9633c46d0b64e4c3d52e2e29746abdbfd339dbe Mon Sep 17 00:00:00 2001 From: Jae-Won Chung Date: Mon, 31 Jul 2023 08:35:14 -0400 Subject: [PATCH] Fix typing in `Model.generate_token` (#733) ## What does this PR do? This PR fixes a minor type annotation issue in the signature of `Model.generate_token`. All existing overrides of `Model.generate_token` return `Tuple[List[Generation], Optional[B]]`: https://github.com/huggingface/text-generation-inference/blob/3ef5ffbc6400370ff2e1546550a6bad3ac61b079/server/text_generation_server/models/causal_lm.py#L535-L537 https://github.com/huggingface/text-generation-inference/blob/3ef5ffbc6400370ff2e1546550a6bad3ac61b079/server/text_generation_server/models/flash_causal_lm.py#L802-L804 https://github.com/huggingface/text-generation-inference/blob/3ef5ffbc6400370ff2e1546550a6bad3ac61b079/server/text_generation_server/models/seq2seq_lm.py#L589-L591 I suspect that back in 017a2a8c when `GeneratedText` and `Generation` were separated, the function signature was not updated. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? CC @OlivierDehaene --- server/text_generation_server/models/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 9d74247c..806e9833 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from typing import List, Tuple, Optional, TypeVar, Type from transformers import PreTrainedTokenizerBase, PretrainedConfig -from text_generation_server.models.types import Batch, GeneratedText +from text_generation_server.models.types import Batch, Generation from text_generation_server.pb.generate_pb2 import InfoResponse B = TypeVar("B", bound=Batch) @@ -52,7 +52,7 @@ class Model(ABC): raise NotImplementedError @abstractmethod - def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]: + def generate_token(self, batch: B) -> Tuple[List[Generation], Optional[B]]: raise NotImplementedError def warmup(self, batch: B) -> Optional[int]: