From 2475aede619c0c6d2ba8440303432d505c77f6d3 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Tue, 18 Apr 2023 16:16:06 +0200 Subject: [PATCH] feat(router): add info route (#196) close #125 --- Cargo.lock | 23 +++++ docs/openapi.json | 208 ++++++++++++++++++++++++++++++++++++++----- launcher/src/main.rs | 6 ++ router/Cargo.toml | 5 +- router/build.rs | 7 ++ router/src/lib.rs | 23 +++++ router/src/main.rs | 66 +++++++++----- router/src/server.rs | 188 +++++++++++++++++++++++--------------- 8 files changed, 409 insertions(+), 117 deletions(-) create mode 100644 router/build.rs diff --git a/Cargo.lock b/Cargo.lock index bc2529d4..8ffaf631 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2430,6 +2430,7 @@ dependencies = [ "tracing-subscriber", "utoipa", "utoipa-swagger-ui", + "vergen", ] [[package]] @@ -2468,8 +2469,10 @@ version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cd0cbfecb4d19b5ea75bb31ad904eb5b9fa13f21079c3b92017ebdf4999a5890" dependencies = [ + "itoa", "serde", "time-core", + "time-macros", ] [[package]] @@ -2478,6 +2481,15 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2e153e1f1acaef8acc537e68b44906d2db6436e2b35ac2c6b42640fff91f00fd" +[[package]] +name = "time-macros" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd80a657e71da814b8e5d60d3374fc6d35045062245d80224748ae522dd76f36" +dependencies = [ + "time-core", +] + [[package]] name = "tinyvec" version = "1.6.0" @@ -2966,6 +2978,17 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" +[[package]] +name = "vergen" +version = "8.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1b86a8af1dedf089b1c78338678e4c7492b6045649042d94faf19690499d236" +dependencies = [ + "anyhow", + "rustversion", + "time", +] + [[package]] name = "version_check" version = "0.9.4" diff --git a/docs/openapi.json b/docs/openapi.json index 377c388a..1be99f07 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -4,8 +4,7 @@ "title": "Text Generation Inference", "description": "Text Generation Webserver", "contact": { - "name": "Olivier Dehaene", - "email": "olivier@huggingface.co" + "name": "Olivier Dehaene" }, "license": { "name": "Apache 2.0", @@ -14,6 +13,83 @@ "version": "0.5.0" }, "paths": { + "/": { + "post": { + "tags": [ + "Text Generation Inference" + ], + "summary": "Generate tokens if `stream == false` or a stream of token if `stream == true`", + "description": "Generate tokens if `stream == false` or a stream of token if `stream == true`", + "operationId": "compat_generate", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CompatGenerateRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "See /generate or /generate_stream" + }, + "422": { + "description": "Input validation error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Input validation error" + } + } + } + }, + "424": { + "description": "Generation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Request failed during generation" + } + } + } + }, + "429": { + "description": "Model is overloaded", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Model is overloaded" + } + } + } + }, + "500": { + "description": "Incomplete generation", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Incomplete generation" + } + } + } + } + } + } + }, "/generate": { "post": { "tags": [ @@ -95,8 +171,7 @@ } } } - }, - "deprecated": false + } } }, "/generate_stream": { @@ -180,8 +255,29 @@ } } } - }, - "deprecated": false + } + } + }, + "/info": { + "get": { + "tags": [ + "Text Generation Inference" + ], + "summary": "Text Generation Inference endpoint info", + "description": "Text Generation Inference endpoint info", + "operationId": "get_model_info", + "responses": { + "200": { + "description": "Served model info", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Info" + } + } + } + } + } } }, "/metrics": { @@ -203,8 +299,7 @@ } } } - }, - "deprecated": false + } } } }, @@ -230,7 +325,8 @@ "generated_tokens": { "type": "integer", "format": "int32", - "example": 1 + "example": 1, + "minimum": 0.0 }, "prefill": { "type": "array", @@ -242,7 +338,8 @@ "type": "integer", "format": "int64", "example": 42, - "nullable": true + "nullable": true, + "minimum": 0.0 }, "tokens": { "type": "array", @@ -252,6 +349,24 @@ } } }, + "CompatGenerateRequest": { + "type": "object", + "required": [ + "inputs" + ], + "properties": { + "inputs": { + "type": "string", + "example": "My name is Olivier and I" + }, + "parameters": { + "$ref": "#/components/schemas/GenerateParameters" + }, + "stream": { + "type": "boolean" + } + } + }, "Details": { "type": "object", "required": [ @@ -265,7 +380,8 @@ "type": "array", "items": { "$ref": "#/components/schemas/BestOfSequence" - } + }, + "nullable": true }, "finish_reason": { "$ref": "#/components/schemas/FinishReason" @@ -273,7 +389,8 @@ "generated_tokens": { "type": "integer", "format": "int32", - "example": 1 + "example": 1, + "minimum": 0.0 }, "prefill": { "type": "array", @@ -285,7 +402,8 @@ "type": "integer", "format": "int64", "example": 42, - "nullable": true + "nullable": true, + "minimum": 0.0 }, "tokens": { "type": "array", @@ -326,6 +444,7 @@ "default": "null", "example": 1, "nullable": true, + "minimum": 0.0, "exclusiveMinimum": 0.0 }, "details": { @@ -341,6 +460,7 @@ "type": "integer", "format": "int32", "default": "20", + "minimum": 0.0, "exclusiveMaximum": 512.0, "exclusiveMinimum": 0.0 }, @@ -364,6 +484,7 @@ "default": "null", "example": "null", "nullable": true, + "minimum": 0.0, "exclusiveMinimum": 0.0 }, "stop": { @@ -405,7 +526,8 @@ "type": "integer", "default": "null", "example": "null", - "nullable": true + "nullable": true, + "minimum": 0.0 }, "typical_p": { "type": "number", @@ -445,7 +567,12 @@ ], "properties": { "details": { - "$ref": "#/components/schemas/Details" + "allOf": [ + { + "$ref": "#/components/schemas/Details" + } + ], + "nullable": true }, "generated_text": { "type": "string", @@ -453,6 +580,38 @@ } } }, + "Info": { + "type": "object", + "required": [ + "model_id", + "version" + ], + "properties": { + "model_id": { + "type": "string", + "example": "bigscience/blomm-560m" + }, + "model_pipeline_tag": { + "type": "string", + "example": "text-generation", + "nullable": true + }, + "model_sha": { + "type": "string", + "example": "e985a63cdc139290c5f700ff1929f0b5942cced2", + "nullable": true + }, + "sha": { + "type": "string", + "example": "null", + "nullable": true + }, + "version": { + "type": "string", + "example": "0.5.0" + } + } + }, "PrefillToken": { "type": "object", "required": [ @@ -464,7 +623,8 @@ "id": { "type": "integer", "format": "int32", - "example": 0 + "example": 0, + "minimum": 0.0 }, "logprob": { "type": "number", @@ -491,13 +651,15 @@ "generated_tokens": { "type": "integer", "format": "int32", - "example": 1 + "example": 1, + "minimum": 0.0 }, "seed": { "type": "integer", "format": "int64", "example": 42, - "nullable": true + "nullable": true, + "minimum": 0.0 } } }, @@ -508,7 +670,12 @@ ], "properties": { "details": { - "$ref": "#/components/schemas/StreamDetails" + "allOf": [ + { + "$ref": "#/components/schemas/StreamDetails" + } + ], + "nullable": true }, "generated_text": { "type": "string", @@ -533,7 +700,8 @@ "id": { "type": "integer", "format": "int32", - "example": 0 + "example": 0, + "minimum": 0.0 }, "logprob": { "type": "number", diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 0cf43f16..8dc6a798 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -392,6 +392,12 @@ fn main() -> ExitCode { model_id, ]; + // Model optional revision + if let Some(ref revision) = revision { + argv.push("--revision".to_string()); + argv.push(revision.to_string()) + } + if json_output { argv.push("--json-output".to_string()); } diff --git a/router/Cargo.toml b/router/Cargo.toml index e2a6f5e9..aa8e9df2 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -4,6 +4,7 @@ version = "0.5.0" edition = "2021" authors = ["Olivier Dehaene"] description = "Text Generation Webserver" +build = "build.rs" [lib] path = "src/lib.rs" @@ -26,7 +27,7 @@ nohash-hasher = "0.2.0" opentelemetry = { version = "0.18.0", features = ["rt-tokio"] } opentelemetry-otlp = "0.11.0" rand = "0.8.5" -reqwest = { version = "0.11.14", features = [] } +reqwest = { version = "0.11.14", features = [] } serde = "1.0.152" serde_json = "1.0.93" thiserror = "1.0.38" @@ -39,3 +40,5 @@ tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] } utoipa = { version = "3.0.1", features = ["axum_extras"] } utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] } +[build-dependencies] +vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] } diff --git a/router/build.rs b/router/build.rs new file mode 100644 index 00000000..c34f9fa8 --- /dev/null +++ b/router/build.rs @@ -0,0 +1,7 @@ +use std::error::Error; +use vergen::EmitBuilder; + +fn main() -> Result<(), Box> { + EmitBuilder::builder().git_sha(false).emit()?; + Ok(()) +} diff --git a/router/src/lib.rs b/router/src/lib.rs index 91b4417c..7dc115fd 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -10,6 +10,29 @@ use serde::{Deserialize, Serialize}; use utoipa::ToSchema; use validation::Validation; +/// Hub type +#[derive(Clone, Debug, Deserialize)] +pub struct ModelInfo { + #[serde(rename(deserialize = "id"))] + pub model_id: String, + pub sha: Option, + pub pipeline_tag: Option, +} + +#[derive(Clone, Debug, Serialize, ToSchema)] +pub struct Info { + #[schema(example = "bigscience/blomm-560m")] + pub model_id: String, + #[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")] + pub model_sha: Option, + #[schema(nullable = true, example = "text-generation")] + pub model_pipeline_tag: Option, + #[schema(example = "0.5.0")] + pub version: &'static str, + #[schema(nullable = true, example = "null")] + pub sha: Option<&'static str>, +} + #[derive(Clone, Debug, Deserialize, ToSchema)] pub(crate) struct GenerateParameters { #[serde(default)] diff --git a/router/src/main.rs b/router/src/main.rs index 3ff72cde..5fda57be 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -10,8 +10,8 @@ use opentelemetry_otlp::WithExportConfig; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::Path; use text_generation_client::ShardedClient; -use text_generation_router::server; -use tokenizers::Tokenizer; +use text_generation_router::{server, ModelInfo}; +use tokenizers::{FromPretrainedParameters, Tokenizer}; use tower_http::cors::AllowOrigin; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; @@ -41,6 +41,8 @@ struct Args { master_shard_uds_path: String, #[clap(default_value = "bigscience/bloom", long, env)] tokenizer_name: String, + #[clap(default_value = "main", long, env)] + revision: String, #[clap(default_value = "2", long, env)] validation_workers: usize, #[clap(long, env)] @@ -66,6 +68,7 @@ fn main() -> Result<(), std::io::Error> { port, master_shard_uds_path, tokenizer_name, + revision, validation_workers, json_output, otlp_endpoint, @@ -90,16 +93,19 @@ fn main() -> Result<(), std::io::Error> { // Tokenizer instance // This will only be used to validate payloads let local_path = Path::new(&tokenizer_name); - let tokenizer = - if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists() - { - // Load local tokenizer - Tokenizer::from_file(local_path.join("tokenizer.json")).ok() - } else { - // Download and instantiate tokenizer - // We need to download it outside of the Tokio runtime - Tokenizer::from_pretrained(tokenizer_name.clone(), None).ok() + let local_model = local_path.exists() && local_path.is_dir(); + let tokenizer = if local_model { + // Load local tokenizer + Tokenizer::from_file(local_path.join("tokenizer.json")).ok() + } else { + // Download and instantiate tokenizer + // We need to download it outside of the Tokio runtime + let params = FromPretrainedParameters { + revision: revision.clone(), + ..Default::default() }; + Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).ok() + }; // Launch Tokio runtime tokio::runtime::Builder::new_multi_thread() @@ -116,25 +122,23 @@ fn main() -> Result<(), std::io::Error> { tracing::warn!("Rust input length validation and truncation is disabled"); } - // Get pipeline tag - let model_info = reqwest::get(format!( - "https://huggingface.co/api/models/{tokenizer_name}" - )) - .await - .expect("Could not connect to hf.co") - .text() - .await - .expect("error when retrieving model info from hf.co"); - let model_info: serde_json::Value = - serde_json::from_str(&model_info).expect("unable to parse model info"); + // Get Model info + let model_info = match local_model { + true => ModelInfo { + model_id: tokenizer_name.clone(), + sha: None, + pipeline_tag: None, + }, + false => get_model_info(&tokenizer_name, &revision).await, + }; // if pipeline-tag == text-generation we default to return_full_text = true - let compat_return_full_text = match model_info.get("pipeline_tag") { + let compat_return_full_text = match &model_info.pipeline_tag { None => { tracing::warn!("no pipeline tag found for model {tokenizer_name}"); false } - Some(pipeline_tag) => pipeline_tag.as_str() == Some("text-generation"), + Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation", }; // Instantiate sharded client from the master unix socket @@ -153,6 +157,7 @@ fn main() -> Result<(), std::io::Error> { // Run server server::run( + model_info, compat_return_full_text, max_concurrent_requests, max_best_of, @@ -226,3 +231,16 @@ fn init_logging(otlp_endpoint: Option, json_output: bool) { .with(layers) .init(); } + +/// get model info from the Huggingface Hub +pub async fn get_model_info(model_id: &str, revision: &str) -> ModelInfo { + let model_info = reqwest::get(format!( + "https://huggingface.co/api/models/{model_id}/revision/{revision}" + )) + .await + .expect("Could not connect to hf.co") + .text() + .await + .expect("error when retrieving model info from hf.co"); + serde_json::from_str(&model_info).expect("unable to parse model info") +} diff --git a/router/src/server.rs b/router/src/server.rs index 851837d5..ce301399 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -3,8 +3,8 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::validation::ValidationError; use crate::{ BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason, - GenerateParameters, GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails, - StreamResponse, Token, Validation, + GenerateParameters, GenerateRequest, GenerateResponse, Infer, Info, ModelInfo, PrefillToken, + StreamDetails, StreamResponse, Token, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -27,7 +27,24 @@ use tracing::{info_span, instrument, Instrument}; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; -/// Compatibility route with api-inference and AzureML +/// Generate tokens if `stream == false` or a stream of token if `stream == true` +#[utoipa::path( + post, + tag = "Text Generation Inference", + path = "/", + request_body = CompatGenerateRequest, + responses( + (status = 200, description = "See /generate or /generate_stream"), + (status = 424, description = "Generation Error", body = ErrorResponse, + example = json ! ({"error": "Request failed during generation"})), + (status = 429, description = "Model is overloaded", body = ErrorResponse, + example = json ! ({"error": "Model is overloaded"})), + (status = 422, description = "Input validation error", body = ErrorResponse, + example = json ! ({"error": "Input validation error"})), + (status = 500, description = "Incomplete generation", body = ErrorResponse, + example = json ! ({"error": "Incomplete generation"})), + ) +)] #[instrument(skip(infer))] async fn compat_generate( default_return_full_text: Extension, @@ -53,6 +70,26 @@ async fn compat_generate( } } +/// Text Generation Inference endpoint info +#[utoipa::path( + get, + tag = "Text Generation Inference", + path = "/info", + responses((status = 200, description = "Served model info", body = Info)) +)] +#[instrument] +async fn get_model_info(model_info: Extension) -> Json { + let model_info = model_info.0; + let info = Info { + version: env!("CARGO_PKG_VERSION"), + sha: option_env!("VERGEN_GIT_SHA"), + model_id: model_info.model_id, + model_sha: model_info.sha, + model_pipeline_tag: model_info.pipeline_tag, + }; + Json(info) +} + /// Health check method #[instrument(skip(infer))] async fn health(infer: Extension) -> Result<(), (StatusCode, Json)> { @@ -87,21 +124,21 @@ async fn health(infer: Extension) -> Result<(), (StatusCode, Json) -> String { prom_handle.render() @@ -459,6 +496,7 @@ async fn metrics(prom_handle: Extension) -> String { /// Serving method #[allow(clippy::too_many_arguments)] pub async fn run( + model_info: ModelInfo, compat_return_full_text: bool, max_concurrent_requests: usize, max_best_of: usize, @@ -476,36 +514,40 @@ pub async fn run( // OpenAPI documentation #[derive(OpenApi)] #[openapi( - paths( - generate, - generate_stream, - metrics, - ), - components( - schemas( - GenerateRequest, - GenerateParameters, - PrefillToken, - Token, - GenerateResponse, - BestOfSequence, - Details, - FinishReason, - StreamResponse, - StreamDetails, - ErrorResponse, - ) - ), - tags( - (name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API") - ), - info( - title = "Text Generation Inference", - license( - name = "Apache 2.0", - url = "https://www.apache.org/licenses/LICENSE-2.0" - ) - ) + paths( + get_model_info, + compat_generate, + generate, + generate_stream, + metrics, + ), + components( + schemas( + Info, + CompatGenerateRequest, + GenerateRequest, + GenerateParameters, + PrefillToken, + Token, + GenerateResponse, + BestOfSequence, + Details, + FinishReason, + StreamResponse, + StreamDetails, + ErrorResponse, + ) + ), + tags( + (name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API") + ), + info( + title = "Text Generation Inference", + license( + name = "Apache 2.0", + url = "https://www.apache.org/licenses/LICENSE-2.0" + ) + ) )] struct ApiDoc; @@ -584,6 +626,7 @@ pub async fn run( .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi())) // Base routes .route("/", post(compat_generate)) + .route("/info", get(get_model_info)) .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) // AWS Sagemaker route @@ -596,6 +639,7 @@ pub async fn run( .route("/ping", get(health)) // Prometheus metrics route .route("/metrics", get(metrics)) + .layer(Extension(model_info)) .layer(Extension(compat_return_full_text)) .layer(Extension(infer)) .layer(Extension(prom_handle))