Fix typing in `Model.generate_token` (#733)

## What does this PR do?

This PR fixes a minor type annotation issue in the signature of
`Model.generate_token`.

All existing overrides of `Model.generate_token` return
`Tuple[List[Generation], Optional[B]]`:

3ef5ffbc64/server/text_generation_server/models/causal_lm.py (L535-L537)

3ef5ffbc64/server/text_generation_server/models/flash_causal_lm.py (L802-L804)

3ef5ffbc64/server/text_generation_server/models/seq2seq_lm.py (L589-L591)

I suspect that back in 017a2a8c when `GeneratedText` and `Generation`
were separated, the function signature was not updated.

## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

CC @OlivierDehaene
This commit is contained in:
Jae-Won Chung 2023-07-31 08:35:14 -04:00 committed by GitHub
parent 92bb56b0c1
commit b9633c46d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase, PretrainedConfig from transformers import PreTrainedTokenizerBase, PretrainedConfig
from text_generation_server.models.types import Batch, GeneratedText from text_generation_server.models.types import Batch, Generation
from text_generation_server.pb.generate_pb2 import InfoResponse from text_generation_server.pb.generate_pb2 import InfoResponse
B = TypeVar("B", bound=Batch) B = TypeVar("B", bound=Batch)
@ -52,7 +52,7 @@ class Model(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]: def generate_token(self, batch: B) -> Tuple[List[Generation], Optional[B]]:
raise NotImplementedError raise NotImplementedError
def warmup(self, batch: B) -> Optional[int]: def warmup(self, batch: B) -> Optional[int]: