feat(router): add endpoint info to /info route (#228)
This commit is contained in:
parent
ebc74d5666
commit
8b182eb986
|
@ -21,6 +21,7 @@ pub struct HubModelInfo {
|
|||
|
||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||
pub struct Info {
|
||||
/// Model info
|
||||
#[schema(example = "bigscience/blomm-560m")]
|
||||
pub model_id: String,
|
||||
#[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")]
|
||||
|
@ -31,6 +32,26 @@ pub struct Info {
|
|||
pub model_device_type: String,
|
||||
#[schema(nullable = true, example = "text-generation")]
|
||||
pub model_pipeline_tag: Option<String>,
|
||||
/// Router Parameters
|
||||
#[schema(example = "128")]
|
||||
pub max_concurrent_requests: usize,
|
||||
#[schema(example = "2")]
|
||||
pub max_best_of: usize,
|
||||
#[schema(example = "4")]
|
||||
pub max_stop_sequences: usize,
|
||||
#[schema(example = "1024")]
|
||||
pub max_input_length: usize,
|
||||
#[schema(example = "2048")]
|
||||
pub max_total_tokens: usize,
|
||||
#[schema(example = "1.2")]
|
||||
pub waiting_served_ratio: f32,
|
||||
#[schema(example = "32000")]
|
||||
pub max_batch_total_tokens: u32,
|
||||
#[schema(example = "20")]
|
||||
pub max_waiting_tokens: usize,
|
||||
#[schema(example = "2")]
|
||||
pub validation_workers: usize,
|
||||
/// Router Info
|
||||
#[schema(example = "0.5.0")]
|
||||
pub version: &'static str,
|
||||
#[schema(nullable = true, example = "null")]
|
||||
|
|
|
@ -78,22 +78,8 @@ async fn compat_generate(
|
|||
responses((status = 200, description = "Served model info", body = Info))
|
||||
)]
|
||||
#[instrument]
|
||||
async fn get_model_info(
|
||||
model_info: Extension<HubModelInfo>,
|
||||
shard_info: Extension<ShardInfo>,
|
||||
) -> Json<Info> {
|
||||
let model_info = model_info.0;
|
||||
let shard_info = shard_info.0;
|
||||
let info = Info {
|
||||
version: env!("CARGO_PKG_VERSION"),
|
||||
sha: option_env!("VERGEN_GIT_SHA"),
|
||||
model_id: model_info.model_id,
|
||||
model_sha: model_info.sha,
|
||||
model_dtype: shard_info.dtype,
|
||||
model_device_type: shard_info.device_type,
|
||||
model_pipeline_tag: model_info.pipeline_tag,
|
||||
};
|
||||
Json(info)
|
||||
async fn get_model_info(info: Extension<Info>) -> Json<Info> {
|
||||
Json(info.0)
|
||||
}
|
||||
|
||||
/// Health check method
|
||||
|
@ -632,6 +618,26 @@ pub async fn run(
|
|||
.allow_headers([http::header::CONTENT_TYPE])
|
||||
.allow_origin(allow_origin);
|
||||
|
||||
// Endpoint info
|
||||
let info = Info {
|
||||
model_id: model_info.model_id,
|
||||
model_sha: model_info.sha,
|
||||
model_dtype: shard_info.dtype,
|
||||
model_device_type: shard_info.device_type,
|
||||
model_pipeline_tag: model_info.pipeline_tag,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
validation_workers,
|
||||
version: env!("CARGO_PKG_VERSION"),
|
||||
sha: option_env!("VERGEN_GIT_SHA"),
|
||||
};
|
||||
|
||||
// Create router
|
||||
let app = Router::new()
|
||||
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
||||
|
@ -650,8 +656,7 @@ pub async fn run(
|
|||
.route("/ping", get(health))
|
||||
// Prometheus metrics route
|
||||
.route("/metrics", get(metrics))
|
||||
.layer(Extension(model_info))
|
||||
.layer(Extension(shard_info))
|
||||
.layer(Extension(info))
|
||||
.layer(Extension(compat_return_full_text))
|
||||
.layer(Extension(infer))
|
||||
.layer(Extension(prom_handle))
|
||||
|
|
Loading…
Reference in New Issue