Revert "Modify default for max_new_tokens in python client (#1336)"
This reverts commit 2d56f106a6
.
It causes a breaking in our integrations-tests.
This commit is contained in:
parent
9ad7b6a1a1
commit
1e03b61b5c
|
@ -21,22 +21,6 @@ def test_generate(flan_t5_xxl_url, hf_headers):
|
|||
assert not response.details.tokens[0].special
|
||||
|
||||
|
||||
def test_generate_max_new_tokens_not_set(flan_t5_xxl_url, hf_headers):
|
||||
client = Client(flan_t5_xxl_url, hf_headers)
|
||||
response = client.generate("test", decoder_input_details=True)
|
||||
|
||||
assert response.generated_text != ""
|
||||
assert response.details.finish_reason == FinishReason.EndOfSequenceToken
|
||||
assert response.details.generated_tokens > 1
|
||||
assert response.details.seed is None
|
||||
assert len(response.details.prefill) == 1
|
||||
assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
|
||||
assert len(response.details.tokens) > 1
|
||||
assert response.details.tokens[0].id == 3
|
||||
assert response.details.tokens[0].text == " "
|
||||
assert not response.details.tokens[0].special
|
||||
|
||||
|
||||
def test_generate_best_of(flan_t5_xxl_url, hf_headers):
|
||||
client = Client(flan_t5_xxl_url, hf_headers)
|
||||
response = client.generate(
|
||||
|
|
|
@ -62,7 +62,7 @@ class Client:
|
|||
self,
|
||||
prompt: str,
|
||||
do_sample: bool = False,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
max_new_tokens: int = 20,
|
||||
best_of: Optional[int] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
return_full_text: bool = False,
|
||||
|
@ -157,7 +157,7 @@ class Client:
|
|||
self,
|
||||
prompt: str,
|
||||
do_sample: bool = False,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
max_new_tokens: int = 20,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
return_full_text: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
|
@ -312,7 +312,7 @@ class AsyncClient:
|
|||
self,
|
||||
prompt: str,
|
||||
do_sample: bool = False,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
max_new_tokens: int = 20,
|
||||
best_of: Optional[int] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
return_full_text: bool = False,
|
||||
|
@ -405,7 +405,7 @@ class AsyncClient:
|
|||
self,
|
||||
prompt: str,
|
||||
do_sample: bool = False,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
max_new_tokens: int = 20,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
return_full_text: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
|
|
|
@ -9,7 +9,7 @@ class Parameters(BaseModel):
|
|||
# Activate logits sampling
|
||||
do_sample: bool = False
|
||||
# Maximum number of generated tokens
|
||||
max_new_tokens: Optional[int] = None
|
||||
max_new_tokens: int = 20
|
||||
# The parameter for repetition penalty. 1.0 means no penalty.
|
||||
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
repetition_penalty: Optional[float] = None
|
||||
|
|
Loading…
Reference in New Issue