feat(router): add info route (#196)

close #125
This commit is contained in:
OlivierDehaene 2023-04-18 16:16:06 +02:00 committed by GitHub
parent b927244eb5
commit 2475aede61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 409 additions and 117 deletions

23
Cargo.lock generated
View File

@ -2430,6 +2430,7 @@ dependencies = [
"tracing-subscriber", "tracing-subscriber",
"utoipa", "utoipa",
"utoipa-swagger-ui", "utoipa-swagger-ui",
"vergen",
] ]
[[package]] [[package]]
@ -2468,8 +2469,10 @@ version = "0.3.20"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cd0cbfecb4d19b5ea75bb31ad904eb5b9fa13f21079c3b92017ebdf4999a5890" checksum = "cd0cbfecb4d19b5ea75bb31ad904eb5b9fa13f21079c3b92017ebdf4999a5890"
dependencies = [ dependencies = [
"itoa",
"serde", "serde",
"time-core", "time-core",
"time-macros",
] ]
[[package]] [[package]]
@ -2478,6 +2481,15 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e153e1f1acaef8acc537e68b44906d2db6436e2b35ac2c6b42640fff91f00fd" checksum = "2e153e1f1acaef8acc537e68b44906d2db6436e2b35ac2c6b42640fff91f00fd"
[[package]]
name = "time-macros"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd80a657e71da814b8e5d60d3374fc6d35045062245d80224748ae522dd76f36"
dependencies = [
"time-core",
]
[[package]] [[package]]
name = "tinyvec" name = "tinyvec"
version = "1.6.0" version = "1.6.0"
@ -2966,6 +2978,17 @@ version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]]
name = "vergen"
version = "8.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1b86a8af1dedf089b1c78338678e4c7492b6045649042d94faf19690499d236"
dependencies = [
"anyhow",
"rustversion",
"time",
]
[[package]] [[package]]
name = "version_check" name = "version_check"
version = "0.9.4" version = "0.9.4"

View File

