From 1e03b61b5c56e2ed5c723457df21cc18d48c1854 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 1 Feb 2024 14:36:10 +0000 Subject: [PATCH] Revert "Modify default for max_new_tokens in python client (#1336)" This reverts commit 2d56f106a60c7b698705494e7539f8a7e4c85dd9. It causes a breaking in our integrations-tests. --- clients/python/tests/test_client.py | 16 ---------------- clients/python/text_generation/client.py | 8 ++++---- clients/python/text_generation/types.py | 2 +- 3 files changed, 5 insertions(+), 21 deletions(-) diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index 775e7a6c..1e25e1b1 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -21,22 +21,6 @@ def test_generate(flan_t5_xxl_url, hf_headers): assert not response.details.tokens[0].special -def test_generate_max_new_tokens_not_set(flan_t5_xxl_url, hf_headers): - client = Client(flan_t5_xxl_url, hf_headers) - response = client.generate("test", decoder_input_details=True) - - assert response.generated_text != "" - assert response.details.finish_reason == FinishReason.EndOfSequenceToken - assert response.details.generated_tokens > 1 - assert response.details.seed is None - assert len(response.details.prefill) == 1 - assert response.details.prefill[0] == InputToken(id=0, text="", logprob=None) - assert len(response.details.tokens) > 1 - assert response.details.tokens[0].id == 3 - assert response.details.tokens[0].text == " " - assert not response.details.tokens[0].special - - def test_generate_best_of(flan_t5_xxl_url, hf_headers): client = Client(flan_t5_xxl_url, hf_headers) response = client.generate( diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 63b5258d..0bf80f8c 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -62,7 +62,7 @@ class Client: self, prompt: str, do_sample: bool = False, - max_new_tokens: Optional[int] = None, + max_new_tokens: int = 20, best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, @@ -157,7 +157,7 @@ class Client: self, prompt: str, do_sample: bool = False, - max_new_tokens: Optional[int] = None, + max_new_tokens: int = 20, repetition_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, @@ -312,7 +312,7 @@ class AsyncClient: self, prompt: str, do_sample: bool = False, - max_new_tokens: Optional[int] = None, + max_new_tokens: int = 20, best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, @@ -405,7 +405,7 @@ class AsyncClient: self, prompt: str, do_sample: bool = False, - max_new_tokens: Optional[int] = None, + max_new_tokens: int = 20, repetition_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 7fa8033e..aa02d8d8 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -9,7 +9,7 @@ class Parameters(BaseModel): # Activate logits sampling do_sample: bool = False # Maximum number of generated tokens - max_new_tokens: Optional[int] = None + max_new_tokens: int = 20 # The parameter for repetition penalty. 1.0 means no penalty. # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. repetition_penalty: Optional[float] = None