Update the docs to include newer models. (#1492)

This commit is contained in:
Nicolas Patry 2024-01-26 16:07:31 +01:00 committed by GitHub
parent 50a20a83d7
commit ebecc06161
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 41 additions and 14 deletions

File diff suppressed because one or more lines are too long

View File

@ -188,18 +188,20 @@ fn default_parameters() -> GenerateParameters {
} }
} }
#[derive(Clone, Deserialize, Serialize)] #[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletion { pub(crate) struct ChatCompletion {
pub id: String, pub id: String,
pub object: String, pub object: String,
#[schema(example = "1706270835")]
pub created: u64, pub created: u64,
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
pub model: String, pub model: String,
pub system_fingerprint: String, pub system_fingerprint: String,
pub choices: Vec<ChatCompletionComplete>, pub choices: Vec<ChatCompletionComplete>,
pub usage: Usage, pub usage: Usage,
} }
#[derive(Clone, Deserialize, Serialize)] #[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionComplete { pub(crate) struct ChatCompletionComplete {
pub index: u32, pub index: u32,
pub message: Message, pub message: Message,
@ -248,17 +250,19 @@ impl ChatCompletion {
} }
} }
#[derive(Clone, Deserialize, Serialize)] #[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionChunk { pub(crate) struct ChatCompletionChunk {
pub id: String, pub id: String,
pub object: String, pub object: String,
#[schema(example = "1706270978")]
pub created: u64, pub created: u64,
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
pub model: String, pub model: String,
pub system_fingerprint: String, pub system_fingerprint: String,
pub choices: Vec<ChatCompletionChoice>, pub choices: Vec<ChatCompletionChoice>,
} }
#[derive(Clone, Deserialize, Serialize)] #[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionChoice { pub(crate) struct ChatCompletionChoice {
pub index: u32, pub index: u32,
pub delta: ChatCompletionDelta, pub delta: ChatCompletionDelta,
@ -266,9 +270,11 @@ pub(crate) struct ChatCompletionChoice {
pub finish_reason: Option<String>, pub finish_reason: Option<String>,
} }
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionDelta { pub(crate) struct ChatCompletionDelta {
#[schema(example = "user")]
pub role: String, pub role: String,
#[schema(example = "What is Deep Learning?")]
pub content: String, pub content: String,
} }
@ -311,7 +317,7 @@ fn default_request_messages() -> Vec<Message> {
#[derive(Clone, Deserialize, ToSchema, Serialize)] #[derive(Clone, Deserialize, ToSchema, Serialize)]
pub(crate) struct ChatRequest { pub(crate) struct ChatRequest {
/// UNUSED /// UNUSED
#[schema(example = "bigscience/blomm-560m")] #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. /// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: String, /* NOTE: UNUSED */ pub model: String, /* NOTE: UNUSED */
@ -322,6 +328,7 @@ pub(crate) struct ChatRequest {
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
/// decreasing the model's likelihood to repeat the same line verbatim. /// decreasing the model's likelihood to repeat the same line verbatim.
#[serde(default)] #[serde(default)]
#[schema(example = "1.0")]
pub frequency_penalty: Option<f32>, pub frequency_penalty: Option<f32>,
/// UNUSED /// UNUSED
@ -336,28 +343,33 @@ pub(crate) struct ChatRequest {
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each /// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
/// output token returned in the content of message. /// output token returned in the content of message.
#[serde(default)] #[serde(default)]
#[schema(example = "false")]
pub logprobs: Option<bool>, pub logprobs: Option<bool>,
/// UNUSED /// UNUSED
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with /// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
/// an associated log probability. logprobs must be set to true if this parameter is used. /// an associated log probability. logprobs must be set to true if this parameter is used.
#[serde(default)] #[serde(default)]
#[schema(example = "5")]
pub top_logprobs: Option<u32>, pub top_logprobs: Option<u32>,
/// The maximum number of tokens that can be generated in the chat completion. /// The maximum number of tokens that can be generated in the chat completion.
#[serde(default)] #[serde(default)]
#[schema(example = "32")]
pub max_tokens: Option<u32>, pub max_tokens: Option<u32>,
/// UNUSED /// UNUSED
/// How many chat completion choices to generate for each input message. Note that you will be charged based on the /// How many chat completion choices to generate for each input message. Note that you will be charged based on the
/// number of generated tokens across all of the choices. Keep n as 1 to minimize costs. /// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
#[serde(default)] #[serde(default)]
#[schema(nullable = true, example = "2")]
pub n: Option<u32>, pub n: Option<u32>,
/// UNUSED /// UNUSED
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
/// increasing the model's likelihood to talk about new topics /// increasing the model's likelihood to talk about new topics
#[serde(default)] #[serde(default)]
#[schema(nullable = true, example = 0.1)]
pub presence_penalty: Option<f32>, pub presence_penalty: Option<f32>,
#[serde(default = "bool::default")] #[serde(default = "bool::default")]
@ -371,11 +383,13 @@ pub(crate) struct ChatRequest {
/// ///
/// We generally recommend altering this or `top_p` but not both. /// We generally recommend altering this or `top_p` but not both.
#[serde(default)] #[serde(default)]
#[schema(nullable = true, example = 1.0)]
pub temperature: Option<f32>, pub temperature: Option<f32>,
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. /// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
#[serde(default)] #[serde(default)]
#[schema(nullable = true, example = 0.95)]
pub top_p: Option<f32>, pub top_p: Option<f32>,
} }
@ -458,6 +472,7 @@ pub struct SimpleToken {
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
#[serde(rename_all(serialize = "snake_case"))] #[serde(rename_all(serialize = "snake_case"))]
#[schema(example = "Length")]
pub(crate) enum FinishReason { pub(crate) enum FinishReason {
#[schema(rename = "length")] #[schema(rename = "length")]
Length, Length,
@ -518,6 +533,10 @@ pub(crate) struct GenerateResponse {
pub details: Option<Details>, pub details: Option<Details>,
} }
#[derive(Serialize, ToSchema)]
#[serde(transparent)]
pub(crate) struct TokenizeResponse(Vec<SimpleToken>);
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
pub(crate) struct StreamDetails { pub(crate) struct StreamDetails {
#[schema(example = "length")] #[schema(example = "length")]

View File

@ -3,10 +3,10 @@ use crate::health::Health;
use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, ChatCompletion, ChatCompletionChunk, ChatRequest, CompatGenerateRequest, BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, ChatRequest, CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters,
HubModelInfo, HubTokenizerConfig, Infer, Info, PrefillToken, SimpleToken, StreamDetails, GenerateRequest, GenerateResponse, HubModelInfo, HubTokenizerConfig, Infer, Info, Message,
StreamResponse, Token, Validation, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Validation,
}; };
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
@ -677,7 +677,7 @@ async fn chat_completions(
post, post,
tag = "Text Generation Inference", tag = "Text Generation Inference",
path = "/tokenize", path = "/tokenize",
request_body = TokenizeRequest, request_body = GenerateRequest,
responses( responses(
(status = 200, description = "Tokenized ids", body = TokenizeResponse), (status = 200, description = "Tokenized ids", body = TokenizeResponse),
(status = 404, description = "No tokenizer found", body = ErrorResponse, (status = 404, description = "No tokenizer found", body = ErrorResponse,
@ -688,7 +688,7 @@ async fn chat_completions(
async fn tokenize( async fn tokenize(
Extension(infer): Extension<Infer>, Extension(infer): Extension<Infer>,
Json(req): Json<GenerateRequest>, Json(req): Json<GenerateRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Json<TokenizeResponse>, (StatusCode, Json<ErrorResponse>)> {
let input = req.inputs.clone(); let input = req.inputs.clone();
let encoding = infer.tokenize(req).await?; let encoding = infer.tokenize(req).await?;
if let Some(encoding) = encoding { if let Some(encoding) = encoding {
@ -706,7 +706,7 @@ async fn tokenize(
} }
}) })
.collect(); .collect();
Ok(Json(tokens).into_response()) Ok(Json(TokenizeResponse(tokens)))
} else { } else {
Err(( Err((
StatusCode::NOT_FOUND, StatusCode::NOT_FOUND,
@ -774,10 +774,18 @@ pub async fn run(
Info, Info,
CompatGenerateRequest, CompatGenerateRequest,
GenerateRequest, GenerateRequest,
ChatRequest,
Message,
ChatCompletionChoice,
ChatCompletionDelta,
ChatCompletionChunk,
ChatCompletion,
GenerateParameters, GenerateParameters,
PrefillToken, PrefillToken,
Token, Token,
GenerateResponse, GenerateResponse,
TokenizeResponse,
SimpleToken,
BestOfSequence, BestOfSequence,
Details, Details,
FinishReason, FinishReason,