Update the docs to include newer models. (#1492)

This commit is contained in:
Nicolas Patry 2024-01-26 16:07:31 +01:00 committed by GitHub
parent 50a20a83d7
commit ebecc06161
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 41 additions and 14 deletions

File diff suppressed because one or more lines are too long

View File

@ -188,18 +188,20 @@ fn default_parameters() -> GenerateParameters {
}
}
#[derive(Clone, Deserialize, Serialize)]
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletion {
pub id: String,
pub object: String,
#[schema(example = "1706270835")]
pub created: u64,
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
pub model: String,
pub system_fingerprint: String,
pub choices: Vec<ChatCompletionComplete>,
pub usage: Usage,
}
#[derive(Clone, Deserialize, Serialize)]
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionComplete {
pub index: u32,
pub message: Message,
@ -248,17 +250,19 @@ impl ChatCompletion {
}
}
#[derive(Clone, Deserialize, Serialize)]
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionChunk {
pub id: String,
pub object: String,
#[schema(example = "1706270978")]
pub created: u64,
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
pub model: String,
pub system_fingerprint: String,
pub choices: Vec<ChatCompletionChoice>,
}
#[derive(Clone, Deserialize, Serialize)]
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionChoice {
pub index: u32,
pub delta: ChatCompletionDelta,
@ -266,9 +270,11 @@ pub(crate) struct ChatCompletionChoice {
pub finish_reason: Option<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionDelta {
#[schema(example = "user")]
pub role: String,
#[schema(example = "What is Deep Learning?")]
pub content: String,
}
@ -311,7 +317,7 @@ fn default_request_messages() -> Vec<Message> {
#[derive(Clone, Deserialize, ToSchema, Serialize)]
pub(crate) struct ChatRequest {
/// UNUSED
#[schema(example = "bigscience/blomm-560m")]
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: String, /* NOTE: UNUSED */
@ -322,6 +328,7 @@ pub(crate) struct ChatRequest {
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
/// decreasing the model's likelihood to repeat the same line verbatim.
#[serde(default)]
#[schema(example = "1.0")]
pub frequency_penalty: Option<f32>,
/// UNUSED
@ -336,28 +343,33 @@ pub(crate) struct ChatRequest {
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
/// output token returned in the content of message.
#[serde(default)]
#[schema(example = "false")]
pub logprobs: Option<bool>,
/// UNUSED
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
/// an associated log probability. logprobs must be set to true if this parameter is used.
#[serde(default)]
#[schema(example = "5")]
pub top_logprobs: Option<u32>,
/// The maximum number of tokens that can be generated in the chat completion.
#[serde(default)]
#[schema(example = "32")]
pub max_tokens: Option<u32>,
/// UNUSED
/// How many chat completion choices to generate for each input message. Note that you will be charged based on the
/// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
#[serde(default)]
#[schema(nullable = true, example = "2")]
pub n: Option<u32>,
/// UNUSED
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
/// increasing the model's likelihood to talk about new topics
#[serde(default)]
#[schema(nullable = true, example = 0.1)]
pub presence_penalty: Option<f32>,
#[serde(default = "bool::default")]
@ -371,11 +383,13 @@ pub(crate) struct ChatRequest {
///
/// We generally recommend altering this or `top_p` but not both.
#[serde(default)]
#[schema(nullable = true, example = 1.0)]
pub temperature: Option<f32>,
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
#[serde(default)]
#[schema(nullable = true, example = 0.95)]
pub top_p: Option<f32>,
}
@ -458,6 +472,7 @@ pub struct SimpleToken {
#[derive(Serialize, ToSchema)]
#[serde(rename_all(serialize = "snake_case"))]
#[schema(example = "Length")]
pub(crate) enum FinishReason {
#[schema(rename = "length")]
Length,
@ -518,6 +533,10 @@ pub(crate) struct GenerateResponse {
pub details: Option<Details>,
}
#[derive(Serialize, ToSchema)]
#[serde(transparent)]
pub(crate) struct TokenizeResponse(Vec<SimpleToken>);
#[derive(Serialize, ToSchema)]
pub(crate) struct StreamDetails {
#[schema(example = "length")]

View File

@ -3,10 +3,10 @@ use crate::health::Health;
use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError;
use crate::{
BestOfSequence, ChatCompletion, ChatCompletionChunk, ChatRequest, CompatGenerateRequest,
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
HubModelInfo, HubTokenizerConfig, Infer, Info, PrefillToken, SimpleToken, StreamDetails,
StreamResponse, Token, Validation,
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
ChatRequest, CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters,
GenerateRequest, GenerateResponse, HubModelInfo, HubTokenizerConfig, Infer, Info, Message,
PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Validation,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
@ -677,7 +677,7 @@ async fn chat_completions(
post,
tag = "Text Generation Inference",
path = "/tokenize",
request_body = TokenizeRequest,
request_body = GenerateRequest,
responses(
(status = 200, description = "Tokenized ids", body = TokenizeResponse),
(status = 404, description = "No tokenizer found", body = ErrorResponse,
@ -688,7 +688,7 @@ async fn chat_completions(
async fn tokenize(
Extension(infer): Extension<Infer>,
Json(req): Json<GenerateRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
) -> Result<Json<TokenizeResponse>, (StatusCode, Json<ErrorResponse>)> {
let input = req.inputs.clone();
let encoding = infer.tokenize(req).await?;
if let Some(encoding) = encoding {
@ -706,7 +706,7 @@ async fn tokenize(
}
})
.collect();
Ok(Json(tokens).into_response())
Ok(Json(TokenizeResponse(tokens)))
} else {
Err((
StatusCode::NOT_FOUND,
@ -774,10 +774,18 @@ pub async fn run(
Info,
CompatGenerateRequest,
GenerateRequest,
ChatRequest,
Message,
ChatCompletionChoice,
ChatCompletionDelta,
ChatCompletionChunk,
ChatCompletion,
GenerateParameters,
PrefillToken,
Token,
GenerateResponse,
TokenizeResponse,
SimpleToken,
BestOfSequence,
Details,
FinishReason,