parent
b927244eb5
commit
2475aede61
|
@ -2430,6 +2430,7 @@ dependencies = [
|
|||
"tracing-subscriber",
|
||||
"utoipa",
|
||||
"utoipa-swagger-ui",
|
||||
"vergen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -2468,8 +2469,10 @@ version = "0.3.20"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cd0cbfecb4d19b5ea75bb31ad904eb5b9fa13f21079c3b92017ebdf4999a5890"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"serde",
|
||||
"time-core",
|
||||
"time-macros",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -2478,6 +2481,15 @@ version = "0.1.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2e153e1f1acaef8acc537e68b44906d2db6436e2b35ac2c6b42640fff91f00fd"
|
||||
|
||||
[[package]]
|
||||
name = "time-macros"
|
||||
version = "0.2.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fd80a657e71da814b8e5d60d3374fc6d35045062245d80224748ae522dd76f36"
|
||||
dependencies = [
|
||||
"time-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinyvec"
|
||||
version = "1.6.0"
|
||||
|
@ -2966,6 +2978,17 @@ version = "0.2.15"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
|
||||
|
||||
[[package]]
|
||||
name = "vergen"
|
||||
version = "8.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c1b86a8af1dedf089b1c78338678e4c7492b6045649042d94faf19690499d236"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"rustversion",
|
||||
"time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "version_check"
|
||||
version = "0.9.4"
|
||||
|
|
|
@ -4,8 +4,7 @@
|
|||
"title": "Text Generation Inference",
|
||||
"description": "Text Generation Webserver",
|
||||
"contact": {
|
||||
"name": "Olivier Dehaene",
|
||||
"email": "olivier@huggingface.co"
|
||||
"name": "Olivier Dehaene"
|
||||
},
|
||||
"license": {
|
||||
"name": "Apache 2.0",
|
||||
|
@ -14,6 +13,83 @@
|
|||
"version": "0.5.0"
|
||||
},
|
||||
"paths": {
|
||||
"/": {
|
||||
"post": {
|
||||
"tags": [
|
||||
"Text Generation Inference"
|
||||
],
|
||||
"summary": "Generate tokens if `stream == false` or a stream of token if `stream == true`",
|
||||
"description": "Generate tokens if `stream == false` or a stream of token if `stream == true`",
|
||||
"operationId": "compat_generate",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/CompatGenerateRequest"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "See /generate or /generate_stream"
|
||||
},
|
||||
"422": {
|
||||
"description": "Input validation error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ErrorResponse"
|
||||
},
|
||||
"example": {
|
||||
"error": "Input validation error"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"424": {
|
||||
"description": "Generation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ErrorResponse"
|
||||
},
|
||||
"example": {
|
||||
"error": "Request failed during generation"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"429": {
|
||||
"description": "Model is overloaded",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ErrorResponse"
|
||||
},
|
||||
"example": {
|
||||
"error": "Model is overloaded"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"500": {
|
||||
"description": "Incomplete generation",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ErrorResponse"
|
||||
},
|
||||
"example": {
|
||||
"error": "Incomplete generation"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/generate": {
|
||||
"post": {
|
||||
"tags": [
|
||||
|
@ -95,8 +171,7 @@
|
|||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"deprecated": false
|
||||
}
|
||||
}
|
||||
},
|
||||
"/generate_stream": {
|
||||
|
@ -180,8 +255,29 @@
|
|||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"deprecated": false
|
||||
}
|
||||
}
|
||||
},
|
||||
"/info": {
|
||||
"get": {
|
||||
"tags": [
|
||||
"Text Generation Inference"
|
||||
],
|
||||
"summary": "Text Generation Inference endpoint info",
|
||||
"description": "Text Generation Inference endpoint info",
|
||||
"operationId": "get_model_info",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Served model info",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Info"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/metrics": {
|
||||
|
@ -203,8 +299,7 @@
|
|||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"deprecated": false
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -230,7 +325,8 @@
|
|||
"generated_tokens": {
|
||||
"type": "integer",
|
||||
"format": "int32",
|
||||
"example": 1
|
||||
"example": 1,
|
||||
"minimum": 0.0
|
||||
},
|
||||
"prefill": {
|
||||
"type": "array",
|
||||
|
@ -242,7 +338,8 @@
|
|||
"type": "integer",
|
||||
"format": "int64",
|
||||
"example": 42,
|
||||
"nullable": true
|
||||
"nullable": true,
|
||||
"minimum": 0.0
|
||||
},
|
||||
"tokens": {
|
||||
"type": "array",
|
||||
|
@ -252,6 +349,24 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"CompatGenerateRequest": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"inputs"
|
||||
],
|
||||
"properties": {
|
||||
"inputs": {
|
||||
"type": "string",
|
||||
"example": "My name is Olivier and I"
|
||||
},
|
||||
"parameters": {
|
||||
"$ref": "#/components/schemas/GenerateParameters"
|
||||
},
|
||||
"stream": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
},
|
||||
"Details": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
|
@ -265,7 +380,8 @@
|
|||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/BestOfSequence"
|
||||
}
|
||||
},
|
||||
"nullable": true
|
||||
},
|
||||
"finish_reason": {
|
||||
"$ref": "#/components/schemas/FinishReason"
|
||||
|
@ -273,7 +389,8 @@
|
|||
"generated_tokens": {
|
||||
"type": "integer",
|
||||
"format": "int32",
|
||||
"example": 1
|
||||
"example": 1,
|
||||
"minimum": 0.0
|
||||
},
|
||||
"prefill": {
|
||||
"type": "array",
|
||||
|
@ -285,7 +402,8 @@
|
|||
"type": "integer",
|
||||
"format": "int64",
|
||||
"example": 42,
|
||||
"nullable": true
|
||||
"nullable": true,
|
||||
"minimum": 0.0
|
||||
},
|
||||
"tokens": {
|
||||
"type": "array",
|
||||
|
@ -326,6 +444,7 @@
|
|||
"default": "null",
|
||||
"example": 1,
|
||||
"nullable": true,
|
||||
"minimum": 0.0,
|
||||
"exclusiveMinimum": 0.0
|
||||
},
|
||||
"details": {
|
||||
|
@ -341,6 +460,7 @@
|
|||
"type": "integer",
|
||||
"format": "int32",
|
||||
"default": "20",
|
||||
"minimum": 0.0,
|
||||
"exclusiveMaximum": 512.0,
|
||||
"exclusiveMinimum": 0.0
|
||||
},
|
||||
|
@ -364,6 +484,7 @@
|
|||
"default": "null",
|
||||
"example": "null",
|
||||
"nullable": true,
|
||||
"minimum": 0.0,
|
||||
"exclusiveMinimum": 0.0
|
||||
},
|
||||
"stop": {
|
||||
|
@ -405,7 +526,8 @@
|
|||
"type": "integer",
|
||||
"default": "null",
|
||||
"example": "null",
|
||||
"nullable": true
|
||||
"nullable": true,
|
||||
"minimum": 0.0
|
||||
},
|
||||
"typical_p": {
|
||||
"type": "number",
|
||||
|
@ -445,7 +567,12 @@
|
|||
],
|
||||
"properties": {
|
||||
"details": {
|
||||
"$ref": "#/components/schemas/Details"
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/Details"
|
||||
}
|
||||
],
|
||||
"nullable": true
|
||||
},
|
||||
"generated_text": {
|
||||
"type": "string",
|
||||
|
@ -453,6 +580,38 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"Info": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"model_id",
|
||||
"version"
|
||||
],
|
||||
"properties": {
|
||||
"model_id": {
|
||||
"type": "string",
|
||||
"example": "bigscience/blomm-560m"
|
||||
},
|
||||
"model_pipeline_tag": {
|
||||
"type": "string",
|
||||
"example": "text-generation",
|
||||
"nullable": true
|
||||
},
|
||||
"model_sha": {
|
||||
"type": "string",
|
||||
"example": "e985a63cdc139290c5f700ff1929f0b5942cced2",
|
||||
"nullable": true
|
||||
},
|
||||
"sha": {
|
||||
"type": "string",
|
||||
"example": "null",
|
||||
"nullable": true
|
||||
},
|
||||
"version": {
|
||||
"type": "string",
|
||||
"example": "0.5.0"
|
||||
}
|
||||
}
|
||||
},
|
||||
"PrefillToken": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
|
@ -464,7 +623,8 @@
|
|||
"id": {
|
||||
"type": "integer",
|
||||
"format": "int32",
|
||||
"example": 0
|
||||
"example": 0,
|
||||
"minimum": 0.0
|
||||
},
|
||||
"logprob": {
|
||||
"type": "number",
|
||||
|
@ -491,13 +651,15 @@
|
|||
"generated_tokens": {
|
||||
"type": "integer",
|
||||
"format": "int32",
|
||||
"example": 1
|
||||
"example": 1,
|
||||
"minimum": 0.0
|
||||
},
|
||||
"seed": {
|
||||
"type": "integer",
|
||||
"format": "int64",
|
||||
"example": 42,
|
||||
"nullable": true
|
||||
"nullable": true,
|
||||
"minimum": 0.0
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -508,7 +670,12 @@
|
|||
],
|
||||
"properties": {
|
||||
"details": {
|
||||
"$ref": "#/components/schemas/StreamDetails"
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/StreamDetails"
|
||||
}
|
||||
],
|
||||
"nullable": true
|
||||
},
|
||||
"generated_text": {
|
||||
"type": "string",
|
||||
|
@ -533,7 +700,8 @@
|
|||
"id": {
|
||||
"type": "integer",
|
||||
"format": "int32",
|
||||
"example": 0
|
||||
"example": 0,
|
||||
"minimum": 0.0
|
||||
},
|
||||
"logprob": {
|
||||
"type": "number",
|
||||
|
|
|
@ -392,6 +392,12 @@ fn main() -> ExitCode {
|
|||
model_id,
|
||||
];
|
||||
|
||||
// Model optional revision
|
||||
if let Some(ref revision) = revision {
|
||||
argv.push("--revision".to_string());
|
||||
argv.push(revision.to_string())
|
||||
}
|
||||
|
||||
if json_output {
|
||||
argv.push("--json-output".to_string());
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ version = "0.5.0"
|
|||
edition = "2021"
|
||||
authors = ["Olivier Dehaene"]
|
||||
description = "Text Generation Webserver"
|
||||
build = "build.rs"
|
||||
|
||||
[lib]
|
||||
path = "src/lib.rs"
|
||||
|
@ -26,7 +27,7 @@ nohash-hasher = "0.2.0"
|
|||
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
|
||||
opentelemetry-otlp = "0.11.0"
|
||||
rand = "0.8.5"
|
||||
reqwest = { version = "0.11.14", features = [] }
|
||||
reqwest = { version = "0.11.14", features = [] }
|
||||
serde = "1.0.152"
|
||||
serde_json = "1.0.93"
|
||||
thiserror = "1.0.38"
|
||||
|
@ -39,3 +40,5 @@ tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
|
|||
utoipa = { version = "3.0.1", features = ["axum_extras"] }
|
||||
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }
|
||||
|
||||
[build-dependencies]
|
||||
vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] }
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
use std::error::Error;
|
||||
use vergen::EmitBuilder;
|
||||
|
||||
fn main() -> Result<(), Box<dyn Error>> {
|
||||
EmitBuilder::builder().git_sha(false).emit()?;
|
||||
Ok(())
|
||||
}
|
|
@ -10,6 +10,29 @@ use serde::{Deserialize, Serialize};
|
|||
use utoipa::ToSchema;
|
||||
use validation::Validation;
|
||||
|
||||
/// Hub type
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct ModelInfo {
|
||||
#[serde(rename(deserialize = "id"))]
|
||||
pub model_id: String,
|
||||
pub sha: Option<String>,
|
||||
pub pipeline_tag: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||
pub struct Info {
|
||||
#[schema(example = "bigscience/blomm-560m")]
|
||||
pub model_id: String,
|
||||
#[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")]
|
||||
pub model_sha: Option<String>,
|
||||
#[schema(nullable = true, example = "text-generation")]
|
||||
pub model_pipeline_tag: Option<String>,
|
||||
#[schema(example = "0.5.0")]
|
||||
pub version: &'static str,
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub sha: Option<&'static str>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||
pub(crate) struct GenerateParameters {
|
||||
#[serde(default)]
|
||||
|
|
|
@ -10,8 +10,8 @@ use opentelemetry_otlp::WithExportConfig;
|
|||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::path::Path;
|
||||
use text_generation_client::ShardedClient;
|
||||
use text_generation_router::server;
|
||||
use tokenizers::Tokenizer;
|
||||
use text_generation_router::{server, ModelInfo};
|
||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||
use tower_http::cors::AllowOrigin;
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
|
@ -41,6 +41,8 @@ struct Args {
|
|||
master_shard_uds_path: String,
|
||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||
tokenizer_name: String,
|
||||
#[clap(default_value = "main", long, env)]
|
||||
revision: String,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
validation_workers: usize,
|
||||
#[clap(long, env)]
|
||||
|
@ -66,6 +68,7 @@ fn main() -> Result<(), std::io::Error> {
|
|||
port,
|
||||
master_shard_uds_path,
|
||||
tokenizer_name,
|
||||
revision,
|
||||
validation_workers,
|
||||
json_output,
|
||||
otlp_endpoint,
|
||||
|
@ -90,16 +93,19 @@ fn main() -> Result<(), std::io::Error> {
|
|||
// Tokenizer instance
|
||||
// This will only be used to validate payloads
|
||||
let local_path = Path::new(&tokenizer_name);
|
||||
let tokenizer =
|
||||
if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
|
||||
{
|
||||
// Load local tokenizer
|
||||
Tokenizer::from_file(local_path.join("tokenizer.json")).ok()
|
||||
} else {
|
||||
// Download and instantiate tokenizer
|
||||
// We need to download it outside of the Tokio runtime
|
||||
Tokenizer::from_pretrained(tokenizer_name.clone(), None).ok()
|
||||
let local_model = local_path.exists() && local_path.is_dir();
|
||||
let tokenizer = if local_model {
|
||||
// Load local tokenizer
|
||||
Tokenizer::from_file(local_path.join("tokenizer.json")).ok()
|
||||
} else {
|
||||
// Download and instantiate tokenizer
|
||||
// We need to download it outside of the Tokio runtime
|
||||
let params = FromPretrainedParameters {
|
||||
revision: revision.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).ok()
|
||||
};
|
||||
|
||||
// Launch Tokio runtime
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
|
@ -116,25 +122,23 @@ fn main() -> Result<(), std::io::Error> {
|
|||
tracing::warn!("Rust input length validation and truncation is disabled");
|
||||
}
|
||||
|
||||
// Get pipeline tag
|
||||
let model_info = reqwest::get(format!(
|
||||
"https://huggingface.co/api/models/{tokenizer_name}"
|
||||
))
|
||||
.await
|
||||
.expect("Could not connect to hf.co")
|
||||
.text()
|
||||
.await
|
||||
.expect("error when retrieving model info from hf.co");
|
||||
let model_info: serde_json::Value =
|
||||
serde_json::from_str(&model_info).expect("unable to parse model info");
|
||||
// Get Model info
|
||||
let model_info = match local_model {
|
||||
true => ModelInfo {
|
||||
model_id: tokenizer_name.clone(),
|
||||
sha: None,
|
||||
pipeline_tag: None,
|
||||
},
|
||||
false => get_model_info(&tokenizer_name, &revision).await,
|
||||
};
|
||||
|
||||
// if pipeline-tag == text-generation we default to return_full_text = true
|
||||
let compat_return_full_text = match model_info.get("pipeline_tag") {
|
||||
let compat_return_full_text = match &model_info.pipeline_tag {
|
||||
None => {
|
||||
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
|
||||
false
|
||||
}
|
||||
Some(pipeline_tag) => pipeline_tag.as_str() == Some("text-generation"),
|
||||
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
|
||||
};
|
||||
|
||||
// Instantiate sharded client from the master unix socket
|
||||
|
@ -153,6 +157,7 @@ fn main() -> Result<(), std::io::Error> {
|
|||
|
||||
// Run server
|
||||
server::run(
|
||||
model_info,
|
||||
compat_return_full_text,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
|
@ -226,3 +231,16 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
|
|||
.with(layers)
|
||||
.init();
|
||||
}
|
||||
|
||||
/// get model info from the Huggingface Hub
|
||||
pub async fn get_model_info(model_id: &str, revision: &str) -> ModelInfo {
|
||||
let model_info = reqwest::get(format!(
|
||||
"https://huggingface.co/api/models/{model_id}/revision/{revision}"
|
||||
))
|
||||
.await
|
||||
.expect("Could not connect to hf.co")
|
||||
.text()
|
||||
.await
|
||||
.expect("error when retrieving model info from hf.co");
|
||||
serde_json::from_str(&model_info).expect("unable to parse model info")
|
||||
}
|
||||
|
|
|
@ -3,8 +3,8 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
|||
use crate::validation::ValidationError;
|
||||
use crate::{
|
||||
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
|
||||
GenerateParameters, GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails,
|
||||
StreamResponse, Token, Validation,
|
||||
GenerateParameters, GenerateRequest, GenerateResponse, Infer, Info, ModelInfo, PrefillToken,
|
||||
StreamDetails, StreamResponse, Token, Validation,
|
||||
};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, Method, StatusCode};
|
||||
|
@ -27,7 +27,24 @@ use tracing::{info_span, instrument, Instrument};
|
|||
use utoipa::OpenApi;
|
||||
use utoipa_swagger_ui::SwaggerUi;
|
||||
|
||||
/// Compatibility route with api-inference and AzureML
|
||||
/// Generate tokens if `stream == false` or a stream of token if `stream == true`
|
||||
#[utoipa::path(
|
||||
post,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/",
|
||||
request_body = CompatGenerateRequest,
|
||||
responses(
|
||||
(status = 200, description = "See /generate or /generate_stream"),
|
||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||
example = json ! ({"error": "Request failed during generation"})),
|
||||
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||
example = json ! ({"error": "Model is overloaded"})),
|
||||
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||
example = json ! ({"error": "Input validation error"})),
|
||||
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||
example = json ! ({"error": "Incomplete generation"})),
|
||||
)
|
||||
)]
|
||||
#[instrument(skip(infer))]
|
||||
async fn compat_generate(
|
||||
default_return_full_text: Extension<bool>,
|
||||
|
@ -53,6 +70,26 @@ async fn compat_generate(
|
|||
}
|
||||
}
|
||||
|
||||
/// Text Generation Inference endpoint info
|
||||
#[utoipa::path(
|
||||
get,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/info",
|
||||
responses((status = 200, description = "Served model info", body = Info))
|
||||
)]
|
||||
#[instrument]
|
||||
async fn get_model_info(model_info: Extension<ModelInfo>) -> Json<Info> {
|
||||
let model_info = model_info.0;
|
||||
let info = Info {
|
||||
version: env!("CARGO_PKG_VERSION"),
|
||||
sha: option_env!("VERGEN_GIT_SHA"),
|
||||
model_id: model_info.model_id,
|
||||
model_sha: model_info.sha,
|
||||
model_pipeline_tag: model_info.pipeline_tag,
|
||||
};
|
||||
Json(info)
|
||||
}
|
||||
|
||||
/// Health check method
|
||||
#[instrument(skip(infer))]
|
||||
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||
|
@ -87,21 +124,21 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
|||
|
||||
/// Generate tokens
|
||||
#[utoipa::path(
|
||||
post,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/generate",
|
||||
request_body = GenerateRequest,
|
||||
responses(
|
||||
(status = 200, description = "Generated Text", body = GenerateResponse),
|
||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||
example = json ! ({"error": "Request failed during generation"})),
|
||||
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||
example = json ! ({"error": "Model is overloaded"})),
|
||||
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||
example = json ! ({"error": "Input validation error"})),
|
||||
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||
example = json ! ({"error": "Incomplete generation"})),
|
||||
)
|
||||
post,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/generate",
|
||||
request_body = GenerateRequest,
|
||||
responses(
|
||||
(status = 200, description = "Generated Text", body = GenerateResponse),
|
||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||
example = json ! ({"error": "Request failed during generation"})),
|
||||
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||
example = json ! ({"error": "Model is overloaded"})),
|
||||
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||
example = json ! ({"error": "Input validation error"})),
|
||||
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||
example = json ! ({"error": "Incomplete generation"})),
|
||||
)
|
||||
)]
|
||||
#[instrument(
|
||||
skip(infer),
|
||||
|
@ -264,26 +301,26 @@ async fn generate(
|
|||
|
||||
/// Generate a stream of token using Server-Sent Events
|
||||
#[utoipa::path(
|
||||
post,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/generate_stream",
|
||||
request_body = GenerateRequest,
|
||||
responses(
|
||||
(status = 200, description = "Generated Text", body = StreamResponse,
|
||||
content_type = "text/event-stream"),
|
||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||
example = json ! ({"error": "Request failed during generation"}),
|
||||
content_type = "text/event-stream"),
|
||||
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||
example = json ! ({"error": "Model is overloaded"}),
|
||||
content_type = "text/event-stream"),
|
||||
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||
example = json ! ({"error": "Input validation error"}),
|
||||
content_type = "text/event-stream"),
|
||||
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||
example = json ! ({"error": "Incomplete generation"}),
|
||||
content_type = "text/event-stream"),
|
||||
)
|
||||
post,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/generate_stream",
|
||||
request_body = GenerateRequest,
|
||||
responses(
|
||||
(status = 200, description = "Generated Text", body = StreamResponse,
|
||||
content_type = "text/event-stream"),
|
||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||
example = json ! ({"error": "Request failed during generation"}),
|
||||
content_type = "text/event-stream"),
|
||||
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||
example = json ! ({"error": "Model is overloaded"}),
|
||||
content_type = "text/event-stream"),
|
||||
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||
example = json ! ({"error": "Input validation error"}),
|
||||
content_type = "text/event-stream"),
|
||||
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||
example = json ! ({"error": "Incomplete generation"}),
|
||||
content_type = "text/event-stream"),
|
||||
)
|
||||
)]
|
||||
#[instrument(
|
||||
skip(infer),
|
||||
|
@ -447,10 +484,10 @@ async fn generate_stream(
|
|||
|
||||
/// Prometheus metrics scrape endpoint
|
||||
#[utoipa::path(
|
||||
get,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/metrics",
|
||||
responses((status = 200, description = "Prometheus Metrics", body = String))
|
||||
get,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/metrics",
|
||||
responses((status = 200, description = "Prometheus Metrics", body = String))
|
||||
)]
|
||||
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
||||
prom_handle.render()
|
||||
|
@ -459,6 +496,7 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
|||
/// Serving method
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn run(
|
||||
model_info: ModelInfo,
|
||||
compat_return_full_text: bool,
|
||||
max_concurrent_requests: usize,
|
||||
max_best_of: usize,
|
||||
|
@ -476,36 +514,40 @@ pub async fn run(
|
|||
// OpenAPI documentation
|
||||
#[derive(OpenApi)]
|
||||
#[openapi(
|
||||
paths(
|
||||
generate,
|
||||
generate_stream,
|
||||
metrics,
|
||||
),
|
||||
components(
|
||||
schemas(
|
||||
GenerateRequest,
|
||||
GenerateParameters,
|
||||
PrefillToken,
|
||||
Token,
|
||||
GenerateResponse,
|
||||
BestOfSequence,
|
||||
Details,
|
||||
FinishReason,
|
||||
StreamResponse,
|
||||
StreamDetails,
|
||||
ErrorResponse,
|
||||
)
|
||||
),
|
||||
tags(
|
||||
(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
|
||||
),
|
||||
info(
|
||||
title = "Text Generation Inference",
|
||||
license(
|
||||
name = "Apache 2.0",
|
||||
url = "https://www.apache.org/licenses/LICENSE-2.0"
|
||||
)
|
||||
)
|
||||
paths(
|
||||
get_model_info,
|
||||
compat_generate,
|
||||
generate,
|
||||
generate_stream,
|
||||
metrics,
|
||||
),
|
||||
components(
|
||||
schemas(
|
||||
Info,
|
||||
CompatGenerateRequest,
|
||||
GenerateRequest,
|
||||
GenerateParameters,
|
||||
PrefillToken,
|
||||
Token,
|
||||
GenerateResponse,
|
||||
BestOfSequence,
|
||||
Details,
|
||||
FinishReason,
|
||||
StreamResponse,
|
||||
StreamDetails,
|
||||
ErrorResponse,
|
||||
)
|
||||
),
|
||||
tags(
|
||||
(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
|
||||
),
|
||||
info(
|
||||
title = "Text Generation Inference",
|
||||
license(
|
||||
name = "Apache 2.0",
|
||||
url = "https://www.apache.org/licenses/LICENSE-2.0"
|
||||
)
|
||||
)
|
||||
)]
|
||||
struct ApiDoc;
|
||||
|
||||
|
@ -584,6 +626,7 @@ pub async fn run(
|
|||
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
||||
// Base routes
|
||||
.route("/", post(compat_generate))
|
||||
.route("/info", get(get_model_info))
|
||||
.route("/generate", post(generate))
|
||||
.route("/generate_stream", post(generate_stream))
|
||||
// AWS Sagemaker route
|
||||
|
@ -596,6 +639,7 @@ pub async fn run(
|
|||
.route("/ping", get(health))
|
||||
// Prometheus metrics route
|
||||
.route("/metrics", get(metrics))
|
||||
.layer(Extension(model_info))
|
||||
.layer(Extension(compat_return_full_text))
|
||||
.layer(Extension(infer))
|
||||
.layer(Extension(prom_handle))
|
||||
|
|
Loading…
Reference in New Issue