Revert "Modify default for max_new_tokens in python client (#1336)"

This reverts commit 2d56f106a6.

It causes a breaking in our integrations-tests.
This commit is contained in:
Nicolas Patry 2024-02-01 14:36:10 +00:00
parent 9ad7b6a1a1
commit 1e03b61b5c
3 changed files with 5 additions and 21 deletions

View File

@ -21,22 +21,6 @@ def test_generate(flan_t5_xxl_url, hf_headers):
assert not response.details.tokens[0].special assert not response.details.tokens[0].special
def test_generate_max_new_tokens_not_set(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
response = client.generate("test", decoder_input_details=True)
assert response.generated_text != ""
assert response.details.finish_reason == FinishReason.EndOfSequenceToken
assert response.details.generated_tokens > 1
assert response.details.seed is None
assert len(response.details.prefill) == 1
assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
assert len(response.details.tokens) > 1
assert response.details.tokens[0].id == 3
assert response.details.tokens[0].text == " "
assert not response.details.tokens[0].special
def test_generate_best_of(flan_t5_xxl_url, hf_headers): def test_generate_best_of(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers) client = Client(flan_t5_xxl_url, hf_headers)
response = client.generate( response = client.generate(

View File

@ -62,7 +62,7 @@ class Client:
self, self,
prompt: str, prompt: str,
do_sample: bool = False, do_sample: bool = False,
max_new_tokens: Optional[int] = None, max_new_tokens: int = 20,
best_of: Optional[int] = None, best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None, repetition_penalty: Optional[float] = None,
return_full_text: bool = False, return_full_text: bool = False,
@ -157,7 +157,7 @@ class Client:
self, self,
prompt: str, prompt: str,
do_sample: bool = False, do_sample: bool = False,
max_new_tokens: Optional[int] = None, max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None, repetition_penalty: Optional[float] = None,
return_full_text: bool = False, return_full_text: bool = False,
seed: Optional[int] = None, seed: Optional[int] = None,
@ -312,7 +312,7 @@ class AsyncClient:
self, self,
prompt: str, prompt: str,
do_sample: bool = False, do_sample: bool = False,
max_new_tokens: Optional[int] = None, max_new_tokens: int = 20,
best_of: Optional[int] = None, best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None, repetition_penalty: Optional[float] = None,
return_full_text: bool = False, return_full_text: bool = False,
@ -405,7 +405,7 @@ class AsyncClient:
self, self,
prompt: str, prompt: str,
do_sample: bool = False, do_sample: bool = False,
max_new_tokens: Optional[int] = None, max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None, repetition_penalty: Optional[float] = None,
return_full_text: bool = False, return_full_text: bool = False,
seed: Optional[int] = None, seed: Optional[int] = None,

View File

@ -9,7 +9,7 @@ class Parameters(BaseModel):
# Activate logits sampling # Activate logits sampling
do_sample: bool = False do_sample: bool = False
# Maximum number of generated tokens # Maximum number of generated tokens
max_new_tokens: Optional[int] = None max_new_tokens: int = 20
# The parameter for repetition penalty. 1.0 means no penalty. # The parameter for repetition penalty. 1.0 means no penalty.
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
repetition_penalty: Optional[float] = None repetition_penalty: Optional[float] = None