From 08e91814180c5a737749f9deadfc45fd0968037a Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 22 Mar 2024 17:10:56 +0100 Subject: [PATCH] feat: update client to 0.7 (#1667) Close #1652 --- clients/python/README.md | 32 ++++++++++++++++-- clients/python/pyproject.toml | 2 +- clients/python/text_generation/client.py | 42 ++++++++++++++++++++++-- clients/python/text_generation/types.py | 17 +++++++++- 4 files changed, 86 insertions(+), 7 deletions(-) diff --git a/clients/python/README.md b/clients/python/README.md index 20243f4a..bf37508e 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -107,7 +107,19 @@ print(text) ### Types ```python -# Request Parameters +# enum for grammar type +class GrammarType(Enum): + Json = "json" + Regex = "regex" + + +# Grammar type and value +class Grammar: + # Grammar type + type: GrammarType + # Grammar value + value: Union[str, dict] + class Parameters: # Activate logits sampling do_sample: bool @@ -116,6 +128,10 @@ class Parameters: # The parameter for repetition penalty. 1.0 means no penalty. # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. repetition_penalty: Optional[float] + # The parameter for frequency penalty. 1.0 means no penalty + # Penalize new tokens based on their existing frequency in the text so far, + # decreasing the model's likelihood to repeat the same line verbatim. + frequency_penalty: Optional[float] # Whether to prepend the prompt to the generated text return_full_text: bool # Stop generating tokens if a member of `stop_sequences` is generated @@ -138,10 +154,22 @@ class Parameters: best_of: Optional[int] # Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) watermark: bool + # Get generation details + details: bool # Get decoder input token logprobs and ids decoder_input_details: bool # Return the N most likely tokens at each step top_n_tokens: Optional[int] + # grammar to use for generation + grammar: Optional[Grammar] + +class Request: + # Prompt + inputs: str + # Generation parameters + parameters: Optional[Parameters] + # Whether to stream output tokens + stream: bool # Decoder input tokens class InputToken: @@ -161,7 +189,7 @@ class Token: # Token text text: str # Logprob - logprob: float + logprob: Optional[float] # Is the token a special token # Can be used to ignore tokens when concatenating special: bool diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index c7a885ef..2925085b 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "text-generation" -version = "0.6.1" +version = "0.7.0" description = "Hugging Face Text Generation Python Client" license = "Apache-2.0" authors = ["Olivier Dehaene "] diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 465bd4fc..95d23901 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -67,6 +67,7 @@ class Client: def chat( self, messages: List[Message], + repetition_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None, logit_bias: Optional[List[float]] = None, logprobs: Optional[bool] = None, @@ -87,9 +88,13 @@ class Client: Args: messages (`List[Message]`): List of messages - frequency_penalty (`float`): - The parameter for frequency penalty. 0.0 means no penalty. See [this + repetition_penalty (`float`): + The parameter for repetition penalty. 0.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + frequency_penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty + Penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat the same line verbatim. logit_bias (`List[float]`): Adjust the likelihood of specified tokens logprobs (`bool`): @@ -121,6 +126,7 @@ class Client: request = ChatRequest( model="tgi", messages=messages, + repetition_penalty=repetition_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, logprobs=logprobs, @@ -179,6 +185,7 @@ class Client: max_new_tokens: int = 20, best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, stop_sequences: Optional[List[str]] = None, @@ -207,6 +214,10 @@ class Client: repetition_penalty (`float`): The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + frequency_penalty (`float`): + The parameter for frequency penalty. 1.0 means no penalty + Penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat the same line verbatim. return_full_text (`bool`): Whether to prepend the prompt to the generated text seed (`int`): @@ -245,6 +256,7 @@ class Client: do_sample=do_sample, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, + frequency_penalty=frequency_penalty, return_full_text=return_full_text, seed=seed, stop=stop_sequences if stop_sequences is not None else [], @@ -278,6 +290,7 @@ class Client: do_sample: bool = False, max_new_tokens: int = 20, repetition_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, stop_sequences: Optional[List[str]] = None, @@ -303,6 +316,10 @@ class Client: repetition_penalty (`float`): The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + frequency_penalty (`float`): + The parameter for frequency penalty. 1.0 means no penalty + Penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat the same line verbatim. return_full_text (`bool`): Whether to prepend the prompt to the generated text seed (`int`): @@ -340,6 +357,7 @@ class Client: do_sample=do_sample, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, + frequency_penalty=frequency_penalty, return_full_text=return_full_text, seed=seed, stop=stop_sequences if stop_sequences is not None else [], @@ -435,6 +453,7 @@ class AsyncClient: async def chat( self, messages: List[Message], + repetition_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None, logit_bias: Optional[List[float]] = None, logprobs: Optional[bool] = None, @@ -455,9 +474,13 @@ class AsyncClient: Args: messages (`List[Message]`): List of messages - frequency_penalty (`float`): + repetition_penalty (`float`): The parameter for frequency penalty. 0.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + frequency_penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty + Penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat the same line verbatim. logit_bias (`List[float]`): Adjust the likelihood of specified tokens logprobs (`bool`): @@ -489,6 +512,7 @@ class AsyncClient: request = ChatRequest( model="tgi", messages=messages, + repetition_penalty=repetition_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, logprobs=logprobs, @@ -546,6 +570,7 @@ class AsyncClient: max_new_tokens: int = 20, best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, stop_sequences: Optional[List[str]] = None, @@ -574,6 +599,10 @@ class AsyncClient: repetition_penalty (`float`): The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + frequency_penalty (`float`): + The parameter for frequency penalty. 1.0 means no penalty + Penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat the same line verbatim. return_full_text (`bool`): Whether to prepend the prompt to the generated text seed (`int`): @@ -614,6 +643,7 @@ class AsyncClient: do_sample=do_sample, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, + frequency_penalty=frequency_penalty, return_full_text=return_full_text, seed=seed, stop=stop_sequences if stop_sequences is not None else [], @@ -644,6 +674,7 @@ class AsyncClient: do_sample: bool = False, max_new_tokens: int = 20, repetition_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, stop_sequences: Optional[List[str]] = None, @@ -669,6 +700,10 @@ class AsyncClient: repetition_penalty (`float`): The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + frequency_penalty (`float`): + The parameter for frequency penalty. 1.0 means no penalty + Penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat the same line verbatim. return_full_text (`bool`): Whether to prepend the prompt to the generated text seed (`int`): @@ -706,6 +741,7 @@ class AsyncClient: do_sample=do_sample, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, + frequency_penalty=frequency_penalty, return_full_text=return_full_text, seed=seed, stop=stop_sequences if stop_sequences is not None else [], diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index f0d859af..deb987c5 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -109,7 +109,12 @@ class ChatRequest(BaseModel): model: str # List of messages in the conversation messages: List[Message] - # Penalty for frequency of new tokens + # The parameter for repetition penalty. 1.0 means no penalty. + # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + repetition_penalty: Optional[float] = None + # The parameter for frequency penalty. 1.0 means no penalty + # Penalize new tokens based on their existing frequency in the text so far, + # decreasing the model's likelihood to repeat the same line verbatim. frequency_penalty: Optional[float] = None # Bias values for token selection logit_bias: Optional[List[float]] = None @@ -145,6 +150,10 @@ class Parameters(BaseModel): # The parameter for repetition penalty. 1.0 means no penalty. # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. repetition_penalty: Optional[float] = None + # The parameter for frequency penalty. 1.0 means no penalty + # Penalize new tokens based on their existing frequency in the text so far, + # decreasing the model's likelihood to repeat the same line verbatim. + frequency_penalty: Optional[float] = None # Whether to prepend the prompt to the generated text return_full_text: bool = False # Stop generating tokens if a member of `stop_sequences` is generated @@ -201,6 +210,12 @@ class Parameters(BaseModel): raise ValidationError("`repetition_penalty` must be strictly positive") return v + @field_validator("frequency_penalty") + def valid_frequency_penalty(cls, v): + if v is not None and v <= 0: + raise ValidationError("`frequency_penalty` must be strictly positive") + return v + @field_validator("seed") def valid_seed(cls, v): if v is not None and v < 0: