feat(router): add info route (#196)

close #125
This commit is contained in:
OlivierDehaene 2023-04-18 16:16:06 +02:00 committed by GitHub
parent b927244eb5
commit 2475aede61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 409 additions and 117 deletions

23
Cargo.lock generated
View File

@ -2430,6 +2430,7 @@ dependencies = [
"tracing-subscriber",
"utoipa",
"utoipa-swagger-ui",
"vergen",
]
[[package]]
@ -2468,8 +2469,10 @@ version = "0.3.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cd0cbfecb4d19b5ea75bb31ad904eb5b9fa13f21079c3b92017ebdf4999a5890"
dependencies = [
"itoa",
"serde",
"time-core",
"time-macros",
]
[[package]]
@ -2478,6 +2481,15 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e153e1f1acaef8acc537e68b44906d2db6436e2b35ac2c6b42640fff91f00fd"
[[package]]
name = "time-macros"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd80a657e71da814b8e5d60d3374fc6d35045062245d80224748ae522dd76f36"
dependencies = [
"time-core",
]
[[package]]
name = "tinyvec"
version = "1.6.0"
@ -2966,6 +2978,17 @@ version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]]
name = "vergen"
version = "8.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1b86a8af1dedf089b1c78338678e4c7492b6045649042d94faf19690499d236"
dependencies = [
"anyhow",
"rustversion",
"time",
]
[[package]]
name = "version_check"
version = "0.9.4"

View File