@ -4,8 +4,7 @@
"title": "Text Generation Inference", "title": "Text Generation Inference",
"description": "Text Generation Webserver", "description": "Text Generation Webserver",
"contact": { "contact": {
"name": "Olivier Dehaene", "name": "Olivier Dehaene"
"email": "olivier@huggingface.co"
}, },
"license": { "license": {
"name": "Apache 2.0", "name": "Apache 2.0",
@ -14,6 +13,83 @@
"version": "0.5.0" "version": "0.5.0"
}, },
"paths": { "paths": {
"/": {
"post": {
"tags": [
"Text Generation Inference"
],
"summary": "Generate tokens if `stream == false` or a stream of token if `stream == true`",
"description": "Generate tokens if `stream == false` or a stream of token if `stream == true`",
"operationId": "compat_generate",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/CompatGenerateRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "See /generate or /generate_stream"
},
"422": {
"description": "Input validation error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Input validation error"
}
}
}
},
"424": {
"description": "Generation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Request failed during generation"
}
}
}
},
"429": {
"description": "Model is overloaded",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Model is overloaded"
}
}
}
},
"500": {
"description": "Incomplete generation",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Incomplete generation"
}
}
}
}
}
}
},
"/generate": { "/generate": {
"post": { "post": {
"tags": [ "tags": [
@ -95,8 +171,7 @@
} }
} }
} }
}, }
"deprecated": false
} }
}, },
"/generate_stream": { "/generate_stream": {
@ -180,8 +255,29 @@
} }
} }
} }
}, }
"deprecated": false }
},
"/info": {
"get": {
"tags": [
"Text Generation Inference"
],
"summary": "Text Generation Inference endpoint info",
"description": "Text Generation Inference endpoint info",
"operationId": "get_model_info",
"responses": {
"200": {
"description": "Served model info",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/Info"
}
}
}
}
}
} }
}, },
"/metrics": { "/metrics": {
@ -203,8 +299,7 @@
} }
} }
} }
}, }
"deprecated": false
} }
} }
}, },
@ -230,7 +325,8 @@
"generated_tokens": { "generated_tokens": {
"type": "integer", "type": "integer",
"format": "int32", "format": "int32",
"example": 1 "example": 1,
"minimum": 0.0
}, },
"prefill": { "prefill": {
"type": "array", "type": "array",
@ -242,7 +338,8 @@
"type": "integer", "type": "integer",
"format": "int64", "format": "int64",
"example": 42, "example": 42,
"nullable": true "nullable": true,
"minimum": 0.0
}, },
"tokens": { "tokens": {
"type": "array", "type": "array",
@ -252,6 +349,24 @@
} }
} }
}, },
"CompatGenerateRequest": {
"type": "object",
"required": [
"inputs"
],
"properties": {
"inputs": {
"type": "string",
"example": "My name is Olivier and I"
},
"parameters": {
"$ref": "#/components/schemas/GenerateParameters"
},
"stream": {
"type": "boolean"
}
}
},
"Details": { "Details": {
"type": "object", "type": "object",
"required": [ "required": [
@ -265,7 +380,8 @@
"type": "array", "type": "array",
"items": { "items": {
"$ref": "#/components/schemas/BestOfSequence" "$ref": "#/components/schemas/BestOfSequence"
} },
"nullable": true
}, },
"finish_reason": { "finish_reason": {
"$ref": "#/components/schemas/FinishReason" "$ref": "#/components/schemas/FinishReason"
@ -273,7 +389,8 @@
"generated_tokens": { "generated_tokens": {
"type": "integer", "type": "integer",
"format": "int32", "format": "int32",
"example": 1 "example": 1,
"minimum": 0.0
}, },
"prefill": { "prefill": {
"type": "array", "type": "array",
@ -285,7 +402,8 @@
"type": "integer", "type": "integer",
"format": "int64", "format": "int64",
"example": 42, "example": 42,
"nullable": true "nullable": true,
"minimum": 0.0
}, },
"tokens": { "tokens": {
"type": "array", "type": "array",
@ -326,6 +444,7 @@
"default": "null", "default": "null",
"example": 1, "example": 1,
"nullable": true, "nullable": true,
"minimum": 0.0,
"exclusiveMinimum": 0.0 "exclusiveMinimum": 0.0
}, },
"details": { "details": {
@ -341,6 +460,7 @@
"type": "integer", "type": "integer",
"format": "int32", "format": "int32",
"default": "20", "default": "20",
"minimum": 0.0,
"exclusiveMaximum": 512.0, "exclusiveMaximum": 512.0,
"exclusiveMinimum": 0.0 "exclusiveMinimum": 0.0
}, },
@ -364,6 +484,7 @@
"default": "null", "default": "null",
"example": "null", "example": "null",
"nullable": true, "nullable": true,
"minimum": 0.0,
"exclusiveMinimum": 0.0 "exclusiveMinimum": 0.0
}, },
"stop": { "stop": {
@ -405,7 +526,8 @@
"type": "integer", "type": "integer",
"default": "null", "default": "null",
"example": "null", "example": "null",
"nullable": true "nullable": true,
"minimum": 0.0
}, },
"typical_p": { "typical_p": {
"type": "number", "type": "number",
@ -445,7 +567,12 @@
], ],
"properties": { "properties": {
"details": { "details": {
"$ref": "#/components/schemas/Details" "allOf": [
{
"$ref": "#/components/schemas/Details"
}
],
"nullable": true
}, },
"generated_text": { "generated_text": {
"type": "string", "type": "string",
@ -453,6 +580,38 @@
} }
} }
}, },
"Info": {
"type": "object",
"required": [
"model_id",
"version"
],
"properties": {
"model_id": {
"type": "string",
"example": "bigscience/blomm-560m"
},
"model_pipeline_tag": {
"type": "string",
"example": "text-generation",
"nullable": true
},
"model_sha": {
"type": "string",
"example": "e985a63cdc139290c5f700ff1929f0b5942cced2",
"nullable": true
},
"sha": {
"type": "string",
"example": "null",
"nullable": true
},
"version": {
"type": "string",
"example": "0.5.0"
}
}
},
"PrefillToken": { "PrefillToken": {
"type": "object", "type": "object",
"required": [ "required": [
@ -464,7 +623,8 @@
"id": { "id": {
"type": "integer", "type": "integer",
"format": "int32", "format": "int32",
"example": 0 "example": 0,
"minimum": 0.0
}, },
"logprob": { "logprob": {
"type": "number", "type": "number",
@ -491,13 +651,15 @@
"generated_tokens": { "generated_tokens": {
"type": "integer", "type": "integer",
"format": "int32", "format": "int32",
"example": 1 "example": 1,
"minimum": 0.0
}, },
"seed": { "seed": {
"type": "integer", "type": "integer",
"format": "int64", "format": "int64",
"example": 42, "example": 42,
"nullable": true "nullable": true,
"minimum": 0.0
} }
} }
}, },
@ -508,7 +670,12 @@
], ],
"properties": { "properties": {
"details": { "details": {
"$ref": "#/components/schemas/StreamDetails" "allOf": [
{
"$ref": "#/components/schemas/StreamDetails"
}
],
"nullable": true
}, },
"generated_text": { "generated_text": {
"type": "string", "type": "string",
@ -533,7 +700,8 @@
"id": { "id": {
"type": "integer", "type": "integer",
"format": "int32", "format": "int32",
"example": 0 "example": 0,
"minimum": 0.0
}, },
"logprob": { "logprob": {
"type": "number", "type": "number",

View File

@ -392,6 +392,12 @@ fn main() -> ExitCode {
model_id, model_id,
]; ];
// Model optional revision
if let Some(ref revision) = revision {
argv.push("--revision".to_string());
argv.push(revision.to_string())
}
if json_output { if json_output {
argv.push("--json-output".to_string()); argv.push("--json-output".to_string());
} }

View File

@ -4,6 +4,7 @@ version = "0.5.0"
edition = "2021" edition = "2021"
authors = ["Olivier Dehaene"] authors = ["Olivier Dehaene"]
description = "Text Generation Webserver" description = "Text Generation Webserver"
build = "build.rs"
[lib] [lib]
path = "src/lib.rs" path = "src/lib.rs"
@ -26,7 +27,7 @@ nohash-hasher = "0.2.0"
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] } opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.11.0" opentelemetry-otlp = "0.11.0"
rand = "0.8.5" rand = "0.8.5"
reqwest = { version = "0.11.14", features = [] } reqwest = { version = "0.11.14", features = [] }
serde = "1.0.152" serde = "1.0.152"
serde_json = "1.0.93" serde_json = "1.0.93"
thiserror = "1.0.38" thiserror = "1.0.38"
@ -39,3 +40,5 @@ tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
utoipa = { version = "3.0.1", features = ["axum_extras"] } utoipa = { version = "3.0.1", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] } utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }
[build-dependencies]
vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] }

