feat: update client to 0.7 (#1667)

Close #1652
This commit is contained in:
OlivierDehaene 2024-03-22 17:10:56 +01:00 committed by GitHub
parent deb440b3a2
commit 08e9181418
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 86 additions and 7 deletions

View File

@ -107,7 +107,19 @@ print(text)
### Types ### Types
```python ```python
# Request Parameters # enum for grammar type
class GrammarType(Enum):
Json = "json"
Regex = "regex"
# Grammar type and value
class Grammar:
# Grammar type
type: GrammarType
# Grammar value
value: Union[str, dict]
class Parameters: class Parameters:
# Activate logits sampling # Activate logits sampling
do_sample: bool do_sample: bool
@ -116,6 +128,10 @@ class Parameters:
# The parameter for repetition penalty. 1.0 means no penalty. # The parameter for repetition penalty. 1.0 means no penalty.
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
repetition_penalty: Optional[float] repetition_penalty: Optional[float]
# The parameter for frequency penalty. 1.0 means no penalty
# Penalize new tokens based on their existing frequency in the text so far,
# decreasing the model's likelihood to repeat the same line verbatim.
frequency_penalty: Optional[float]
# Whether to prepend the prompt to the generated text # Whether to prepend the prompt to the generated text
return_full_text: bool return_full_text: bool
# Stop generating tokens if a member of `stop_sequences` is generated # Stop generating tokens if a member of `stop_sequences` is generated
@ -138,10 +154,22 @@ class Parameters:
best_of: Optional[int] best_of: Optional[int]
# Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) # Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
watermark: bool watermark: bool
# Get generation details
details: bool
# Get decoder input token logprobs and ids # Get decoder input token logprobs and ids
decoder_input_details: bool decoder_input_details: bool
# Return the N most likely tokens at each step # Return the N most likely tokens at each step
top_n_tokens: Optional[int] top_n_tokens: Optional[int]
# grammar to use for generation
grammar: Optional[Grammar]
class Request:
# Prompt
inputs: str
# Generation parameters
parameters: Optional[Parameters]
# Whether to stream output tokens
stream: bool
# Decoder input tokens # Decoder input tokens
class InputToken: class InputToken:
@ -161,7 +189,7 @@ class Token:
# Token text # Token text
text: str text: str
# Logprob # Logprob
logprob: float logprob: Optional[float]
# Is the token a special token # Is the token a special token
# Can be used to ignore tokens when concatenating # Can be used to ignore tokens when concatenating
special: bool special: bool

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "text-generation" name = "text-generation"
version = "0.6.1" version = "0.7.0"
description = "Hugging Face Text Generation Python Client" description = "Hugging Face Text Generation Python Client"
license = "Apache-2.0" license = "Apache-2.0"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]

View File

