Modify default for max_new_tokens in python client (#1336)

# What does this PR do?
Since
([#1097](https://github.com/huggingface/text-generation-inference/pull/1097))
the clients do not need to specify a max_length anymore. However, the
python client in this repo had not yet been adapted to these changes.
This PR makes it possible to use the python client and not provide
max_new_tokens.

<!-- Remove if not applicable -->


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [x] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
This commit is contained in:
freitng 2024-01-29 17:02:57 +01:00 committed by GitHub
parent a9ea60684b
commit 2d56f106a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 21 additions and 5 deletions

View File

@ -21,6 +21,22 @@ def test_generate(flan_t5_xxl_url, hf_headers):
assert not response.details.tokens[0].special assert not response.details.tokens[0].special
def test_generate_max_new_tokens_not_set(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
response = client.generate("test", decoder_input_details=True)
assert response.generated_text != ""
assert response.details.finish_reason == FinishReason.EndOfSequenceToken
assert response.details.generated_tokens > 1
assert response.details.seed is None
assert len(response.details.prefill) == 1
assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
assert len(response.details.tokens) > 1
assert response.details.tokens[0].id == 3
assert response.details.tokens[0].text == " "
assert not response.details.tokens[0].special
def test_generate_best_of(flan_t5_xxl_url, hf_headers): def test_generate_best_of(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers) client = Client(flan_t5_xxl_url, hf_headers)
response = client.generate( response = client.generate(

View File

@ -62,7 +62,7 @@ class Client:
self, self,
prompt: str, prompt: str,
do_sample: bool = False, do_sample: bool = False,
max_new_tokens: int = 20, max_new_tokens: Optional[int] = None,
best_of: Optional[int] = None, best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None, repetition_penalty: Optional[float] = None,
return_full_text: bool = False, return_full_text: bool = False,
@ -157,7 +157,7 @@ class Client:
self, self,
prompt: str, prompt: str,
do_sample: bool = False, do_sample: bool = False,
max_new_tokens: int = 20, max_new_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None, repetition_penalty: Optional[float] = None,
return_full_text: bool = False, return_full_text: bool = False,
seed: Optional[int] = None, seed: Optional[int] = None,
@ -312,7 +312,7 @@ class AsyncClient:
self, self,
prompt: str, prompt: str,
do_sample: bool = False, do_sample: bool = False,
max_new_tokens: int = 20, max_new_tokens: Optional[int] = None,
best_of: Optional[int] = None, best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None, repetition_penalty: Optional[float] = None,
return_full_text: bool = False, return_full_text: bool = False,
@ -405,7 +405,7 @@ class AsyncClient:
self, self,
prompt: str, prompt: str,
do_sample: bool = False, do_sample: bool = False,
max_new_tokens: int = 20, max_new_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None, repetition_penalty: Optional[float] = None,
return_full_text: bool = False, return_full_text: bool = False,
seed: Optional[int] = None, seed: Optional[int] = None,

View File

@ -9,7 +9,7 @@ class Parameters(BaseModel):
# Activate logits sampling # Activate logits sampling
do_sample: bool = False do_sample: bool = False
# Maximum number of generated tokens # Maximum number of generated tokens
max_new_tokens: int = 20 max_new_tokens: Optional[int] = None
# The parameter for repetition penalty. 1.0 means no penalty. # The parameter for repetition penalty. 1.0 means no penalty.
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
repetition_penalty: Optional[float] = None repetition_penalty: Optional[float] = None