diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index 1e25e1b1..775e7a6c 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -21,6 +21,22 @@ def test_generate(flan_t5_xxl_url, hf_headers): assert not response.details.tokens[0].special +def test_generate_max_new_tokens_not_set(flan_t5_xxl_url, hf_headers): + client = Client(flan_t5_xxl_url, hf_headers) + response = client.generate("test", decoder_input_details=True) + + assert response.generated_text != "" + assert response.details.finish_reason == FinishReason.EndOfSequenceToken + assert response.details.generated_tokens > 1 + assert response.details.seed is None + assert len(response.details.prefill) == 1 + assert response.details.prefill[0] == InputToken(id=0, text="", logprob=None) + assert len(response.details.tokens) > 1 + assert response.details.tokens[0].id == 3 + assert response.details.tokens[0].text == " " + assert not response.details.tokens[0].special + + def test_generate_best_of(flan_t5_xxl_url, hf_headers): client = Client(flan_t5_xxl_url, hf_headers) response = client.generate( diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 0bf80f8c..63b5258d 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -62,7 +62,7 @@ class Client: self, prompt: str, do_sample: bool = False, - max_new_tokens: int = 20, + max_new_tokens: Optional[int] = None, best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, @@ -157,7 +157,7 @@ class Client: self, prompt: str, do_sample: bool = False, - max_new_tokens: int = 20, + max_new_tokens: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, @@ -312,7 +312,7 @@ class AsyncClient: self, prompt: str, do_sample: bool = False, - max_new_tokens: int = 20, + max_new_tokens: Optional[int] = None, best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, @@ -405,7 +405,7 @@ class AsyncClient: self, prompt: str, do_sample: bool = False, - max_new_tokens: int = 20, + max_new_tokens: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index aa02d8d8..7fa8033e 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -9,7 +9,7 @@ class Parameters(BaseModel): # Activate logits sampling do_sample: bool = False # Maximum number of generated tokens - max_new_tokens: int = 20 + max_new_tokens: Optional[int] = None # The parameter for repetition penalty. 1.0 means no penalty. # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. repetition_penalty: Optional[float] = None