7
router/build.rs Normal file
View File

@ -0,0 +1,7 @@
use std::error::Error;
use vergen::EmitBuilder;
fn main() -> Result<(), Box<dyn Error>> {
EmitBuilder::builder().git_sha(false).emit()?;
Ok(())
}

View File

@ -10,6 +10,29 @@ use serde::{Deserialize, Serialize};
use utoipa::ToSchema; use utoipa::ToSchema;
use validation::Validation; use validation::Validation;
/// Hub type
#[derive(Clone, Debug, Deserialize)]
pub struct ModelInfo {
#[serde(rename(deserialize = "id"))]
pub model_id: String,
pub sha: Option<String>,
pub pipeline_tag: Option<String>,
}
#[derive(Clone, Debug, Serialize, ToSchema)]
pub struct Info {
#[schema(example = "bigscience/blomm-560m")]
pub model_id: String,
#[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")]
pub model_sha: Option<String>,
#[schema(nullable = true, example = "text-generation")]
pub model_pipeline_tag: Option<String>,
#[schema(example = "0.5.0")]
pub version: &'static str,
#[schema(nullable = true, example = "null")]
pub sha: Option<&'static str>,
}
#[derive(Clone, Debug, Deserialize, ToSchema)] #[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct GenerateParameters { pub(crate) struct GenerateParameters {
#[serde(default)] #[serde(default)]

