feat: add /v1/models endpoint (#2433)

* feat: add /v1/models endpoint

* feat: add /v1/models endpoint

* fix: remove unused type import

* fix: revert route typo

* fix: update docs with new endpoint

* fix: add to redocly ignore and lint
This commit is contained in:
drbh 2024-08-29 10:32:38 -04:00 committed by GitHub
parent e415b690a6
commit d5202c46f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 117 additions and 1 deletions

View File

@ -77,3 +77,4 @@ docs/openapi.json:
- '#/paths/~1tokenize/post' - '#/paths/~1tokenize/post'
- '#/paths/~1v1~1chat~1completions/post' - '#/paths/~1v1~1chat~1completions/post'
- '#/paths/~1v1~1completions/post' - '#/paths/~1v1~1completions/post'
- '#/paths/~1v1~1models/get'

View File

@ -556,6 +556,37 @@
} }
} }
} }
},
"/v1/models": {
"get": {
"tags": [
"Text Generation Inference"
],
"summary": "Get model info",
"operationId": "openai_get_model_info",
"responses": {
"200": {
"description": "Served model info",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ModelInfo"
}
}
}
},
"404": {
"description": "Model not found",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
}
}
}
}
}
}
} }
}, },
"components": { "components": {
@ -1747,6 +1778,35 @@
} }
] ]
}, },
"ModelInfo": {
"type": "object",
"required": [
"id",
"object",
"created",
"owned_by"
],
"properties": {
"created": {
"type": "integer",
"format": "int64",
"example": 1686935002,
"minimum": 0
},
"id": {
"type": "string",
"example": "gpt2"
},
"object": {
"type": "string",
"example": "model"
},
"owned_by": {
"type": "string",
"example": "openai"
}
}
},
"OutputMessage": { "OutputMessage": {
"oneOf": [ "oneOf": [
{ {

View File

@ -1261,6 +1261,34 @@ pub(crate) struct ErrorResponse {
pub error_type: String, pub error_type: String,
} }
#[derive(Serialize, Deserialize, ToSchema)]
pub(crate) struct ModelInfo {
#[schema(example = "gpt2")]
pub id: String,
#[schema(example = "model")]
pub object: String,
#[schema(example = 1686935002)]
pub created: u64,
#[schema(example = "openai")]
pub owned_by: String,
}
#[derive(Serialize, Deserialize, ToSchema)]
pub(crate) struct ModelsInfo {
#[schema(example = "list")]
pub object: String,
pub data: Vec<ModelInfo>,
}
impl Default for ModelsInfo {
fn default() -> Self {
ModelsInfo {
object: "list".to_string(),
data: Vec::new(),
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View File

@ -24,6 +24,7 @@ use crate::{
VertexResponse, VertexResponse,
}; };
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
use crate::{ModelInfo, ModelsInfo};
use async_stream::__private::AsyncStream; use async_stream::__private::AsyncStream;
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode}; use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
@ -116,6 +117,29 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
Json(info.0) Json(info.0)
} }
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/v1/models",
responses(
(status = 200, description = "Served model info", body = ModelInfo),
(status = 404, description = "Model not found", body = ErrorResponse),
)
)]
#[instrument(skip(info))]
/// Get model info
async fn openai_get_model_info(info: Extension<Info>) -> Json<ModelsInfo> {
Json(ModelsInfo {
data: vec![ModelInfo {
id: info.0.model_id.clone(),
object: "model".to_string(),
created: 0, // TODO: determine how to get this
owned_by: info.0.model_id.clone(),
}],
..Default::default()
})
}
#[utoipa::path( #[utoipa::path(
post, post,
tag = "Text Generation Inference", tag = "Text Generation Inference",
@ -1505,6 +1529,7 @@ chat_completions,
completions, completions,
tokenize, tokenize,
metrics, metrics,
openai_get_model_info,
), ),
components( components(
schemas( schemas(
@ -1557,6 +1582,7 @@ ToolCall,
Function, Function,
FunctionDefinition, FunctionDefinition,
ToolChoice, ToolChoice,
ModelInfo,
) )
), ),
tags( tags(
@ -2250,7 +2276,8 @@ async fn start(
.route("/info", get(get_model_info)) .route("/info", get(get_model_info))
.route("/health", get(health)) .route("/health", get(health))
.route("/ping", get(health)) .route("/ping", get(health))
.route("/metrics", get(metrics)); .route("/metrics", get(metrics))
.route("/v1/models", get(openai_get_model_info));
// Conditional AWS Sagemaker route // Conditional AWS Sagemaker route
let aws_sagemaker_route = if messages_api_enabled { let aws_sagemaker_route = if messages_api_enabled {