feat: update client to 0.7 (#1667)

Close #1652
This commit is contained in:
OlivierDehaene 2024-03-22 17:10:56 +01:00 committed by GitHub
parent deb440b3a2
commit 08e9181418
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 86 additions and 7 deletions

View File

@ -107,7 +107,19 @@ print(text)
### Types
```python
# Request Parameters
# enum for grammar type
class GrammarType(Enum):
Json = "json"
Regex = "regex"
# Grammar type and value
class Grammar:
# Grammar type
type: GrammarType
# Grammar value
value: Union[str, dict]
class Parameters:
# Activate logits sampling
do_sample: bool
@ -116,6 +128,10 @@ class Parameters:
# The parameter for repetition penalty. 1.0 means no penalty.
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
repetition_penalty: Optional[float]
# The parameter for frequency penalty. 1.0 means no penalty
# Penalize new tokens based on their existing frequency in the text so far,
# decreasing the model's likelihood to repeat the same line verbatim.
frequency_penalty: Optional[float]
# Whether to prepend the prompt to the generated text
return_full_text: bool
# Stop generating tokens if a member of `stop_sequences` is generated
@ -138,10 +154,22 @@ class Parameters:
best_of: Optional[int]
# Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
watermark: bool
# Get generation details
details: bool
# Get decoder input token logprobs and ids
decoder_input_details: bool
# Return the N most likely tokens at each step
top_n_tokens: Optional[int]
# grammar to use for generation
grammar: Optional[Grammar]
class Request:
# Prompt
inputs: str
# Generation parameters
parameters: Optional[Parameters]
# Whether to stream output tokens
stream: bool
# Decoder input tokens
class InputToken:
@ -161,7 +189,7 @@ class Token:
# Token text
text: str
# Logprob
logprob: float
logprob: Optional[float]
# Is the token a special token
# Can be used to ignore tokens when concatenating
special: bool

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "text-generation"
version = "0.6.1"
version = "0.7.0"
description = "Hugging Face Text Generation Python Client"
license = "Apache-2.0"
authors = ["Olivier Dehaene <olivier@huggingface.co>"]

View File

@ -67,6 +67,7 @@ class Client:
def chat(
self,
messages: List[Message],
repetition_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[List[float]] = None,
logprobs: Optional[bool] = None,
@ -87,9 +88,13 @@ class Client:
Args:
messages (`List[Message]`):
List of messages
frequency_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty. See [this
repetition_penalty (`float`):
The parameter for repetition penalty. 0.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
frequency_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty
Penalize new tokens based on their existing frequency in the text so far,
decreasing the model's likelihood to repeat the same line verbatim.
logit_bias (`List[float]`):
Adjust the likelihood of specified tokens
logprobs (`bool`):
@ -121,6 +126,7 @@ class Client:
request = ChatRequest(
model="tgi",
messages=messages,
repetition_penalty=repetition_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
@ -179,6 +185,7 @@ class Client:
max_new_tokens: int = 20,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
@ -207,6 +214,10 @@ class Client:
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
frequency_penalty (`float`):
The parameter for frequency penalty. 1.0 means no penalty
Penalize new tokens based on their existing frequency in the text so far,
decreasing the model's likelihood to repeat the same line verbatim.
return_full_text (`bool`):
Whether to prepend the prompt to the generated text
seed (`int`):
@ -245,6 +256,7 @@ class Client:
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
frequency_penalty=frequency_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop_sequences if stop_sequences is not None else [],
@ -278,6 +290,7 @@ class Client:
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
@ -303,6 +316,10 @@ class Client:
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
frequency_penalty (`float`):
The parameter for frequency penalty. 1.0 means no penalty
Penalize new tokens based on their existing frequency in the text so far,
decreasing the model's likelihood to repeat the same line verbatim.
return_full_text (`bool`):
Whether to prepend the prompt to the generated text
seed (`int`):
@ -340,6 +357,7 @@ class Client:
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
frequency_penalty=frequency_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop_sequences if stop_sequences is not None else [],
@ -435,6 +453,7 @@ class AsyncClient:
async def chat(
self,
messages: List[Message],
repetition_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[List[float]] = None,
logprobs: Optional[bool] = None,
@ -455,9 +474,13 @@ class AsyncClient:
Args:
messages (`List[Message]`):
List of messages
frequency_penalty (`float`):
repetition_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
frequency_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty
Penalize new tokens based on their existing frequency in the text so far,
decreasing the model's likelihood to repeat the same line verbatim.
logit_bias (`List[float]`):
Adjust the likelihood of specified tokens
logprobs (`bool`):
@ -489,6 +512,7 @@ class AsyncClient:
request = ChatRequest(
model="tgi",
messages=messages,
repetition_penalty=repetition_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
@ -546,6 +570,7 @@ class AsyncClient:
max_new_tokens: int = 20,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
@ -574,6 +599,10 @@ class AsyncClient:
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
frequency_penalty (`float`):
The parameter for frequency penalty. 1.0 means no penalty
Penalize new tokens based on their existing frequency in the text so far,
decreasing the model's likelihood to repeat the same line verbatim.
return_full_text (`bool`):
Whether to prepend the prompt to the generated text
seed (`int`):
@ -614,6 +643,7 @@ class AsyncClient:
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
frequency_penalty=frequency_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop_sequences if stop_sequences is not None else [],
@ -644,6 +674,7 @@ class AsyncClient:
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
@ -669,6 +700,10 @@ class AsyncClient:
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
frequency_penalty (`float`):
The parameter for frequency penalty. 1.0 means no penalty
Penalize new tokens based on their existing frequency in the text so far,
decreasing the model's likelihood to repeat the same line verbatim.
return_full_text (`bool`):
Whether to prepend the prompt to the generated text
seed (`int`):
@ -706,6 +741,7 @@ class AsyncClient:
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
frequency_penalty=frequency_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop_sequences if stop_sequences is not None else [],

View File

@ -109,7 +109,12 @@ class ChatRequest(BaseModel):
model: str
# List of messages in the conversation
messages: List[Message]
# Penalty for frequency of new tokens
# The parameter for repetition penalty. 1.0 means no penalty.
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
repetition_penalty: Optional[float] = None
# The parameter for frequency penalty. 1.0 means no penalty
# Penalize new tokens based on their existing frequency in the text so far,
# decreasing the model's likelihood to repeat the same line verbatim.
frequency_penalty: Optional[float] = None
# Bias values for token selection
logit_bias: Optional[List[float]] = None
@ -145,6 +150,10 @@ class Parameters(BaseModel):
# The parameter for repetition penalty. 1.0 means no penalty.
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
repetition_penalty: Optional[float] = None
# The parameter for frequency penalty. 1.0 means no penalty
# Penalize new tokens based on their existing frequency in the text so far,
# decreasing the model's likelihood to repeat the same line verbatim.
frequency_penalty: Optional[float] = None
# Whether to prepend the prompt to the generated text
return_full_text: bool = False
# Stop generating tokens if a member of `stop_sequences` is generated
@ -201,6 +210,12 @@ class Parameters(BaseModel):
raise ValidationError("`repetition_penalty` must be strictly positive")
return v
@field_validator("frequency_penalty")
def valid_frequency_penalty(cls, v):
if v is not None and v <= 0:
raise ValidationError("`frequency_penalty` must be strictly positive")
return v
@field_validator("seed")
def valid_seed(cls, v):
if v is not None and v < 0: