feat(router): add endpoint info to /info route (#228)
This commit is contained in:
parent
ebc74d5666
commit
8b182eb986
|
@ -21,6 +21,7 @@ pub struct HubModelInfo {
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||||
pub struct Info {
|
pub struct Info {
|
||||||
|
/// Model info
|
||||||
#[schema(example = "bigscience/blomm-560m")]
|
#[schema(example = "bigscience/blomm-560m")]
|
||||||
pub model_id: String,
|
pub model_id: String,
|
||||||
#[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")]
|
#[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")]
|
||||||
|
@ -31,6 +32,26 @@ pub struct Info {
|
||||||
pub model_device_type: String,
|
pub model_device_type: String,
|
||||||
#[schema(nullable = true, example = "text-generation")]
|
#[schema(nullable = true, example = "text-generation")]
|
||||||
pub model_pipeline_tag: Option<String>,
|
pub model_pipeline_tag: Option<String>,
|
||||||
|
/// Router Parameters
|
||||||
|
#[schema(example = "128")]
|
||||||
|
pub max_concurrent_requests: usize,
|
||||||
|
#[schema(example = "2")]
|
||||||
|
pub max_best_of: usize,
|
||||||
|
#[schema(example = "4")]
|
||||||
|
pub max_stop_sequences: usize,
|
||||||
|
#[schema(example = "1024")]
|
||||||
|
pub max_input_length: usize,
|
||||||
|
#[schema(example = "2048")]
|
||||||
|
pub max_total_tokens: usize,
|
||||||
|
#[schema(example = "1.2")]
|
||||||
|
pub waiting_served_ratio: f32,
|
||||||
|
#[schema(example = "32000")]
|
||||||
|
pub max_batch_total_tokens: u32,
|
||||||
|
#[schema(example = "20")]
|
||||||
|
pub max_waiting_tokens: usize,
|
||||||
|
#[schema(example = "2")]
|
||||||
|
pub validation_workers: usize,
|
||||||
|
/// Router Info
|
||||||
#[schema(example = "0.5.0")]
|
#[schema(example = "0.5.0")]
|
||||||
pub version: &'static str,
|
pub version: &'static str,
|
||||||
#[schema(nullable = true, example = "null")]
|
#[schema(nullable = true, example = "null")]
|
||||||
|
|
|
@ -78,22 +78,8 @@ async fn compat_generate(
|
||||||
responses((status = 200, description = "Served model info", body = Info))
|
responses((status = 200, description = "Served model info", body = Info))
|
||||||
)]
|
)]
|
||||||
#[instrument]
|
#[instrument]
|
||||||
async fn get_model_info(
|
async fn get_model_info(info: Extension<Info>) -> Json<Info> {
|
||||||
model_info: Extension<HubModelInfo>,
|
Json(info.0)
|
||||||
shard_info: Extension<ShardInfo>,
|
|
||||||
) -> Json<Info> {
|
|
||||||
let model_info = model_info.0;
|
|
||||||
let shard_info = shard_info.0;
|
|
||||||
let info = Info {
|
|
||||||
version: env!("CARGO_PKG_VERSION"),
|
|
||||||
sha: option_env!("VERGEN_GIT_SHA"),
|
|
||||||
model_id: model_info.model_id,
|
|
||||||
model_sha: model_info.sha,
|
|
||||||
model_dtype: shard_info.dtype,
|
|
||||||
model_device_type: shard_info.device_type,
|
|
||||||
model_pipeline_tag: model_info.pipeline_tag,
|
|
||||||
};
|
|
||||||
Json(info)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Health check method
|
/// Health check method
|
||||||
|
@ -632,6 +618,26 @@ pub async fn run(
|
||||||
.allow_headers([http::header::CONTENT_TYPE])
|
.allow_headers([http::header::CONTENT_TYPE])
|
||||||
.allow_origin(allow_origin);
|
.allow_origin(allow_origin);
|
||||||
|
|
||||||
|
// Endpoint info
|
||||||
|
let info = Info {
|
||||||
|
model_id: model_info.model_id,
|
||||||
|
model_sha: model_info.sha,
|
||||||
|
model_dtype: shard_info.dtype,
|
||||||
|
model_device_type: shard_info.device_type,
|
||||||
|
model_pipeline_tag: model_info.pipeline_tag,
|
||||||
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
validation_workers,
|
||||||
|
version: env!("CARGO_PKG_VERSION"),
|
||||||
|
sha: option_env!("VERGEN_GIT_SHA"),
|
||||||
|
};
|
||||||
|
|
||||||
// Create router
|
// Create router
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
||||||
|
@ -650,8 +656,7 @@ pub async fn run(
|
||||||
.route("/ping", get(health))
|
.route("/ping", get(health))
|
||||||
// Prometheus metrics route
|
// Prometheus metrics route
|
||||||
.route("/metrics", get(metrics))
|
.route("/metrics", get(metrics))
|
||||||
.layer(Extension(model_info))
|
.layer(Extension(info))
|
||||||
.layer(Extension(shard_info))
|
|
||||||
.layer(Extension(compat_return_full_text))
|
.layer(Extension(compat_return_full_text))
|
||||||
.layer(Extension(infer))
|
.layer(Extension(infer))
|
||||||
.layer(Extension(prom_handle))
|
.layer(Extension(prom_handle))
|
||||||
|
|
Loading…
Reference in New Issue