Fix typing in `Model.generate_token` (#733)
## What does this PR do? This PR fixes a minor type annotation issue in the signature of `Model.generate_token`. All existing overrides of `Model.generate_token` return `Tuple[List[Generation], Optional[B]]`:3ef5ffbc64/server/text_generation_server/models/causal_lm.py (L535-L537)
3ef5ffbc64/server/text_generation_server/models/flash_causal_lm.py (L802-L804)
3ef5ffbc64/server/text_generation_server/models/seq2seq_lm.py (L589-L591)
I suspect that back in017a2a8c
when `GeneratedText` and `Generation` were separated, the function signature was not updated. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? CC @OlivierDehaene
This commit is contained in:
parent
92bb56b0c1
commit
b9633c46d0
|
@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
||||||
from typing import List, Tuple, Optional, TypeVar, Type
|
from typing import List, Tuple, Optional, TypeVar, Type
|
||||||
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
||||||
|
|
||||||
from text_generation_server.models.types import Batch, GeneratedText
|
from text_generation_server.models.types import Batch, Generation
|
||||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||||
|
|
||||||
B = TypeVar("B", bound=Batch)
|
B = TypeVar("B", bound=Batch)
|
||||||
|
@ -52,7 +52,7 @@ class Model(ABC):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
|
def generate_token(self, batch: B) -> Tuple[List[Generation], Optional[B]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def warmup(self, batch: B) -> Optional[int]:
|
def warmup(self, batch: B) -> Optional[int]:
|
||||||
|
|
Loading…
Reference in New Issue