diff --git a/router/src/lib.rs b/router/src/lib.rs index 2f93ec0a..7a1707d9 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -21,6 +21,7 @@ pub struct HubModelInfo { #[derive(Clone, Debug, Serialize, ToSchema)] pub struct Info { + /// Model info #[schema(example = "bigscience/blomm-560m")] pub model_id: String, #[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")] @@ -31,6 +32,26 @@ pub struct Info { pub model_device_type: String, #[schema(nullable = true, example = "text-generation")] pub model_pipeline_tag: Option, + /// Router Parameters + #[schema(example = "128")] + pub max_concurrent_requests: usize, + #[schema(example = "2")] + pub max_best_of: usize, + #[schema(example = "4")] + pub max_stop_sequences: usize, + #[schema(example = "1024")] + pub max_input_length: usize, + #[schema(example = "2048")] + pub max_total_tokens: usize, + #[schema(example = "1.2")] + pub waiting_served_ratio: f32, + #[schema(example = "32000")] + pub max_batch_total_tokens: u32, + #[schema(example = "20")] + pub max_waiting_tokens: usize, + #[schema(example = "2")] + pub validation_workers: usize, + /// Router Info #[schema(example = "0.5.0")] pub version: &'static str, #[schema(nullable = true, example = "null")] diff --git a/router/src/server.rs b/router/src/server.rs index d1f7ae12..9540ba18 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -78,22 +78,8 @@ async fn compat_generate( responses((status = 200, description = "Served model info", body = Info)) )] #[instrument] -async fn get_model_info( - model_info: Extension, - shard_info: Extension, -) -> Json { - let model_info = model_info.0; - let shard_info = shard_info.0; - let info = Info { - version: env!("CARGO_PKG_VERSION"), - sha: option_env!("VERGEN_GIT_SHA"), - model_id: model_info.model_id, - model_sha: model_info.sha, - model_dtype: shard_info.dtype, - model_device_type: shard_info.device_type, - model_pipeline_tag: model_info.pipeline_tag, - }; - Json(info) +async fn get_model_info(info: Extension) -> Json { + Json(info.0) } /// Health check method @@ -632,6 +618,26 @@ pub async fn run( .allow_headers([http::header::CONTENT_TYPE]) .allow_origin(allow_origin); + // Endpoint info + let info = Info { + model_id: model_info.model_id, + model_sha: model_info.sha, + model_dtype: shard_info.dtype, + model_device_type: shard_info.device_type, + model_pipeline_tag: model_info.pipeline_tag, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_input_length, + max_total_tokens, + waiting_served_ratio, + max_batch_total_tokens, + max_waiting_tokens, + validation_workers, + version: env!("CARGO_PKG_VERSION"), + sha: option_env!("VERGEN_GIT_SHA"), + }; + // Create router let app = Router::new() .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi())) @@ -650,8 +656,7 @@ pub async fn run( .route("/ping", get(health)) // Prometheus metrics route .route("/metrics", get(metrics)) - .layer(Extension(model_info)) - .layer(Extension(shard_info)) + .layer(Extension(info)) .layer(Extension(compat_return_full_text)) .layer(Extension(infer)) .layer(Extension(prom_handle))