feat(router): add endpoint info to /info route (#228)

This commit is contained in:
OlivierDehaene 2023-04-25 13:11:18 +02:00 committed by GitHub
parent ebc74d5666
commit 8b182eb986
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 44 additions and 18 deletions

View File

@ -21,6 +21,7 @@ pub struct HubModelInfo {
#[derive(Clone, Debug, Serialize, ToSchema)] #[derive(Clone, Debug, Serialize, ToSchema)]
pub struct Info { pub struct Info {
/// Model info
#[schema(example = "bigscience/blomm-560m")] #[schema(example = "bigscience/blomm-560m")]
pub model_id: String, pub model_id: String,
#[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")] #[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")]
@ -31,6 +32,26 @@ pub struct Info {
pub model_device_type: String, pub model_device_type: String,
#[schema(nullable = true, example = "text-generation")] #[schema(nullable = true, example = "text-generation")]
pub model_pipeline_tag: Option<String>, pub model_pipeline_tag: Option<String>,
/// Router Parameters
#[schema(example = "128")]
pub max_concurrent_requests: usize,
#[schema(example = "2")]
pub max_best_of: usize,
#[schema(example = "4")]
pub max_stop_sequences: usize,
#[schema(example = "1024")]
pub max_input_length: usize,
#[schema(example = "2048")]
pub max_total_tokens: usize,
#[schema(example = "1.2")]
pub waiting_served_ratio: f32,
#[schema(example = "32000")]
pub max_batch_total_tokens: u32,
#[schema(example = "20")]
pub max_waiting_tokens: usize,
#[schema(example = "2")]
pub validation_workers: usize,
/// Router Info
#[schema(example = "0.5.0")] #[schema(example = "0.5.0")]
pub version: &'static str, pub version: &'static str,
#[schema(nullable = true, example = "null")] #[schema(nullable = true, example = "null")]

View File

@ -78,22 +78,8 @@ async fn compat_generate(
responses((status = 200, description = "Served model info", body = Info)) responses((status = 200, description = "Served model info", body = Info))
)] )]
#[instrument] #[instrument]
async fn get_model_info( async fn get_model_info(info: Extension<Info>) -> Json<Info> {
model_info: Extension<HubModelInfo>, Json(info.0)
shard_info: Extension<ShardInfo>,
) -> Json<Info> {
let model_info = model_info.0;
let shard_info = shard_info.0;
let info = Info {
version: env!("CARGO_PKG_VERSION"),
sha: option_env!("VERGEN_GIT_SHA"),
model_id: model_info.model_id,
model_sha: model_info.sha,
model_dtype: shard_info.dtype,
model_device_type: shard_info.device_type,
model_pipeline_tag: model_info.pipeline_tag,
};
Json(info)
} }
/// Health check method /// Health check method
@ -632,6 +618,26 @@ pub async fn run(
.allow_headers([http::header::CONTENT_TYPE]) .allow_headers([http::header::CONTENT_TYPE])
.allow_origin(allow_origin); .allow_origin(allow_origin);
// Endpoint info
let info = Info {
model_id: model_info.model_id,
model_sha: model_info.sha,
model_dtype: shard_info.dtype,
model_device_type: shard_info.device_type,
model_pipeline_tag: model_info.pipeline_tag,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_input_length,
max_total_tokens,
waiting_served_ratio,
max_batch_total_tokens,
max_waiting_tokens,
validation_workers,
version: env!("CARGO_PKG_VERSION"),
sha: option_env!("VERGEN_GIT_SHA"),
};
// Create router // Create router
let app = Router::new() let app = Router::new()
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi())) .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
@ -650,8 +656,7 @@ pub async fn run(
.route("/ping", get(health)) .route("/ping", get(health))
// Prometheus metrics route // Prometheus metrics route
.route("/metrics", get(metrics)) .route("/metrics", get(metrics))
.layer(Extension(model_info)) .layer(Extension(info))
.layer(Extension(shard_info))
.layer(Extension(compat_return_full_text)) .layer(Extension(compat_return_full_text))
.layer(Extension(infer)) .layer(Extension(infer))
.layer(Extension(prom_handle)) .layer(Extension(prom_handle))