Modify default for max_new_tokens in python client (#1336)
# What does this PR do? Since ([#1097](https://github.com/huggingface/text-generation-inference/pull/1097)) the clients do not need to specify a max_length anymore. However, the python client in this repo had not yet been adapted to these changes. This PR makes it possible to use the python client and not provide max_new_tokens. <!-- Remove if not applicable --> ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [x] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
This commit is contained in:
parent
a9ea60684b
commit
2d56f106a6
|
@ -21,6 +21,22 @@ def test_generate(flan_t5_xxl_url, hf_headers):
|
||||||
assert not response.details.tokens[0].special
|
assert not response.details.tokens[0].special
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_max_new_tokens_not_set(flan_t5_xxl_url, hf_headers):
|
||||||
|
client = Client(flan_t5_xxl_url, hf_headers)
|
||||||
|
response = client.generate("test", decoder_input_details=True)
|
||||||
|
|
||||||
|
assert response.generated_text != ""
|
||||||
|
assert response.details.finish_reason == FinishReason.EndOfSequenceToken
|
||||||
|
assert response.details.generated_tokens > 1
|
||||||
|
assert response.details.seed is None
|
||||||
|
assert len(response.details.prefill) == 1
|
||||||
|
assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
|
||||||
|
assert len(response.details.tokens) > 1
|
||||||
|
assert response.details.tokens[0].id == 3
|
||||||
|
assert response.details.tokens[0].text == " "
|
||||||
|
assert not response.details.tokens[0].special
|
||||||
|
|
||||||
|
|
||||||
def test_generate_best_of(flan_t5_xxl_url, hf_headers):
|
def test_generate_best_of(flan_t5_xxl_url, hf_headers):
|
||||||
client = Client(flan_t5_xxl_url, hf_headers)
|
client = Client(flan_t5_xxl_url, hf_headers)
|
||||||
response = client.generate(
|
response = client.generate(
|
||||||
|
|
|
@ -62,7 +62,7 @@ class Client:
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
do_sample: bool = False,
|
do_sample: bool = False,
|
||||||
max_new_tokens: int = 20,
|
max_new_tokens: Optional[int] = None,
|
||||||
best_of: Optional[int] = None,
|
best_of: Optional[int] = None,
|
||||||
repetition_penalty: Optional[float] = None,
|
repetition_penalty: Optional[float] = None,
|
||||||
return_full_text: bool = False,
|
return_full_text: bool = False,
|
||||||
|
@ -157,7 +157,7 @@ class Client:
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
do_sample: bool = False,
|
do_sample: bool = False,
|
||||||
max_new_tokens: int = 20,
|
max_new_tokens: Optional[int] = None,
|
||||||
repetition_penalty: Optional[float] = None,
|
repetition_penalty: Optional[float] = None,
|
||||||
return_full_text: bool = False,
|
return_full_text: bool = False,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
|
@ -312,7 +312,7 @@ class AsyncClient:
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
do_sample: bool = False,
|
do_sample: bool = False,
|
||||||
max_new_tokens: int = 20,
|
max_new_tokens: Optional[int] = None,
|
||||||
best_of: Optional[int] = None,
|
best_of: Optional[int] = None,
|
||||||
repetition_penalty: Optional[float] = None,
|
repetition_penalty: Optional[float] = None,
|
||||||
return_full_text: bool = False,
|
return_full_text: bool = False,
|
||||||
|
@ -405,7 +405,7 @@ class AsyncClient:
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
do_sample: bool = False,
|
do_sample: bool = False,
|
||||||
max_new_tokens: int = 20,
|
max_new_tokens: Optional[int] = None,
|
||||||
repetition_penalty: Optional[float] = None,
|
repetition_penalty: Optional[float] = None,
|
||||||
return_full_text: bool = False,
|
return_full_text: bool = False,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
|
|
|
@ -9,7 +9,7 @@ class Parameters(BaseModel):
|
||||||
# Activate logits sampling
|
# Activate logits sampling
|
||||||
do_sample: bool = False
|
do_sample: bool = False
|
||||||
# Maximum number of generated tokens
|
# Maximum number of generated tokens
|
||||||
max_new_tokens: int = 20
|
max_new_tokens: Optional[int] = None
|
||||||
# The parameter for repetition penalty. 1.0 means no penalty.
|
# The parameter for repetition penalty. 1.0 means no penalty.
|
||||||
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
repetition_penalty: Optional[float] = None
|
repetition_penalty: Optional[float] = None
|
||||||
|
|
Loading…
Reference in New Issue