@ -67,6 +67,7 @@ class Client:
def chat( def chat(
self, self,
messages: List[Message], messages: List[Message],
repetition_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None,
logit_bias: Optional[List[float]] = None, logit_bias: Optional[List[float]] = None,
logprobs: Optional[bool] = None, logprobs: Optional[bool] = None,
@ -87,9 +88,13 @@ class Client:
Args: Args:
messages (`List[Message]`): messages (`List[Message]`):
List of messages List of messages
frequency_penalty (`float`): repetition_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty. See [this The parameter for repetition penalty. 0.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
frequency_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty
Penalize new tokens based on their existing frequency in the text so far,
decreasing the model's likelihood to repeat the same line verbatim.
logit_bias (`List[float]`): logit_bias (`List[float]`):
Adjust the likelihood of specified tokens Adjust the likelihood of specified tokens
logprobs (`bool`): logprobs (`bool`):
@ -121,6 +126,7 @@ class Client:
request = ChatRequest( request = ChatRequest(
model="tgi", model="tgi",
messages=messages, messages=messages,
repetition_penalty=repetition_penalty,
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
logit_bias=logit_bias, logit_bias=logit_bias,
logprobs=logprobs, logprobs=logprobs,
@ -179,6 +185,7 @@ class Client:
max_new_tokens: int = 20, max_new_tokens: int = 20,
best_of: Optional[int] = None, best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None, repetition_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
return_full_text: bool = False, return_full_text: bool = False,
seed: Optional[int] = None, seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
@ -207,6 +214,10 @@ class Client:
repetition_penalty (`float`): repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
frequency_penalty (`float`):
The parameter for frequency penalty. 1.0 means no penalty
Penalize new tokens based on their existing frequency in the text so far,
decreasing the model's likelihood to repeat the same line verbatim.
return_full_text (`bool`): return_full_text (`bool`):
Whether to prepend the prompt to the generated text Whether to prepend the prompt to the generated text
seed (`int`): seed (`int`):
@ -245,6 +256,7 @@ class Client:
do_sample=do_sample, do_sample=do_sample,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
frequency_penalty=frequency_penalty,
return_full_text=return_full_text, return_full_text=return_full_text,
seed=seed, seed=seed,
stop=stop_sequences if stop_sequences is not None else [], stop=stop_sequences if stop_sequences is not None else [],
@ -278,6 +290,7 @@ class Client:
do_sample: bool = False, do_sample: bool = False,
max_new_tokens: int = 20, max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None, repetition_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
return_full_text: bool = False, return_full_text: bool = False,
seed: Optional[int] = None, seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
@ -303,6 +316,10 @@ class Client:
repetition_penalty (`float`): repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
frequency_penalty (`float`):
The parameter for frequency penalty. 1.0 means no penalty
Penalize new tokens based on their existing frequency in the text so far,
decreasing the model's likelihood to repeat the same line verbatim.
return_full_text (`bool`): return_full_text (`bool`):
Whether to prepend the prompt to the generated text Whether to prepend the prompt to the generated text
seed (`int`): seed (`int`):
@ -340,6 +357,7 @@ class Client:
do_sample=do_sample, do_sample=do_sample,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
frequency_penalty=frequency_penalty,
return_full_text=return_full_text, return_full_text=return_full_text,
seed=seed, seed=seed,
stop=stop_sequences if stop_sequences is not None else [], stop=stop_sequences if stop_sequences is not None else [],
@ -435,6 +453,7 @@ class AsyncClient:
async def chat( async def chat(
self, self,
messages: List[Message], messages: List[Message],
repetition_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None,
logit_bias: Optional[List[float]] = None, logit_bias: Optional[List[float]] = None,
logprobs: Optional[bool] = None, logprobs: Optional[bool] = None,
@ -455,9 +474,13 @@ class AsyncClient:
Args: Args:
messages (`List[Message]`): messages (`List[Message]`):
List of messages List of messages
frequency_penalty (`float`): repetition_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty. See [this The parameter for frequency penalty. 0.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
frequency_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty
Penalize new tokens based on their existing frequency in the text so far,
decreasing the model's likelihood to repeat the same line verbatim.
logit_bias (`List[float]`): logit_bias (`List[float]`):
Adjust the likelihood of specified tokens Adjust the likelihood of specified tokens
logprobs (`bool`): logprobs (`bool`):
@ -489,6 +512,7 @@ class AsyncClient:
request = ChatRequest( request = ChatRequest(
model="tgi", model="tgi",
messages=messages, messages=messages,
repetition_penalty=repetition_penalty,
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
logit_bias=logit_bias, logit_bias=logit_bias,
logprobs=logprobs, logprobs=logprobs,
@ -546,6 +570,7 @@ class AsyncClient:
max_new_tokens: int = 20, max_new_tokens: int = 20,
best_of: Optional[int] = None, best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None, repetition_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
return_full_text: bool = False, return_full_text: bool = False,
seed: Optional[int] = None, seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
@ -574,6 +599,10 @@ class AsyncClient:
repetition_penalty (`float`): repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
frequency_penalty (`float`):
The parameter for frequency penalty. 1.0 means no penalty
Penalize new tokens based on their existing frequency in the text so far,
decreasing the model's likelihood to repeat the same line verbatim.
return_full_text (`bool`): return_full_text (`bool`):
Whether to prepend the prompt to the generated text Whether to prepend the prompt to the generated text
seed (`int`): seed (`int`):
@ -614,6 +643,7 @@ class AsyncClient:
do_sample=do_sample, do_sample=do_sample,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
frequency_penalty=frequency_penalty,
return_full_text=return_full_text, return_full_text=return_full_text,
seed=seed, seed=seed,
stop=stop_sequences if stop_sequences is not None else [], stop=stop_sequences if stop_sequences is not None else [],
@ -644,6 +674,7 @@ class AsyncClient:
do_sample: bool = False, do_sample: bool = False,
max_new_tokens: int = 20, max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None, repetition_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
return_full_text: bool = False, return_full_text: bool = False,
seed: Optional[int] = None, seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
@ -669,6 +700,10 @@ class AsyncClient:
repetition_penalty (`float`): repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
frequency_penalty (`float`):
The parameter for frequency penalty. 1.0 means no penalty
Penalize new tokens based on their existing frequency in the text so far,
decreasing the model's likelihood to repeat the same line verbatim.
return_full_text (`bool`): return_full_text (`bool`):
Whether to prepend the prompt to the generated text Whether to prepend the prompt to the generated text
seed (`int`): seed (`int`):
@ -706,6 +741,7 @@ class AsyncClient:
do_sample=do_sample, do_sample=do_sample,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
frequency_penalty=frequency_penalty,
return_full_text=return_full_text, return_full_text=return_full_text,
seed=seed, seed=seed,
stop=stop_sequences if stop_sequences is not None else [], stop=stop_sequences if stop_sequences is not None else [],

View File

@ -109,7 +109,12 @@ class ChatRequest(BaseModel):
model: str model: str
# List of messages in the conversation # List of messages in the conversation
messages: List[Message] messages: List[Message]
# Penalty for frequency of new tokens # The parameter for repetition penalty. 1.0 means no penalty.
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
repetition_penalty: Optional[float] = None
# The parameter for frequency penalty. 1.0 means no penalty
# Penalize new tokens based on their existing frequency in the text so far,
# decreasing the model's likelihood to repeat the same line verbatim.
frequency_penalty: Optional[float] = None frequency_penalty: Optional[float] = None
# Bias values for token selection # Bias values for token selection
logit_bias: Optional[List[float]] = None logit_bias: Optional[List[float]] = None
@ -145,6 +150,10 @@ class Parameters(BaseModel):
# The parameter for repetition penalty. 1.0 means no penalty. # The parameter for repetition penalty. 1.0 means no penalty.
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
repetition_penalty: Optional[float] = None repetition_penalty: Optional[float] = None
# The parameter for frequency penalty. 1.0 means no penalty
# Penalize new tokens based on their existing frequency in the text so far,
# decreasing the model's likelihood to repeat the same line verbatim.
frequency_penalty: Optional[float] = None
# Whether to prepend the prompt to the generated text # Whether to prepend the prompt to the generated text
return_full_text: bool = False return_full_text: bool = False
# Stop generating tokens if a member of `stop_sequences` is generated # Stop generating tokens if a member of `stop_sequences` is generated
@ -201,6 +210,12 @@ class Parameters(BaseModel):
raise ValidationError("`repetition_penalty` must be strictly positive") raise ValidationError("`repetition_penalty` must be strictly positive")
return v return v
@field_validator("frequency_penalty")
def valid_frequency_penalty(cls, v):
if v is not None and v <= 0:
raise ValidationError("`frequency_penalty` must be strictly positive")
return v
@field_validator("seed") @field_validator("seed")
def valid_seed(cls, v): def valid_seed(cls, v):
if v is not None and v < 0: if v is not None and v < 0: