
234 lines
7.2 KiB

from enum import Enum
from pydantic import BaseModel, validator
from typing import Optional, List
from text_generation.errors import ValidationError
class Parameters(BaseModel):
# Activate logits sampling
do_sample: bool = False
# Maximum number of generated tokens
max_new_tokens: int = 20
# The parameter for repetition penalty. 1.0 means no penalty.
# See [this paper]( for more details.
repetition_penalty: Optional[float] = None
# Whether to prepend the prompt to the generated text
return_full_text: bool = False
# Stop generating tokens if a member of `stop_sequences` is generated
stop: List[str] = []
# Random sampling seed
seed: Optional[int]
# The value used to module the logits distribution.
temperature: Optional[float]
# The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_k: Optional[int]
# If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
# higher are kept for generation.
top_p: Optional[float]
# truncate inputs tokens to the given size
truncate: Optional[int]
# Typical Decoding mass
# See [Typical Decoding for Natural Language Generation]( for more information
typical_p: Optional[float]
# Generate best_of sequences and return the one if the highest token logprobs
best_of: Optional[int]
# Watermarking with [A Watermark for Large Language Models](
watermark: bool = False
# Get generation details
details: bool = False
# Get decoder input token logprobs and ids
decoder_input_details: bool = False
def valid_best_of(cls, field_value, values):
if field_value is not None:
if field_value <= 0:
raise ValidationError("`best_of` must be strictly positive")
if field_value > 1 and values["seed"] is not None:
raise ValidationError("`seed` must not be set when `best_of` is > 1")
sampling = (
| (values["temperature"] is not None)
| (values["top_k"] is not None)
| (values["top_p"] is not None)
| (values["typical_p"] is not None)
if field_value > 1 and not sampling:
raise ValidationError("you must use sampling when `best_of` is > 1")
return field_value
def valid_repetition_penalty(cls, v):
if v is not None and v <= 0:
raise ValidationError("`repetition_penalty` must be strictly positive")
return v
def valid_seed(cls, v):
if v is not None and v < 0:
raise ValidationError("`seed` must be positive")
return v
def valid_temp(cls, v):
if v is not None and v <= 0:
raise ValidationError("`temperature` must be strictly positive")
return v
def valid_top_k(cls, v):
if v is not None and v <= 0:
raise ValidationError("`top_k` must be strictly positive")
return v
def valid_top_p(cls, v):
if v is not None and (v <= 0 or v >= 1.0):
raise ValidationError("`top_p` must be > 0.0 and < 1.0")
return v
def valid_truncate(cls, v):
if v is not None and v <= 0:
raise ValidationError("`truncate` must be strictly positive")
return v
def valid_typical_p(cls, v):
if v is not None and (v <= 0 or v >= 1.0):
raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
return v
class Request(BaseModel):
# Prompt
inputs: str
# Generation parameters
parameters: Optional[Parameters]
# Whether to stream output tokens
stream: bool = False
def valid_input(cls, v):
if not v:
raise ValidationError("`inputs` cannot be empty")
return v
def valid_best_of_stream(cls, field_value, values):
parameters = values["parameters"]
if (
parameters is not None
and parameters.best_of is not None
and parameters.best_of > 1
and field_value
raise ValidationError(
"`best_of` != 1 is not supported when `stream` == True"
return field_value
# Decoder input tokens
class InputToken(BaseModel):
# Token ID from the model tokenizer
id: int
# Token text
text: str
# Logprob
# Optional since the logprob of the first token cannot be computed
logprob: Optional[float]
# Generated tokens
class Token(BaseModel):
# Token ID from the model tokenizer
id: int
# Token text
text: str
# Logprob
logprob: float
# Is the token a special token
# Can be used to ignore tokens when concatenating
special: bool
# Generation finish reason
class FinishReason(str, Enum):
# number of generated tokens == `max_new_tokens`
Length = "length"
# the model generated its end of sequence token
EndOfSequenceToken = "eos_token"
# the model generated a text included in `stop_sequences`
StopSequence = "stop_sequence"
# Additional sequences when using the `best_of` parameter
class BestOfSequence(BaseModel):
# Generated text
generated_text: str
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Decoder input tokens, empty if decoder_input_details is False
prefill: List[InputToken]
# Generated tokens
tokens: List[Token]
# `generate` details
class Details(BaseModel):
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Decoder input tokens, empty if decoder_input_details is False
prefill: List[InputToken]
# Generated tokens
tokens: List[Token]
# Additional sequences when using the `best_of` parameter
best_of_sequences: Optional[List[BestOfSequence]]
# `generate` return value
class Response(BaseModel):
# Generated text
generated_text: str
# Generation details
details: Details
# `generate_stream` details
class StreamDetails(BaseModel):
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# `generate_stream` return value
class StreamResponse(BaseModel):
# Generated token
token: Token
# Complete generated text
# Only available when the generation is finished
generated_text: Optional[str]
# Generation details
# Only available when the generation is finished
details: Optional[StreamDetails]
# Inference API currently deployed model
class DeployedModel(BaseModel):
model_id: str
sha: str