From 2d56f106a60c7b698705494e7539f8a7e4c85dd9 Mon Sep 17 00:00:00 2001 From: freitng <153592523+freitng@users.noreply.github.com> Date: Mon, 29 Jan 2024 17:02:57 +0100 Subject: [PATCH] Modify default for max_new_tokens in python client (#1336) # What does this PR do? Since ([#1097](https://github.com/huggingface/text-generation-inference/pull/1097)) the clients do not need to specify a max_length anymore. However, the python client in this repo had not yet been adapted to these changes. This PR makes it possible to use the python client and not provide max_new_tokens. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [x] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- clients/python/tests/test_client.py | 16 ++++++++++++++++ clients/python/text_generation/client.py | 8 ++++---- clients/python/text_generation/types.py | 2 +- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index 1e25e1b1..775e7a6c 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -21,6 +21,22 @@ def test_generate(flan_t5_xxl_url, hf_headers): assert not response.details.tokens[0].special +def test_generate_max_new_tokens_not_set(flan_t5_xxl_url, hf_headers): + client = Client(flan_t5_xxl_url, hf_headers) + response = client.generate("test", decoder_input_details=True) + + assert response.generated_text != "" + assert response.details.finish_reason == FinishReason.EndOfSequenceToken + assert response.details.generated_tokens > 1 + assert response.details.seed is None + assert len(response.details.prefill) == 1 + assert response.details.prefill[0] == InputToken(id=0, text="", logprob=None) + assert len(response.details.tokens) > 1 + assert response.details.tokens[0].id == 3 + assert response.details.tokens[0].text == " " + assert not response.details.tokens[0].special + + def test_generate_best_of(flan_t5_xxl_url, hf_headers): client = Client(flan_t5_xxl_url, hf_headers) response = client.generate( diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 0bf80f8c..63b5258d 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -62,7 +62,7 @@ class Client: self, prompt: str, do_sample: bool = False, - max_new_tokens: int = 20, + max_new_tokens: Optional[int] = None, best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, @@ -157,7 +157,7 @@ class Client: self, prompt: str, do_sample: bool = False, - max_new_tokens: int = 20, + max_new_tokens: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, @@ -312,7 +312,7 @@ class AsyncClient: self, prompt: str, do_sample: bool = False, - max_new_tokens: int = 20, + max_new_tokens: Optional[int] = None, best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, @@ -405,7 +405,7 @@ class AsyncClient: self, prompt: str, do_sample: bool = False, - max_new_tokens: int = 20, + max_new_tokens: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index aa02d8d8..7fa8033e 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -9,7 +9,7 @@ class Parameters(BaseModel): # Activate logits sampling do_sample: bool = False # Maximum number of generated tokens - max_new_tokens: int = 20 + max_new_tokens: Optional[int] = None # The parameter for repetition penalty. 1.0 means no penalty. # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. repetition_penalty: Optional[float] = None