View File

@ -10,8 +10,8 @@ use opentelemetry_otlp::WithExportConfig;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::Path; use std::path::Path;
use text_generation_client::ShardedClient; use text_generation_client::ShardedClient;
use text_generation_router::server; use text_generation_router::{server, ModelInfo};
use tokenizers::Tokenizer; use tokenizers::{FromPretrainedParameters, Tokenizer};
use tower_http::cors::AllowOrigin; use tower_http::cors::AllowOrigin;
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::util::SubscriberInitExt;
@ -41,6 +41,8 @@ struct Args {
master_shard_uds_path: String, master_shard_uds_path: String,
#[clap(default_value = "bigscience/bloom", long, env)] #[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String, tokenizer_name: String,
#[clap(default_value = "main", long, env)]
revision: String,
#[clap(default_value = "2", long, env)] #[clap(default_value = "2", long, env)]
validation_workers: usize, validation_workers: usize,
#[clap(long, env)] #[clap(long, env)]
@ -66,6 +68,7 @@ fn main() -> Result<(), std::io::Error> {
port, port,
master_shard_uds_path, master_shard_uds_path,
tokenizer_name, tokenizer_name,
revision,
validation_workers, validation_workers,
json_output, json_output,
otlp_endpoint, otlp_endpoint,
@ -90,16 +93,19 @@ fn main() -> Result<(), std::io::Error> {
// Tokenizer instance // Tokenizer instance
// This will only be used to validate payloads // This will only be used to validate payloads
let local_path = Path::new(&tokenizer_name); let local_path = Path::new(&tokenizer_name);
let tokenizer = let local_model = local_path.exists() && local_path.is_dir();
if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists() let tokenizer = if local_model {
{ // Load local tokenizer
// Load local tokenizer Tokenizer::from_file(local_path.join("tokenizer.json")).ok()
Tokenizer::from_file(local_path.join("tokenizer.json")).ok() } else {
} else { // Download and instantiate tokenizer
// Download and instantiate tokenizer // We need to download it outside of the Tokio runtime
// We need to download it outside of the Tokio runtime let params = FromPretrainedParameters {
Tokenizer::from_pretrained(tokenizer_name.clone(), None).ok() revision: revision.clone(),
..Default::default()
}; };
Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).ok()
};
// Launch Tokio runtime // Launch Tokio runtime
tokio::runtime::Builder::new_multi_thread() tokio::runtime::Builder::new_multi_thread()
@ -116,25 +122,23 @@ fn main() -> Result<(), std::io::Error> {
tracing::warn!("Rust input length validation and truncation is disabled"); tracing::warn!("Rust input length validation and truncation is disabled");
} }
// Get pipeline tag // Get Model info
let model_info = reqwest::get(format!( let model_info = match local_model {
"https://huggingface.co/api/models/{tokenizer_name}" true => ModelInfo {
)) model_id: tokenizer_name.clone(),
.await sha: None,
.expect("Could not connect to hf.co") pipeline_tag: None,
.text() },
.await false => get_model_info(&tokenizer_name, &revision).await,
.expect("error when retrieving model info from hf.co"); };
let model_info: serde_json::Value =
serde_json::from_str(&model_info).expect("unable to parse model info");
// if pipeline-tag == text-generation we default to return_full_text = true // if pipeline-tag == text-generation we default to return_full_text = true
let compat_return_full_text = match model_info.get("pipeline_tag") { let compat_return_full_text = match &model_info.pipeline_tag {
None => { None => {
tracing::warn!("no pipeline tag found for model {tokenizer_name}"); tracing::warn!("no pipeline tag found for model {tokenizer_name}");
false false
} }
Some(pipeline_tag) => pipeline_tag.as_str() == Some("text-generation"), Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
}; };
// Instantiate sharded client from the master unix socket // Instantiate sharded client from the master unix socket
@ -153,6 +157,7 @@ fn main() -> Result<(), std::io::Error> {
// Run server // Run server
server::run( server::run(
model_info,
compat_return_full_text, compat_return_full_text,
max_concurrent_requests, max_concurrent_requests,
max_best_of, max_best_of,
@ -226,3 +231,16 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
.with(layers) .with(layers)
.init(); .init();
} }
/// get model info from the Huggingface Hub
pub async fn get_model_info(model_id: &str, revision: &str) -> ModelInfo {
let model_info = reqwest::get(format!(
"https://huggingface.co/api/models/{model_id}/revision/{revision}"
))
.await
.expect("Could not connect to hf.co")
.text()
.await
.expect("error when retrieving model info from hf.co");
serde_json::from_str(&model_info).expect("unable to parse model info")
}

View File

@ -3,8 +3,8 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason, BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
GenerateParameters, GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails, GenerateParameters, GenerateRequest, GenerateResponse, Infer, Info, ModelInfo, PrefillToken,
StreamResponse, Token, Validation, StreamDetails, StreamResponse, Token, Validation,
}; };
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
@ -27,7 +27,24 @@ use tracing::{info_span, instrument, Instrument};
use utoipa::OpenApi; use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi; use utoipa_swagger_ui::SwaggerUi;
/// Compatibility route with api-inference and AzureML /// Generate tokens if `stream == false` or a stream of token if `stream == true`
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/",
request_body = CompatGenerateRequest,
responses(
(status = 200, description = "See /generate or /generate_stream"),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})),
)
)]
#[instrument(skip(infer))] #[instrument(skip(infer))]
async fn compat_generate( async fn compat_generate(
default_return_full_text: Extension<bool>, default_return_full_text: Extension<bool>,
@ -53,6 +70,26 @@ async fn compat_generate(
} }
} }
/// Text Generation Inference endpoint info
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/info",
responses((status = 200, description = "Served model info", body = Info))
)]
#[instrument]
async fn get_model_info(model_info: Extension<ModelInfo>) -> Json<Info> {
let model_info = model_info.0;
let info = Info {
version: env!("CARGO_PKG_VERSION"),
sha: option_env!("VERGEN_GIT_SHA"),
model_id: model_info.model_id,
model_sha: model_info.sha,
model_pipeline_tag: model_info.pipeline_tag,
};
Json(info)
}
/// Health check method /// Health check method
#[instrument(skip(infer))] #[instrument(skip(infer))]
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> { async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
@ -87,21 +124,21 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
/// Generate tokens /// Generate tokens
#[utoipa::path( #[utoipa::path(
post, post,
tag = "Text Generation Inference", tag = "Text Generation Inference",
path = "/generate", path = "/generate",
request_body = GenerateRequest, request_body = GenerateRequest,
responses( responses(
(status = 200, description = "Generated Text", body = GenerateResponse), (status = 200, description = "Generated Text", body = GenerateResponse),
(status = 424, description = "Generation Error", body = ErrorResponse, (status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})), example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse, (status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})), example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse, (status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})), example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse, (status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})), example = json ! ({"error": "Incomplete generation"})),
) )
)] )]
#[instrument( #[instrument(
skip(infer), skip(infer),
@ -264,26 +301,26 @@ async fn generate(
/// Generate a stream of token using Server-Sent Events /// Generate a stream of token using Server-Sent Events
#[utoipa::path( #[utoipa::path(
post, post,
tag = "Text Generation Inference", tag = "Text Generation Inference",
path = "/generate_stream", path = "/generate_stream",
request_body = GenerateRequest, request_body = GenerateRequest,
responses( responses(
(status = 200, description = "Generated Text", body = StreamResponse, (status = 200, description = "Generated Text", body = StreamResponse,
content_type = "text/event-stream"), content_type = "text/event-stream"),
(status = 424, description = "Generation Error", body = ErrorResponse, (status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"}), example = json ! ({"error": "Request failed during generation"}),
content_type = "text/event-stream"), content_type = "text/event-stream"),
(status = 429, description = "Model is overloaded", body = ErrorResponse, (status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"}), example = json ! ({"error": "Model is overloaded"}),
content_type = "text/event-stream"), content_type = "text/event-stream"),
(status = 422, description = "Input validation error", body = ErrorResponse, (status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"}), example = json ! ({"error": "Input validation error"}),
content_type = "text/event-stream"), content_type = "text/event-stream"),
(status = 500, description = "Incomplete generation", body = ErrorResponse, (status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"}), example = json ! ({"error": "Incomplete generation"}),
content_type = "text/event-stream"), content_type = "text/event-stream"),
) )
)] )]
#[instrument( #[instrument(
skip(infer), skip(infer),
@ -447,10 +484,10 @@ async fn generate_stream(
/// Prometheus metrics scrape endpoint /// Prometheus metrics scrape endpoint
#[utoipa::path( #[utoipa::path(
get, get,
tag = "Text Generation Inference", tag = "Text Generation Inference",
path = "/metrics", path = "/metrics",
responses((status = 200, description = "Prometheus Metrics", body = String)) responses((status = 200, description = "Prometheus Metrics", body = String))
)] )]
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String { async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
prom_handle.render() prom_handle.render()
@ -459,6 +496,7 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
/// Serving method /// Serving method
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub async fn run( pub async fn run(
model_info: ModelInfo,
compat_return_full_text: bool, compat_return_full_text: bool,
max_concurrent_requests: usize, max_concurrent_requests: usize,
max_best_of: usize, max_best_of: usize,
@ -476,36 +514,40 @@ pub async fn run(
// OpenAPI documentation // OpenAPI documentation
#[derive(OpenApi)] #[derive(OpenApi)]
#[openapi( #[openapi(
paths( paths(
generate, get_model_info,
generate_stream, compat_generate,
metrics, generate,
), generate_stream,
components( metrics,
schemas( ),
GenerateRequest, components(
GenerateParameters, schemas(
PrefillToken, Info,
Token, CompatGenerateRequest,
GenerateResponse, GenerateRequest,
BestOfSequence, GenerateParameters,
Details, PrefillToken,
FinishReason, Token,
StreamResponse, GenerateResponse,
StreamDetails, BestOfSequence,
ErrorResponse, Details,
) FinishReason,
), StreamResponse,
tags( StreamDetails,
(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API") ErrorResponse,
), )
info( ),
title = "Text Generation Inference", tags(
license( (name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
name = "Apache 2.0", ),
url = "https://www.apache.org/licenses/LICENSE-2.0" info(
) title = "Text Generation Inference",
) license(
name = "Apache 2.0",
url = "https://www.apache.org/licenses/LICENSE-2.0"
)
)
)] )]
struct ApiDoc; struct ApiDoc;
@ -584,6 +626,7 @@ pub async fn run(
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi())) .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
// Base routes // Base routes
.route("/", post(compat_generate)) .route("/", post(compat_generate))
.route("/info", get(get_model_info))
.route("/generate", post(generate)) .route("/generate", post(generate))
.route("/generate_stream", post(generate_stream)) .route("/generate_stream", post(generate_stream))
// AWS Sagemaker route // AWS Sagemaker route
@ -596,6 +639,7 @@ pub async fn run(
.route("/ping", get(health)) .route("/ping", get(health))
// Prometheus metrics route // Prometheus metrics route
.route("/metrics", get(metrics)) .route("/metrics", get(metrics))
.layer(Extension(model_info))
.layer(Extension(compat_return_full_text)) .layer(Extension(compat_return_full_text))
.layer(Extension(infer)) .layer(Extension(infer))
.layer(Extension(prom_handle)) .layer(Extension(prom_handle))