@ -4,8 +4,7 @@
"title": "Text Generation Inference",
"description": "Text Generation Webserver",
"contact": {
"name": "Olivier Dehaene",
"email": "olivier@huggingface.co"
"name": "Olivier Dehaene"
},
"license": {
"name": "Apache 2.0",
@ -14,6 +13,83 @@
"version": "0.5.0"
},
"paths": {
"/": {
"post": {
"tags": [
"Text Generation Inference"
],
"summary": "Generate tokens if `stream == false` or a stream of token if `stream == true`",
"description": "Generate tokens if `stream == false` or a stream of token if `stream == true`",
"operationId": "compat_generate",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/CompatGenerateRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "See /generate or /generate_stream"
},
"422": {
"description": "Input validation error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Input validation error"
}
}
}
},
"424": {
"description": "Generation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Request failed during generation"
}
}
}
},
"429": {
"description": "Model is overloaded",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Model is overloaded"
}
}
}
},
"500": {
"description": "Incomplete generation",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Incomplete generation"
}
}
}
}
}
}
},
"/generate": {
"post": {
"tags": [
@ -95,8 +171,7 @@
}
}
}
},
"deprecated": false
}
}
},
"/generate_stream": {
@ -180,8 +255,29 @@
}
}
}
},
"deprecated": false
}
}
},
"/info": {
"get": {
"tags": [
"Text Generation Inference"
],
"summary": "Text Generation Inference endpoint info",
"description": "Text Generation Inference endpoint info",
"operationId": "get_model_info",
"responses": {
"200": {
"description": "Served model info",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/Info"
}
}
}
}
}
}
},
"/metrics": {
@ -203,8 +299,7 @@
}
}
}
},
"deprecated": false
}
}
}
},
@ -230,7 +325,8 @@
"generated_tokens": {
"type": "integer",
"format": "int32",
"example": 1
"example": 1,
"minimum": 0.0
},
"prefill": {
"type": "array",
@ -242,7 +338,8 @@
"type": "integer",
"format": "int64",
"example": 42,
"nullable": true
"nullable": true,
"minimum": 0.0
},
"tokens": {
"type": "array",
@ -252,6 +349,24 @@
}
}
},
"CompatGenerateRequest": {
"type": "object",
"required": [
"inputs"
],
"properties": {
"inputs": {
"type": "string",
"example": "My name is Olivier and I"
},
"parameters": {
"$ref": "#/components/schemas/GenerateParameters"
},
"stream": {
"type": "boolean"
}
}
},
"Details": {
"type": "object",
"required": [
@ -265,7 +380,8 @@
"type": "array",
"items": {
"$ref": "#/components/schemas/BestOfSequence"
}
},
"nullable": true
},
"finish_reason": {
"$ref": "#/components/schemas/FinishReason"
@ -273,7 +389,8 @@
"generated_tokens": {
"type": "integer",
"format": "int32",
"example": 1
"example": 1,
"minimum": 0.0
},
"prefill": {
"type": "array",
@ -285,7 +402,8 @@
"type": "integer",
"format": "int64",
"example": 42,
"nullable": true
"nullable": true,
"minimum": 0.0
},
"tokens": {
"type": "array",
@ -326,6 +444,7 @@
"default": "null",
"example": 1,
"nullable": true,
"minimum": 0.0,
"exclusiveMinimum": 0.0
},
"details": {
@ -341,6 +460,7 @@
"type": "integer",
"format": "int32",
"default": "20",
"minimum": 0.0,
"exclusiveMaximum": 512.0,
"exclusiveMinimum": 0.0
},
@ -364,6 +484,7 @@
"default": "null",
"example": "null",
"nullable": true,
"minimum": 0.0,
"exclusiveMinimum": 0.0
},
"stop": {
@ -405,7 +526,8 @@
"type": "integer",
"default": "null",
"example": "null",
"nullable": true
"nullable": true,
"minimum": 0.0
},
"typical_p": {
"type": "number",
@ -445,7 +567,12 @@
],
"properties": {
"details": {
"$ref": "#/components/schemas/Details"
"allOf": [
{
"$ref": "#/components/schemas/Details"
}
],
"nullable": true
},
"generated_text": {
"type": "string",
@ -453,6 +580,38 @@
}
}
},
"Info": {
"type": "object",
"required": [
"model_id",
"version"
],
"properties": {
"model_id": {
"type": "string",
"example": "bigscience/blomm-560m"
},
"model_pipeline_tag": {
"type": "string",
"example": "text-generation",
"nullable": true
},
"model_sha": {
"type": "string",
"example": "e985a63cdc139290c5f700ff1929f0b5942cced2",
"nullable": true
},
"sha": {
"type": "string",
"example": "null",
"nullable": true
},
"version": {
"type": "string",
"example": "0.5.0"
}
}
},
"PrefillToken": {
"type": "object",
"required": [
@ -464,7 +623,8 @@
"id": {
"type": "integer",
"format": "int32",
"example": 0
"example": 0,
"minimum": 0.0
},
"logprob": {
"type": "number",
@ -491,13 +651,15 @@
"generated_tokens": {
"type": "integer",
"format": "int32",
"example": 1
"example": 1,
"minimum": 0.0
},
"seed": {
"type": "integer",
"format": "int64",
"example": 42,
"nullable": true
"nullable": true,
"minimum": 0.0
}
}
},
@ -508,7 +670,12 @@
],
"properties": {
"details": {
"$ref": "#/components/schemas/StreamDetails"
"allOf": [
{
"$ref": "#/components/schemas/StreamDetails"
}
],
"nullable": true
},
"generated_text": {
"type": "string",
@ -533,7 +700,8 @@
"id": {
"type": "integer",
"format": "int32",
"example": 0
"example": 0,
"minimum": 0.0
},
"logprob": {
"type": "number",

View File

@ -392,6 +392,12 @@ fn main() -> ExitCode {
model_id,
];
// Model optional revision
if let Some(ref revision) = revision {
argv.push("--revision".to_string());
argv.push(revision.to_string())
}
if json_output {
argv.push("--json-output".to_string());
}

View File

@ -4,6 +4,7 @@ version = "0.5.0"
edition = "2021"
authors = ["Olivier Dehaene"]
description = "Text Generation Webserver"
build = "build.rs"
[lib]
path = "src/lib.rs"
@ -26,7 +27,7 @@ nohash-hasher = "0.2.0"
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.11.0"
rand = "0.8.5"
reqwest = { version = "0.11.14", features = [] }
reqwest = { version = "0.11.14", features = [] }
serde = "1.0.152"
serde_json = "1.0.93"
thiserror = "1.0.38"
@ -39,3 +40,5 @@ tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
utoipa = { version = "3.0.1", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }
[build-dependencies]
vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] }

7
router/build.rs Normal file
View File

@ -0,0 +1,7 @@
use std::error::Error;
use vergen::EmitBuilder;
fn main() -> Result<(), Box<dyn Error>> {
EmitBuilder::builder().git_sha(false).emit()?;
Ok(())
}

View File

@ -10,6 +10,29 @@ use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use validation::Validation;
/// Hub type
#[derive(Clone, Debug, Deserialize)]
pub struct ModelInfo {
#[serde(rename(deserialize = "id"))]
pub model_id: String,
pub sha: Option<String>,
pub pipeline_tag: Option<String>,
}
#[derive(Clone, Debug, Serialize, ToSchema)]
pub struct Info {
#[schema(example = "bigscience/blomm-560m")]
pub model_id: String,
#[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")]
pub model_sha: Option<String>,
#[schema(nullable = true, example = "text-generation")]
pub model_pipeline_tag: Option<String>,
#[schema(example = "0.5.0")]
pub version: &'static str,
#[schema(nullable = true, example = "null")]
pub sha: Option<&'static str>,
}
#[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct GenerateParameters {
#[serde(default)]

View File

@ -10,8 +10,8 @@ use opentelemetry_otlp::WithExportConfig;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::Path;
use text_generation_client::ShardedClient;
use text_generation_router::server;
use tokenizers::Tokenizer;
use text_generation_router::{server, ModelInfo};
use tokenizers::{FromPretrainedParameters, Tokenizer};
use tower_http::cors::AllowOrigin;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
@ -41,6 +41,8 @@ struct Args {
master_shard_uds_path: String,
#[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String,
#[clap(default_value = "main", long, env)]
revision: String,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
@ -66,6 +68,7 @@ fn main() -> Result<(), std::io::Error> {
port,
master_shard_uds_path,
tokenizer_name,
revision,
validation_workers,
json_output,
otlp_endpoint,
@ -90,16 +93,19 @@ fn main() -> Result<(), std::io::Error> {
// Tokenizer instance
// This will only be used to validate payloads
let local_path = Path::new(&tokenizer_name);
let tokenizer =
if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
{
// Load local tokenizer
Tokenizer::from_file(local_path.join("tokenizer.json")).ok()
} else {
// Download and instantiate tokenizer
// We need to download it outside of the Tokio runtime
Tokenizer::from_pretrained(tokenizer_name.clone(), None).ok()
let local_model = local_path.exists() && local_path.is_dir();
let tokenizer = if local_model {
// Load local tokenizer
Tokenizer::from_file(local_path.join("tokenizer.json")).ok()
} else {
// Download and instantiate tokenizer
// We need to download it outside of the Tokio runtime
let params = FromPretrainedParameters {
revision: revision.clone(),
..Default::default()
};
Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).ok()
};
// Launch Tokio runtime
tokio::runtime::Builder::new_multi_thread()
@ -116,25 +122,23 @@ fn main() -> Result<(), std::io::Error> {
tracing::warn!("Rust input length validation and truncation is disabled");
}
// Get pipeline tag
let model_info = reqwest::get(format!(
"https://huggingface.co/api/models/{tokenizer_name}"
))
.await
.expect("Could not connect to hf.co")
.text()
.await
.expect("error when retrieving model info from hf.co");
let model_info: serde_json::Value =
serde_json::from_str(&model_info).expect("unable to parse model info");
// Get Model info
let model_info = match local_model {
true => ModelInfo {
model_id: tokenizer_name.clone(),
sha: None,
pipeline_tag: None,
},
false => get_model_info(&tokenizer_name, &revision).await,
};
// if pipeline-tag == text-generation we default to return_full_text = true
let compat_return_full_text = match model_info.get("pipeline_tag") {
let compat_return_full_text = match &model_info.pipeline_tag {
None => {
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
false
}
Some(pipeline_tag) => pipeline_tag.as_str() == Some("text-generation"),
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
};
// Instantiate sharded client from the master unix socket
@ -153,6 +157,7 @@ fn main() -> Result<(), std::io::Error> {
// Run server
server::run(
model_info,
compat_return_full_text,
max_concurrent_requests,
max_best_of,
@ -226,3 +231,16 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
.with(layers)
.init();
}
/// get model info from the Huggingface Hub
pub async fn get_model_info(model_id: &str, revision: &str) -> ModelInfo {
let model_info = reqwest::get(format!(
"https://huggingface.co/api/models/{model_id}/revision/{revision}"
))
.await
.expect("Could not connect to hf.co")
.text()
.await
.expect("error when retrieving model info from hf.co");
serde_json::from_str(&model_info).expect("unable to parse model info")
}

View File

@ -3,8 +3,8 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError;
use crate::{
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
GenerateParameters, GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails,
StreamResponse, Token, Validation,
GenerateParameters, GenerateRequest, GenerateResponse, Infer, Info, ModelInfo, PrefillToken,
StreamDetails, StreamResponse, Token, Validation,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
@ -27,7 +27,24 @@ use tracing::{info_span, instrument, Instrument};
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
/// Compatibility route with api-inference and AzureML
/// Generate tokens if `stream == false` or a stream of token if `stream == true`
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/",
request_body = CompatGenerateRequest,
responses(
(status = 200, description = "See /generate or /generate_stream"),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})),
)
)]
#[instrument(skip(infer))]
async fn compat_generate(
default_return_full_text: Extension<bool>,
@ -53,6 +70,26 @@ async fn compat_generate(
}
}
/// Text Generation Inference endpoint info
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/info",
responses((status = 200, description = "Served model info", body = Info))
)]
#[instrument]
async fn get_model_info(model_info: Extension<ModelInfo>) -> Json<Info> {
let model_info = model_info.0;
let info = Info {
version: env!("CARGO_PKG_VERSION"),
sha: option_env!("VERGEN_GIT_SHA"),
model_id: model_info.model_id,
model_sha: model_info.sha,
model_pipeline_tag: model_info.pipeline_tag,
};
Json(info)
}
/// Health check method
#[instrument(skip(infer))]
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
@ -87,21 +124,21 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
/// Generate tokens
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/generate",
request_body = GenerateRequest,
responses(
(status = 200, description = "Generated Text", body = GenerateResponse),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})),
)
post,
tag = "Text Generation Inference",
path = "/generate",
request_body = GenerateRequest,
responses(
(status = 200, description = "Generated Text", body = GenerateResponse),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})),
)
)]
#[instrument(
skip(infer),
@ -264,26 +301,26 @@ async fn generate(
/// Generate a stream of token using Server-Sent Events
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/generate_stream",
request_body = GenerateRequest,
responses(
(status = 200, description = "Generated Text", body = StreamResponse,
content_type = "text/event-stream"),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"}),
content_type = "text/event-stream"),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"}),
content_type = "text/event-stream"),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"}),
content_type = "text/event-stream"),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"}),
content_type = "text/event-stream"),
)
post,
tag = "Text Generation Inference",
path = "/generate_stream",
request_body = GenerateRequest,
responses(
(status = 200, description = "Generated Text", body = StreamResponse,
content_type = "text/event-stream"),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"}),
content_type = "text/event-stream"),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"}),
content_type = "text/event-stream"),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"}),
content_type = "text/event-stream"),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"}),
content_type = "text/event-stream"),
)
)]
#[instrument(
skip(infer),
@ -447,10 +484,10 @@ async fn generate_stream(
/// Prometheus metrics scrape endpoint
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/metrics",
responses((status = 200, description = "Prometheus Metrics", body = String))
get,
tag = "Text Generation Inference",
path = "/metrics",
responses((status = 200, description = "Prometheus Metrics", body = String))
)]
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
prom_handle.render()
@ -459,6 +496,7 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
/// Serving method
#[allow(clippy::too_many_arguments)]
pub async fn run(
model_info: ModelInfo,
compat_return_full_text: bool,
max_concurrent_requests: usize,
max_best_of: usize,
@ -476,36 +514,40 @@ pub async fn run(
// OpenAPI documentation
#[derive(OpenApi)]
#[openapi(
paths(
generate,
generate_stream,
metrics,
),
components(
schemas(
GenerateRequest,
GenerateParameters,
PrefillToken,
Token,
GenerateResponse,
BestOfSequence,
Details,
FinishReason,
StreamResponse,
StreamDetails,
ErrorResponse,
)
),
tags(
(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
),
info(
title = "Text Generation Inference",
license(
name = "Apache 2.0",
url = "https://www.apache.org/licenses/LICENSE-2.0"
)
)
paths(
get_model_info,
compat_generate,
generate,
generate_stream,
metrics,
),
components(
schemas(
Info,
CompatGenerateRequest,
GenerateRequest,
GenerateParameters,
PrefillToken,
Token,
GenerateResponse,
BestOfSequence,
Details,
FinishReason,
StreamResponse,
StreamDetails,
ErrorResponse,
)
),
tags(
(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
),
info(
title = "Text Generation Inference",
license(
name = "Apache 2.0",
url = "https://www.apache.org/licenses/LICENSE-2.0"
)
)
)]
struct ApiDoc;
@ -584,6 +626,7 @@ pub async fn run(
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
// Base routes
.route("/", post(compat_generate))
.route("/info", get(get_model_info))
.route("/generate", post(generate))
.route("/generate_stream", post(generate_stream))
// AWS Sagemaker route
@ -596,6 +639,7 @@ pub async fn run(
.route("/ping", get(health))
// Prometheus metrics route
.route("/metrics", get(metrics))
.layer(Extension(model_info))
.layer(Extension(compat_return_full_text))
.layer(Extension(infer))
.layer(Extension(prom_handle))