Update the docs to include newer models. (#1492)
This commit is contained in:
parent
50a20a83d7
commit
ebecc06161
File diff suppressed because one or more lines are too long
|
@ -188,18 +188,20 @@ fn default_parameters() -> GenerateParameters {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||||
pub(crate) struct ChatCompletion {
|
pub(crate) struct ChatCompletion {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
pub object: String,
|
pub object: String,
|
||||||
|
#[schema(example = "1706270835")]
|
||||||
pub created: u64,
|
pub created: u64,
|
||||||
|
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||||
pub model: String,
|
pub model: String,
|
||||||
pub system_fingerprint: String,
|
pub system_fingerprint: String,
|
||||||
pub choices: Vec<ChatCompletionComplete>,
|
pub choices: Vec<ChatCompletionComplete>,
|
||||||
pub usage: Usage,
|
pub usage: Usage,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||||
pub(crate) struct ChatCompletionComplete {
|
pub(crate) struct ChatCompletionComplete {
|
||||||
pub index: u32,
|
pub index: u32,
|
||||||
pub message: Message,
|
pub message: Message,
|
||||||
|
@ -248,17 +250,19 @@ impl ChatCompletion {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||||
pub(crate) struct ChatCompletionChunk {
|
pub(crate) struct ChatCompletionChunk {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
pub object: String,
|
pub object: String,
|
||||||
|
#[schema(example = "1706270978")]
|
||||||
pub created: u64,
|
pub created: u64,
|
||||||
|
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||||
pub model: String,
|
pub model: String,
|
||||||
pub system_fingerprint: String,
|
pub system_fingerprint: String,
|
||||||
pub choices: Vec<ChatCompletionChoice>,
|
pub choices: Vec<ChatCompletionChoice>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||||
pub(crate) struct ChatCompletionChoice {
|
pub(crate) struct ChatCompletionChoice {
|
||||||
pub index: u32,
|
pub index: u32,
|
||||||
pub delta: ChatCompletionDelta,
|
pub delta: ChatCompletionDelta,
|
||||||
|
@ -266,9 +270,11 @@ pub(crate) struct ChatCompletionChoice {
|
||||||
pub finish_reason: Option<String>,
|
pub finish_reason: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||||
pub(crate) struct ChatCompletionDelta {
|
pub(crate) struct ChatCompletionDelta {
|
||||||
|
#[schema(example = "user")]
|
||||||
pub role: String,
|
pub role: String,
|
||||||
|
#[schema(example = "What is Deep Learning?")]
|
||||||
pub content: String,
|
pub content: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -311,7 +317,7 @@ fn default_request_messages() -> Vec<Message> {
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
pub(crate) struct ChatRequest {
|
pub(crate) struct ChatRequest {
|
||||||
/// UNUSED
|
/// UNUSED
|
||||||
#[schema(example = "bigscience/blomm-560m")]
|
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||||
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
||||||
pub model: String, /* NOTE: UNUSED */
|
pub model: String, /* NOTE: UNUSED */
|
||||||
|
|
||||||
|
@ -322,6 +328,7 @@ pub(crate) struct ChatRequest {
|
||||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
|
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
|
||||||
/// decreasing the model's likelihood to repeat the same line verbatim.
|
/// decreasing the model's likelihood to repeat the same line verbatim.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(example = "1.0")]
|
||||||
pub frequency_penalty: Option<f32>,
|
pub frequency_penalty: Option<f32>,
|
||||||
|
|
||||||
/// UNUSED
|
/// UNUSED
|
||||||
|
@ -336,28 +343,33 @@ pub(crate) struct ChatRequest {
|
||||||
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
|
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
|
||||||
/// output token returned in the content of message.
|
/// output token returned in the content of message.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(example = "false")]
|
||||||
pub logprobs: Option<bool>,
|
pub logprobs: Option<bool>,
|
||||||
|
|
||||||
/// UNUSED
|
/// UNUSED
|
||||||
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
|
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
|
||||||
/// an associated log probability. logprobs must be set to true if this parameter is used.
|
/// an associated log probability. logprobs must be set to true if this parameter is used.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(example = "5")]
|
||||||
pub top_logprobs: Option<u32>,
|
pub top_logprobs: Option<u32>,
|
||||||
|
|
||||||
/// The maximum number of tokens that can be generated in the chat completion.
|
/// The maximum number of tokens that can be generated in the chat completion.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(example = "32")]
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
|
|
||||||
/// UNUSED
|
/// UNUSED
|
||||||
/// How many chat completion choices to generate for each input message. Note that you will be charged based on the
|
/// How many chat completion choices to generate for each input message. Note that you will be charged based on the
|
||||||
/// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
|
/// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "2")]
|
||||||
pub n: Option<u32>,
|
pub n: Option<u32>,
|
||||||
|
|
||||||
/// UNUSED
|
/// UNUSED
|
||||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
|
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
|
||||||
/// increasing the model's likelihood to talk about new topics
|
/// increasing the model's likelihood to talk about new topics
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = 0.1)]
|
||||||
pub presence_penalty: Option<f32>,
|
pub presence_penalty: Option<f32>,
|
||||||
|
|
||||||
#[serde(default = "bool::default")]
|
#[serde(default = "bool::default")]
|
||||||
|
@ -371,11 +383,13 @@ pub(crate) struct ChatRequest {
|
||||||
///
|
///
|
||||||
/// We generally recommend altering this or `top_p` but not both.
|
/// We generally recommend altering this or `top_p` but not both.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = 1.0)]
|
||||||
pub temperature: Option<f32>,
|
pub temperature: Option<f32>,
|
||||||
|
|
||||||
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
|
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
|
||||||
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
|
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = 0.95)]
|
||||||
pub top_p: Option<f32>,
|
pub top_p: Option<f32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -458,6 +472,7 @@ pub struct SimpleToken {
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
#[serde(rename_all(serialize = "snake_case"))]
|
#[serde(rename_all(serialize = "snake_case"))]
|
||||||
|
#[schema(example = "Length")]
|
||||||
pub(crate) enum FinishReason {
|
pub(crate) enum FinishReason {
|
||||||
#[schema(rename = "length")]
|
#[schema(rename = "length")]
|
||||||
Length,
|
Length,
|
||||||
|
@ -518,6 +533,10 @@ pub(crate) struct GenerateResponse {
|
||||||
pub details: Option<Details>,
|
pub details: Option<Details>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, ToSchema)]
|
||||||
|
#[serde(transparent)]
|
||||||
|
pub(crate) struct TokenizeResponse(Vec<SimpleToken>);
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
pub(crate) struct StreamDetails {
|
pub(crate) struct StreamDetails {
|
||||||
#[schema(example = "length")]
|
#[schema(example = "length")]
|
||||||
|
|
|
@ -3,10 +3,10 @@ use crate::health::Health;
|
||||||
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::{
|
use crate::{
|
||||||
BestOfSequence, ChatCompletion, ChatCompletionChunk, ChatRequest, CompatGenerateRequest,
|
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
|
||||||
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
|
ChatRequest, CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters,
|
||||||
HubModelInfo, HubTokenizerConfig, Infer, Info, PrefillToken, SimpleToken, StreamDetails,
|
GenerateRequest, GenerateResponse, HubModelInfo, HubTokenizerConfig, Infer, Info, Message,
|
||||||
StreamResponse, Token, Validation,
|
PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Validation,
|
||||||
};
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
|
@ -677,7 +677,7 @@ async fn chat_completions(
|
||||||
post,
|
post,
|
||||||
tag = "Text Generation Inference",
|
tag = "Text Generation Inference",
|
||||||
path = "/tokenize",
|
path = "/tokenize",
|
||||||
request_body = TokenizeRequest,
|
request_body = GenerateRequest,
|
||||||
responses(
|
responses(
|
||||||
(status = 200, description = "Tokenized ids", body = TokenizeResponse),
|
(status = 200, description = "Tokenized ids", body = TokenizeResponse),
|
||||||
(status = 404, description = "No tokenizer found", body = ErrorResponse,
|
(status = 404, description = "No tokenizer found", body = ErrorResponse,
|
||||||
|
@ -688,7 +688,7 @@ async fn chat_completions(
|
||||||
async fn tokenize(
|
async fn tokenize(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
Json(req): Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Json<TokenizeResponse>, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let input = req.inputs.clone();
|
let input = req.inputs.clone();
|
||||||
let encoding = infer.tokenize(req).await?;
|
let encoding = infer.tokenize(req).await?;
|
||||||
if let Some(encoding) = encoding {
|
if let Some(encoding) = encoding {
|
||||||
|
@ -706,7 +706,7 @@ async fn tokenize(
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
Ok(Json(tokens).into_response())
|
Ok(Json(TokenizeResponse(tokens)))
|
||||||
} else {
|
} else {
|
||||||
Err((
|
Err((
|
||||||
StatusCode::NOT_FOUND,
|
StatusCode::NOT_FOUND,
|
||||||
|
@ -774,10 +774,18 @@ pub async fn run(
|
||||||
Info,
|
Info,
|
||||||
CompatGenerateRequest,
|
CompatGenerateRequest,
|
||||||
GenerateRequest,
|
GenerateRequest,
|
||||||
|
ChatRequest,
|
||||||
|
Message,
|
||||||
|
ChatCompletionChoice,
|
||||||
|
ChatCompletionDelta,
|
||||||
|
ChatCompletionChunk,
|
||||||
|
ChatCompletion,
|
||||||
GenerateParameters,
|
GenerateParameters,
|
||||||
PrefillToken,
|
PrefillToken,
|
||||||
Token,
|
Token,
|
||||||
GenerateResponse,
|
GenerateResponse,
|
||||||
|
TokenizeResponse,
|
||||||
|
SimpleToken,
|
||||||
BestOfSequence,
|
BestOfSequence,
|
||||||
Details,
|
Details,
|
||||||
FinishReason,
|
FinishReason,
|
||||||
|
|
Loading…
Reference in New Issue