429 lines
13 KiB
Python
429 lines
13 KiB
Python
from enum import Enum
|
|
from pydantic import BaseModel, field_validator
|
|
from typing import Optional, List, Union, Any
|
|
|
|
from text_generation.errors import ValidationError
|
|
|
|
|
|
# enum for grammar type
|
|
class GrammarType(str, Enum):
|
|
Json = "json"
|
|
Regex = "regex"
|
|
|
|
|
|
# Grammar type and value
|
|
class Grammar(BaseModel):
|
|
# Grammar type
|
|
type: GrammarType
|
|
# Grammar value
|
|
value: Union[str, dict]
|
|
|
|
|
|
class ToolCall(BaseModel):
|
|
# Id of the tool call
|
|
id: int
|
|
# Type of the tool call
|
|
type: str
|
|
# Function details of the tool call
|
|
function: dict
|
|
|
|
|
|
class Message(BaseModel):
|
|
# Role of the message sender
|
|
role: str
|
|
# Content of the message
|
|
content: Optional[str] = None
|
|
# Optional name of the message sender
|
|
name: Optional[str] = None
|
|
# Tool calls associated with the chat completion
|
|
tool_calls: Optional[Any] = None
|
|
|
|
|
|
class Tool(BaseModel):
|
|
# Type of the tool
|
|
type: str
|
|
# Function details of the tool
|
|
function: dict
|
|
|
|
|
|
class ChatCompletionComplete(BaseModel):
|
|
# Index of the chat completion
|
|
index: int
|
|
# Message associated with the chat completion
|
|
message: Message
|
|
# Log probabilities for the chat completion
|
|
logprobs: Optional[Any]
|
|
# Reason for completion
|
|
finish_reason: str
|
|
# Usage details of the chat completion
|
|
usage: Optional[Any] = None
|
|
|
|
|
|
class CompletionComplete(BaseModel):
|
|
# Index of the chat completion
|
|
index: int
|
|
# Message associated with the chat completion
|
|
text: str
|
|
# Log probabilities for the chat completion
|
|
logprobs: Optional[Any]
|
|
# Reason for completion
|
|
finish_reason: str
|
|
|
|
|
|
class Function(BaseModel):
|
|
name: Optional[str]
|
|
arguments: str
|
|
|
|
|
|
class ChoiceDeltaToolCall(BaseModel):
|
|
index: int
|
|
id: str
|
|
type: str
|
|
function: Function
|
|
|
|
|
|
class ChoiceDelta(BaseModel):
|
|
role: str
|
|
content: Optional[str] = None
|
|
tool_calls: Optional[ChoiceDeltaToolCall]
|
|
|
|
|
|
class Choice(BaseModel):
|
|
index: int
|
|
delta: ChoiceDelta
|
|
logprobs: Optional[dict] = None
|
|
finish_reason: Optional[str] = None
|
|
|
|
|
|
class ChatCompletionChunk(BaseModel):
|
|
id: str
|
|
object: str
|
|
created: int
|
|
model: str
|
|
system_fingerprint: str
|
|
choices: List[Choice]
|
|
|
|
|
|
class ChatComplete(BaseModel):
|
|
# Chat completion details
|
|
id: str
|
|
object: str
|
|
created: int
|
|
model: str
|
|
system_fingerprint: str
|
|
choices: List[ChatCompletionComplete]
|
|
usage: Any
|
|
|
|
|
|
class Completion(BaseModel):
|
|
# Completion details
|
|
id: str
|
|
object: str
|
|
created: int
|
|
model: str
|
|
system_fingerprint: str
|
|
choices: List[CompletionComplete]
|
|
|
|
|
|
class ChatRequest(BaseModel):
|
|
# Model identifier
|
|
model: str
|
|
# List of messages in the conversation
|
|
messages: List[Message]
|
|
# The parameter for repetition penalty. 1.0 means no penalty.
|
|
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
|
repetition_penalty: Optional[float] = None
|
|
# The parameter for frequency penalty. 1.0 means no penalty
|
|
# Penalize new tokens based on their existing frequency in the text so far,
|
|
# decreasing the model's likelihood to repeat the same line verbatim.
|
|
frequency_penalty: Optional[float] = None
|
|
# Bias values for token selection
|
|
logit_bias: Optional[List[float]] = None
|
|
# Whether to return log probabilities
|
|
logprobs: Optional[bool] = None
|
|
# Number of most likely tokens to return at each position
|
|
top_logprobs: Optional[int] = None
|
|
# Maximum number of tokens to generate
|
|
max_tokens: Optional[int] = None
|
|
# Number of chat completion choices to generate
|
|
n: Optional[int] = None
|
|
# Penalty for presence of new tokens
|
|
presence_penalty: Optional[float] = None
|
|
# Flag to indicate streaming response
|
|
stream: bool = False
|
|
# Random sampling seed
|
|
seed: Optional[int] = None
|
|
# Sampling temperature
|
|
temperature: Optional[float] = None
|
|
# Top-p value for nucleus sampling
|
|
top_p: Optional[float] = None
|
|
# List of tools to be used
|
|
tools: Optional[List[Tool]] = None
|
|
# A prompt to be appended before the tools
|
|
tool_prompt: Optional[str] = None
|
|
# Choice of tool to be used
|
|
tool_choice: Optional[str] = None
|
|
|
|
|
|
class Parameters(BaseModel):
|
|
# Activate logits sampling
|
|
do_sample: bool = False
|
|
# Maximum number of generated tokens
|
|
max_new_tokens: int = 20
|
|
# The parameter for repetition penalty. 1.0 means no penalty.
|
|
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
|
repetition_penalty: Optional[float] = None
|
|
# The parameter for frequency penalty. 1.0 means no penalty
|
|
# Penalize new tokens based on their existing frequency in the text so far,
|
|
# decreasing the model's likelihood to repeat the same line verbatim.
|
|
frequency_penalty: Optional[float] = None
|
|
# Whether to prepend the prompt to the generated text
|
|
return_full_text: bool = False
|
|
# Stop generating tokens if a member of `stop_sequences` is generated
|
|
stop: List[str] = []
|
|
# Random sampling seed
|
|
seed: Optional[int] = None
|
|
# The value used to module the logits distribution.
|
|
temperature: Optional[float] = None
|
|
# The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
|
top_k: Optional[int] = None
|
|
# If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
|
# higher are kept for generation.
|
|
top_p: Optional[float] = None
|
|
# truncate inputs tokens to the given size
|
|
truncate: Optional[int] = None
|
|
# Typical Decoding mass
|
|
# See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
|
typical_p: Optional[float] = None
|
|
# Generate best_of sequences and return the one if the highest token logprobs
|
|
best_of: Optional[int] = None
|
|
# Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
|
watermark: bool = False
|
|
# Get generation details
|
|
details: bool = False
|
|
# Get decoder input token logprobs and ids
|
|
decoder_input_details: bool = False
|
|
# Return the N most likely tokens at each step
|
|
top_n_tokens: Optional[int] = None
|
|
# grammar to use for generation
|
|
grammar: Optional[Grammar] = None
|
|
|
|
@field_validator("best_of")
|
|
def valid_best_of(cls, field_value, values):
|
|
if field_value is not None:
|
|
if field_value <= 0:
|
|
raise ValidationError("`best_of` must be strictly positive")
|
|
if field_value > 1 and values.data["seed"] is not None:
|
|
raise ValidationError("`seed` must not be set when `best_of` is > 1")
|
|
sampling = (
|
|
values.data["do_sample"]
|
|
| (values.data["temperature"] is not None)
|
|
| (values.data["top_k"] is not None)
|
|
| (values.data["top_p"] is not None)
|
|
| (values.data["typical_p"] is not None)
|
|
)
|
|
if field_value > 1 and not sampling:
|
|
raise ValidationError("you must use sampling when `best_of` is > 1")
|
|
|
|
return field_value
|
|
|
|
@field_validator("repetition_penalty")
|
|
def valid_repetition_penalty(cls, v):
|
|
if v is not None and v <= 0:
|
|
raise ValidationError("`repetition_penalty` must be strictly positive")
|
|
return v
|
|
|
|
@field_validator("frequency_penalty")
|
|
def valid_frequency_penalty(cls, v):
|
|
if v is not None and v <= 0:
|
|
raise ValidationError("`frequency_penalty` must be strictly positive")
|
|
return v
|
|
|
|
@field_validator("seed")
|
|
def valid_seed(cls, v):
|
|
if v is not None and v < 0:
|
|
raise ValidationError("`seed` must be positive")
|
|
return v
|
|
|
|
@field_validator("temperature")
|
|
def valid_temp(cls, v):
|
|
if v is not None and v <= 0:
|
|
raise ValidationError("`temperature` must be strictly positive")
|
|
return v
|
|
|
|
@field_validator("top_k")
|
|
def valid_top_k(cls, v):
|
|
if v is not None and v <= 0:
|
|
raise ValidationError("`top_k` must be strictly positive")
|
|
return v
|
|
|
|
@field_validator("top_p")
|
|
def valid_top_p(cls, v):
|
|
if v is not None and (v <= 0 or v >= 1.0):
|
|
raise ValidationError("`top_p` must be > 0.0 and < 1.0")
|
|
return v
|
|
|
|
@field_validator("truncate")
|
|
def valid_truncate(cls, v):
|
|
if v is not None and v <= 0:
|
|
raise ValidationError("`truncate` must be strictly positive")
|
|
return v
|
|
|
|
@field_validator("typical_p")
|
|
def valid_typical_p(cls, v):
|
|
if v is not None and (v <= 0 or v >= 1.0):
|
|
raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
|
|
return v
|
|
|
|
@field_validator("top_n_tokens")
|
|
def valid_top_n_tokens(cls, v):
|
|
if v is not None and v <= 0:
|
|
raise ValidationError("`top_n_tokens` must be strictly positive")
|
|
return v
|
|
|
|
@field_validator("grammar")
|
|
def valid_grammar(cls, v):
|
|
if v is not None:
|
|
if v.type == GrammarType.Regex and not v.value:
|
|
raise ValidationError("`value` cannot be empty for `regex` grammar")
|
|
if v.type == GrammarType.Json and not v.value:
|
|
raise ValidationError("`value` cannot be empty for `json` grammar")
|
|
return v
|
|
|
|
|
|
class Request(BaseModel):
|
|
# Prompt
|
|
inputs: str
|
|
# Generation parameters
|
|
parameters: Optional[Parameters] = None
|
|
# Whether to stream output tokens
|
|
stream: bool = False
|
|
|
|
@field_validator("inputs")
|
|
def valid_input(cls, v):
|
|
if not v:
|
|
raise ValidationError("`inputs` cannot be empty")
|
|
return v
|
|
|
|
@field_validator("stream")
|
|
def valid_best_of_stream(cls, field_value, values):
|
|
parameters = values.data["parameters"]
|
|
if (
|
|
parameters is not None
|
|
and parameters.best_of is not None
|
|
and parameters.best_of > 1
|
|
and field_value
|
|
):
|
|
raise ValidationError(
|
|
"`best_of` != 1 is not supported when `stream` == True"
|
|
)
|
|
return field_value
|
|
|
|
|
|
# Decoder input tokens
|
|
class InputToken(BaseModel):
|
|
# Token ID from the model tokenizer
|
|
id: int
|
|
# Token text
|
|
text: str
|
|
# Logprob
|
|
# Optional since the logprob of the first token cannot be computed
|
|
logprob: Optional[float] = None
|
|
|
|
|
|
# Generated tokens
|
|
class Token(BaseModel):
|
|
# Token ID from the model tokenizer
|
|
id: int
|
|
# Token text
|
|
text: str
|
|
# Logprob
|
|
logprob: Optional[float] = None
|
|
# Is the token a special token
|
|
# Can be used to ignore tokens when concatenating
|
|
special: bool
|
|
|
|
|
|
# Generation finish reason
|
|
class FinishReason(str, Enum):
|
|
# number of generated tokens == `max_new_tokens`
|
|
Length = "length"
|
|
# the model generated its end of sequence token
|
|
EndOfSequenceToken = "eos_token"
|
|
# the model generated a text included in `stop_sequences`
|
|
StopSequence = "stop_sequence"
|
|
|
|
|
|
# Additional sequences when using the `best_of` parameter
|
|
class BestOfSequence(BaseModel):
|
|
# Generated text
|
|
generated_text: str
|
|
# Generation finish reason
|
|
finish_reason: FinishReason
|
|
# Number of generated tokens
|
|
generated_tokens: int
|
|
# Sampling seed if sampling was activated
|
|
seed: Optional[int] = None
|
|
# Decoder input tokens, empty if decoder_input_details is False
|
|
prefill: List[InputToken]
|
|
# Generated tokens
|
|
tokens: List[Token]
|
|
# Most likely tokens
|
|
top_tokens: Optional[List[List[Token]]] = None
|
|
|
|
|
|
# `generate` details
|
|
class Details(BaseModel):
|
|
# Generation finish reason
|
|
finish_reason: FinishReason
|
|
# Number of generated tokens
|
|
generated_tokens: int
|
|
# Sampling seed if sampling was activated
|
|
seed: Optional[int] = None
|
|
# Decoder input tokens, empty if decoder_input_details is False
|
|
prefill: List[InputToken]
|
|
# Generated tokens
|
|
tokens: List[Token]
|
|
# Most likely tokens
|
|
top_tokens: Optional[List[List[Token]]] = None
|
|
# Additional sequences when using the `best_of` parameter
|
|
best_of_sequences: Optional[List[BestOfSequence]] = None
|
|
|
|
|
|
# `generate` return value
|
|
class Response(BaseModel):
|
|
# Generated text
|
|
generated_text: str
|
|
# Generation details
|
|
details: Details
|
|
|
|
|
|
# `generate_stream` details
|
|
class StreamDetails(BaseModel):
|
|
# Generation finish reason
|
|
finish_reason: FinishReason
|
|
# Number of generated tokens
|
|
generated_tokens: int
|
|
# Sampling seed if sampling was activated
|
|
seed: Optional[int] = None
|
|
|
|
|
|
# `generate_stream` return value
|
|
class StreamResponse(BaseModel):
|
|
# Generated token
|
|
token: Token
|
|
# Most likely tokens
|
|
top_tokens: Optional[List[Token]] = None
|
|
# Complete generated text
|
|
# Only available when the generation is finished
|
|
generated_text: Optional[str] = None
|
|
# Generation details
|
|
# Only available when the generation is finished
|
|
details: Optional[StreamDetails] = None
|
|
|
|
|
|
# Inference API currently deployed model
|
|
class DeployedModel(BaseModel):
|
|
model_id: str
|
|
sha: str
|