parent
deb440b3a2
commit
08e9181418
|
@ -107,7 +107,19 @@ print(text)
|
|||
### Types
|
||||
|
||||
```python
|
||||
# Request Parameters
|
||||
# enum for grammar type
|
||||
class GrammarType(Enum):
|
||||
Json = "json"
|
||||
Regex = "regex"
|
||||
|
||||
|
||||
# Grammar type and value
|
||||
class Grammar:
|
||||
# Grammar type
|
||||
type: GrammarType
|
||||
# Grammar value
|
||||
value: Union[str, dict]
|
||||
|
||||
class Parameters:
|
||||
# Activate logits sampling
|
||||
do_sample: bool
|
||||
|
@ -116,6 +128,10 @@ class Parameters:
|
|||
# The parameter for repetition penalty. 1.0 means no penalty.
|
||||
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
repetition_penalty: Optional[float]
|
||||
# The parameter for frequency penalty. 1.0 means no penalty
|
||||
# Penalize new tokens based on their existing frequency in the text so far,
|
||||
# decreasing the model's likelihood to repeat the same line verbatim.
|
||||
frequency_penalty: Optional[float]
|
||||
# Whether to prepend the prompt to the generated text
|
||||
return_full_text: bool
|
||||
# Stop generating tokens if a member of `stop_sequences` is generated
|
||||
|
@ -138,10 +154,22 @@ class Parameters:
|
|||
best_of: Optional[int]
|
||||
# Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
watermark: bool
|
||||
# Get generation details
|
||||
details: bool
|
||||
# Get decoder input token logprobs and ids
|
||||
decoder_input_details: bool
|
||||
# Return the N most likely tokens at each step
|
||||
top_n_tokens: Optional[int]
|
||||
# grammar to use for generation
|
||||
grammar: Optional[Grammar]
|
||||
|
||||
class Request:
|
||||
# Prompt
|
||||
inputs: str
|
||||
# Generation parameters
|
||||
parameters: Optional[Parameters]
|
||||
# Whether to stream output tokens
|
||||
stream: bool
|
||||
|
||||
# Decoder input tokens
|
||||
class InputToken:
|
||||
|
@ -161,7 +189,7 @@ class Token:
|
|||
# Token text
|
||||
text: str
|
||||
# Logprob
|
||||
logprob: float
|
||||
logprob: Optional[float]
|
||||
# Is the token a special token
|
||||
# Can be used to ignore tokens when concatenating
|
||||
special: bool
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "text-generation"
|
||||
version = "0.6.1"
|
||||
version = "0.7.0"
|
||||
description = "Hugging Face Text Generation Python Client"
|
||||
license = "Apache-2.0"
|
||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||
|
|
|
@ -67,6 +67,7 @@ class Client:
|
|||
def chat(
|
||||
self,
|
||||
messages: List[Message],
|
||||
repetition_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[List[float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
|
@ -87,9 +88,13 @@ class Client:
|
|||
Args:
|
||||
messages (`List[Message]`):
|
||||
List of messages
|
||||
frequency_penalty (`float`):
|
||||
The parameter for frequency penalty. 0.0 means no penalty. See [this
|
||||
repetition_penalty (`float`):
|
||||
The parameter for repetition penalty. 0.0 means no penalty. See [this
|
||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
frequency_penalty (`float`):
|
||||
The parameter for frequency penalty. 0.0 means no penalty
|
||||
Penalize new tokens based on their existing frequency in the text so far,
|
||||
decreasing the model's likelihood to repeat the same line verbatim.
|
||||
logit_bias (`List[float]`):
|
||||
Adjust the likelihood of specified tokens
|
||||
logprobs (`bool`):
|
||||
|
@ -121,6 +126,7 @@ class Client:
|
|||
request = ChatRequest(
|
||||
model="tgi",
|
||||
messages=messages,
|
||||
repetition_penalty=repetition_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
|
@ -179,6 +185,7 @@ class Client:
|
|||
max_new_tokens: int = 20,
|
||||
best_of: Optional[int] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
return_full_text: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
|
@ -207,6 +214,10 @@ class Client:
|
|||
repetition_penalty (`float`):
|
||||
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
frequency_penalty (`float`):
|
||||
The parameter for frequency penalty. 1.0 means no penalty
|
||||
Penalize new tokens based on their existing frequency in the text so far,
|
||||
decreasing the model's likelihood to repeat the same line verbatim.
|
||||
return_full_text (`bool`):
|
||||
Whether to prepend the prompt to the generated text
|
||||
seed (`int`):
|
||||
|
@ -245,6 +256,7 @@ class Client:
|
|||
do_sample=do_sample,
|
||||
max_new_tokens=max_new_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
return_full_text=return_full_text,
|
||||
seed=seed,
|
||||
stop=stop_sequences if stop_sequences is not None else [],
|
||||
|
@ -278,6 +290,7 @@ class Client:
|
|||
do_sample: bool = False,
|
||||
max_new_tokens: int = 20,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
return_full_text: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
|
@ -303,6 +316,10 @@ class Client:
|
|||
repetition_penalty (`float`):
|
||||
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
frequency_penalty (`float`):
|
||||
The parameter for frequency penalty. 1.0 means no penalty
|
||||
Penalize new tokens based on their existing frequency in the text so far,
|
||||
decreasing the model's likelihood to repeat the same line verbatim.
|
||||
return_full_text (`bool`):
|
||||
Whether to prepend the prompt to the generated text
|
||||
seed (`int`):
|
||||
|
@ -340,6 +357,7 @@ class Client:
|
|||
do_sample=do_sample,
|
||||
max_new_tokens=max_new_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
return_full_text=return_full_text,
|
||||
seed=seed,
|
||||
stop=stop_sequences if stop_sequences is not None else [],
|
||||
|
@ -435,6 +453,7 @@ class AsyncClient:
|
|||
async def chat(
|
||||
self,
|
||||
messages: List[Message],
|
||||
repetition_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[List[float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
|
@ -455,9 +474,13 @@ class AsyncClient:
|
|||
Args:
|
||||
messages (`List[Message]`):
|
||||
List of messages
|
||||
frequency_penalty (`float`):
|
||||
repetition_penalty (`float`):
|
||||
The parameter for frequency penalty. 0.0 means no penalty. See [this
|
||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
frequency_penalty (`float`):
|
||||
The parameter for frequency penalty. 0.0 means no penalty
|
||||
Penalize new tokens based on their existing frequency in the text so far,
|
||||
decreasing the model's likelihood to repeat the same line verbatim.
|
||||
logit_bias (`List[float]`):
|
||||
Adjust the likelihood of specified tokens
|
||||
logprobs (`bool`):
|
||||
|
@ -489,6 +512,7 @@ class AsyncClient:
|
|||
request = ChatRequest(
|
||||
model="tgi",
|
||||
messages=messages,
|
||||
repetition_penalty=repetition_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
|
@ -546,6 +570,7 @@ class AsyncClient:
|
|||
max_new_tokens: int = 20,
|
||||
best_of: Optional[int] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
return_full_text: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
|
@ -574,6 +599,10 @@ class AsyncClient:
|
|||
repetition_penalty (`float`):
|
||||
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
frequency_penalty (`float`):
|
||||
The parameter for frequency penalty. 1.0 means no penalty
|
||||
Penalize new tokens based on their existing frequency in the text so far,
|
||||
decreasing the model's likelihood to repeat the same line verbatim.
|
||||
return_full_text (`bool`):
|
||||
Whether to prepend the prompt to the generated text
|
||||
seed (`int`):
|
||||
|
@ -614,6 +643,7 @@ class AsyncClient:
|
|||
do_sample=do_sample,
|
||||
max_new_tokens=max_new_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
return_full_text=return_full_text,
|
||||
seed=seed,
|
||||
stop=stop_sequences if stop_sequences is not None else [],
|
||||
|
@ -644,6 +674,7 @@ class AsyncClient:
|
|||
do_sample: bool = False,
|
||||
max_new_tokens: int = 20,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
return_full_text: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
|
@ -669,6 +700,10 @@ class AsyncClient:
|
|||
repetition_penalty (`float`):
|
||||
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
frequency_penalty (`float`):
|
||||
The parameter for frequency penalty. 1.0 means no penalty
|
||||
Penalize new tokens based on their existing frequency in the text so far,
|
||||
decreasing the model's likelihood to repeat the same line verbatim.
|
||||
return_full_text (`bool`):
|
||||
Whether to prepend the prompt to the generated text
|
||||
seed (`int`):
|
||||
|
@ -706,6 +741,7 @@ class AsyncClient:
|
|||
do_sample=do_sample,
|
||||
max_new_tokens=max_new_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
return_full_text=return_full_text,
|
||||
seed=seed,
|
||||
stop=stop_sequences if stop_sequences is not None else [],
|
||||
|
|
|
@ -109,7 +109,12 @@ class ChatRequest(BaseModel):
|
|||
model: str
|
||||
# List of messages in the conversation
|
||||
messages: List[Message]
|
||||
# Penalty for frequency of new tokens
|
||||
# The parameter for repetition penalty. 1.0 means no penalty.
|
||||
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
repetition_penalty: Optional[float] = None
|
||||
# The parameter for frequency penalty. 1.0 means no penalty
|
||||
# Penalize new tokens based on their existing frequency in the text so far,
|
||||
# decreasing the model's likelihood to repeat the same line verbatim.
|
||||
frequency_penalty: Optional[float] = None
|
||||
# Bias values for token selection
|
||||
logit_bias: Optional[List[float]] = None
|
||||
|
@ -145,6 +150,10 @@ class Parameters(BaseModel):
|
|||
# The parameter for repetition penalty. 1.0 means no penalty.
|
||||
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
repetition_penalty: Optional[float] = None
|
||||
# The parameter for frequency penalty. 1.0 means no penalty
|
||||
# Penalize new tokens based on their existing frequency in the text so far,
|
||||
# decreasing the model's likelihood to repeat the same line verbatim.
|
||||
frequency_penalty: Optional[float] = None
|
||||
# Whether to prepend the prompt to the generated text
|
||||
return_full_text: bool = False
|
||||
# Stop generating tokens if a member of `stop_sequences` is generated
|
||||
|
@ -201,6 +210,12 @@ class Parameters(BaseModel):
|
|||
raise ValidationError("`repetition_penalty` must be strictly positive")
|
||||
return v
|
||||
|
||||
@field_validator("frequency_penalty")
|
||||
def valid_frequency_penalty(cls, v):
|
||||
if v is not None and v <= 0:
|
||||
raise ValidationError("`frequency_penalty` must be strictly positive")
|
||||
return v
|
||||
|
||||
@field_validator("seed")
|
||||
def valid_seed(cls, v):
|
||||
if v is not None and v < 0:
|
||||
|
|
Loading…
Reference in New Issue