parent
b927244eb5
commit
2475aede61
|
@ -2430,6 +2430,7 @@ dependencies = [
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
"utoipa",
|
"utoipa",
|
||||||
"utoipa-swagger-ui",
|
"utoipa-swagger-ui",
|
||||||
|
"vergen",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -2468,8 +2469,10 @@ version = "0.3.20"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "cd0cbfecb4d19b5ea75bb31ad904eb5b9fa13f21079c3b92017ebdf4999a5890"
|
checksum = "cd0cbfecb4d19b5ea75bb31ad904eb5b9fa13f21079c3b92017ebdf4999a5890"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"itoa",
|
||||||
"serde",
|
"serde",
|
||||||
"time-core",
|
"time-core",
|
||||||
|
"time-macros",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -2478,6 +2481,15 @@ version = "0.1.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2e153e1f1acaef8acc537e68b44906d2db6436e2b35ac2c6b42640fff91f00fd"
|
checksum = "2e153e1f1acaef8acc537e68b44906d2db6436e2b35ac2c6b42640fff91f00fd"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "time-macros"
|
||||||
|
version = "0.2.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "fd80a657e71da814b8e5d60d3374fc6d35045062245d80224748ae522dd76f36"
|
||||||
|
dependencies = [
|
||||||
|
"time-core",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tinyvec"
|
name = "tinyvec"
|
||||||
version = "1.6.0"
|
version = "1.6.0"
|
||||||
|
@ -2966,6 +2978,17 @@ version = "0.2.15"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
|
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "vergen"
|
||||||
|
version = "8.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c1b86a8af1dedf089b1c78338678e4c7492b6045649042d94faf19690499d236"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"rustversion",
|
||||||
|
"time",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "version_check"
|
name = "version_check"
|
||||||
version = "0.9.4"
|
version = "0.9.4"
|
||||||
|
|
|
@ -4,8 +4,7 @@
|
||||||
"title": "Text Generation Inference",
|
"title": "Text Generation Inference",
|
||||||
"description": "Text Generation Webserver",
|
"description": "Text Generation Webserver",
|
||||||
"contact": {
|
"contact": {
|
||||||
"name": "Olivier Dehaene",
|
"name": "Olivier Dehaene"
|
||||||
"email": "olivier@huggingface.co"
|
|
||||||
},
|
},
|
||||||
"license": {
|
"license": {
|
||||||
"name": "Apache 2.0",
|
"name": "Apache 2.0",
|
||||||
|
@ -14,6 +13,83 @@
|
||||||
"version": "0.5.0"
|
"version": "0.5.0"
|
||||||
},
|
},
|
||||||
"paths": {
|
"paths": {
|
||||||
|
"/": {
|
||||||
|
"post": {
|
||||||
|
"tags": [
|
||||||
|
"Text Generation Inference"
|
||||||
|
],
|
||||||
|
"summary": "Generate tokens if `stream == false` or a stream of token if `stream == true`",
|
||||||
|
"description": "Generate tokens if `stream == false` or a stream of token if `stream == true`",
|
||||||
|
"operationId": "compat_generate",
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/CompatGenerateRequest"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
},
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "See /generate or /generate_stream"
|
||||||
|
},
|
||||||
|
"422": {
|
||||||
|
"description": "Input validation error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Input validation error"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"424": {
|
||||||
|
"description": "Generation Error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Request failed during generation"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"429": {
|
||||||
|
"description": "Model is overloaded",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Model is overloaded"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"description": "Incomplete generation",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Incomplete generation"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"/generate": {
|
"/generate": {
|
||||||
"post": {
|
"post": {
|
||||||
"tags": [
|
"tags": [
|
||||||
|
@ -95,8 +171,7 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
"deprecated": false
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/generate_stream": {
|
"/generate_stream": {
|
||||||
|
@ -180,8 +255,29 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"deprecated": false
|
"/info": {
|
||||||
|
"get": {
|
||||||
|
"tags": [
|
||||||
|
"Text Generation Inference"
|
||||||
|
],
|
||||||
|
"summary": "Text Generation Inference endpoint info",
|
||||||
|
"description": "Text Generation Inference endpoint info",
|
||||||
|
"operationId": "get_model_info",
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Served model info",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/Info"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/metrics": {
|
"/metrics": {
|
||||||
|
@ -203,8 +299,7 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
"deprecated": false
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -230,7 +325,8 @@
|
||||||
"generated_tokens": {
|
"generated_tokens": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"format": "int32",
|
"format": "int32",
|
||||||
"example": 1
|
"example": 1,
|
||||||
|
"minimum": 0.0
|
||||||
},
|
},
|
||||||
"prefill": {
|
"prefill": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
|
@ -242,7 +338,8 @@
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"format": "int64",
|
"format": "int64",
|
||||||
"example": 42,
|
"example": 42,
|
||||||
"nullable": true
|
"nullable": true,
|
||||||
|
"minimum": 0.0
|
||||||
},
|
},
|
||||||
"tokens": {
|
"tokens": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
|
@ -252,6 +349,24 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"CompatGenerateRequest": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"inputs"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"inputs": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "My name is Olivier and I"
|
||||||
|
},
|
||||||
|
"parameters": {
|
||||||
|
"$ref": "#/components/schemas/GenerateParameters"
|
||||||
|
},
|
||||||
|
"stream": {
|
||||||
|
"type": "boolean"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"Details": {
|
"Details": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
|
@ -265,7 +380,8 @@
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"$ref": "#/components/schemas/BestOfSequence"
|
"$ref": "#/components/schemas/BestOfSequence"
|
||||||
}
|
},
|
||||||
|
"nullable": true
|
||||||
},
|
},
|
||||||
"finish_reason": {
|
"finish_reason": {
|
||||||
"$ref": "#/components/schemas/FinishReason"
|
"$ref": "#/components/schemas/FinishReason"
|
||||||
|
@ -273,7 +389,8 @@
|
||||||
"generated_tokens": {
|
"generated_tokens": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"format": "int32",
|
"format": "int32",
|
||||||
"example": 1
|
"example": 1,
|
||||||
|
"minimum": 0.0
|
||||||
},
|
},
|
||||||
"prefill": {
|
"prefill": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
|
@ -285,7 +402,8 @@
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"format": "int64",
|
"format": "int64",
|
||||||
"example": 42,
|
"example": 42,
|
||||||
"nullable": true
|
"nullable": true,
|
||||||
|
"minimum": 0.0
|
||||||
},
|
},
|
||||||
"tokens": {
|
"tokens": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
|
@ -326,6 +444,7 @@
|
||||||
"default": "null",
|
"default": "null",
|
||||||
"example": 1,
|
"example": 1,
|
||||||
"nullable": true,
|
"nullable": true,
|
||||||
|
"minimum": 0.0,
|
||||||
"exclusiveMinimum": 0.0
|
"exclusiveMinimum": 0.0
|
||||||
},
|
},
|
||||||
"details": {
|
"details": {
|
||||||
|
@ -341,6 +460,7 @@
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"format": "int32",
|
"format": "int32",
|
||||||
"default": "20",
|
"default": "20",
|
||||||
|
"minimum": 0.0,
|
||||||
"exclusiveMaximum": 512.0,
|
"exclusiveMaximum": 512.0,
|
||||||
"exclusiveMinimum": 0.0
|
"exclusiveMinimum": 0.0
|
||||||
},
|
},
|
||||||
|
@ -364,6 +484,7 @@
|
||||||
"default": "null",
|
"default": "null",
|
||||||
"example": "null",
|
"example": "null",
|
||||||
"nullable": true,
|
"nullable": true,
|
||||||
|
"minimum": 0.0,
|
||||||
"exclusiveMinimum": 0.0
|
"exclusiveMinimum": 0.0
|
||||||
},
|
},
|
||||||
"stop": {
|
"stop": {
|
||||||
|
@ -405,7 +526,8 @@
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"default": "null",
|
"default": "null",
|
||||||
"example": "null",
|
"example": "null",
|
||||||
"nullable": true
|
"nullable": true,
|
||||||
|
"minimum": 0.0
|
||||||
},
|
},
|
||||||
"typical_p": {
|
"typical_p": {
|
||||||
"type": "number",
|
"type": "number",
|
||||||
|
@ -445,7 +567,12 @@
|
||||||
],
|
],
|
||||||
"properties": {
|
"properties": {
|
||||||
"details": {
|
"details": {
|
||||||
|
"allOf": [
|
||||||
|
{
|
||||||
"$ref": "#/components/schemas/Details"
|
"$ref": "#/components/schemas/Details"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"nullable": true
|
||||||
},
|
},
|
||||||
"generated_text": {
|
"generated_text": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
@ -453,6 +580,38 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"Info": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"model_id",
|
||||||
|
"version"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"model_id": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "bigscience/blomm-560m"
|
||||||
|
},
|
||||||
|
"model_pipeline_tag": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "text-generation",
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
|
"model_sha": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "e985a63cdc139290c5f700ff1929f0b5942cced2",
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
|
"sha": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "null",
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
|
"version": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "0.5.0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"PrefillToken": {
|
"PrefillToken": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
|
@ -464,7 +623,8 @@
|
||||||
"id": {
|
"id": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"format": "int32",
|
"format": "int32",
|
||||||
"example": 0
|
"example": 0,
|
||||||
|
"minimum": 0.0
|
||||||
},
|
},
|
||||||
"logprob": {
|
"logprob": {
|
||||||
"type": "number",
|
"type": "number",
|
||||||
|
@ -491,13 +651,15 @@
|
||||||
"generated_tokens": {
|
"generated_tokens": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"format": "int32",
|
"format": "int32",
|
||||||
"example": 1
|
"example": 1,
|
||||||
|
"minimum": 0.0
|
||||||
},
|
},
|
||||||
"seed": {
|
"seed": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"format": "int64",
|
"format": "int64",
|
||||||
"example": 42,
|
"example": 42,
|
||||||
"nullable": true
|
"nullable": true,
|
||||||
|
"minimum": 0.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -508,7 +670,12 @@
|
||||||
],
|
],
|
||||||
"properties": {
|
"properties": {
|
||||||
"details": {
|
"details": {
|
||||||
|
"allOf": [
|
||||||
|
{
|
||||||
"$ref": "#/components/schemas/StreamDetails"
|
"$ref": "#/components/schemas/StreamDetails"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"nullable": true
|
||||||
},
|
},
|
||||||
"generated_text": {
|
"generated_text": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
@ -533,7 +700,8 @@
|
||||||
"id": {
|
"id": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"format": "int32",
|
"format": "int32",
|
||||||
"example": 0
|
"example": 0,
|
||||||
|
"minimum": 0.0
|
||||||
},
|
},
|
||||||
"logprob": {
|
"logprob": {
|
||||||
"type": "number",
|
"type": "number",
|
||||||
|
|
|
@ -392,6 +392,12 @@ fn main() -> ExitCode {
|
||||||
model_id,
|
model_id,
|
||||||
];
|
];
|
||||||
|
|
||||||
|
// Model optional revision
|
||||||
|
if let Some(ref revision) = revision {
|
||||||
|
argv.push("--revision".to_string());
|
||||||
|
argv.push(revision.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
if json_output {
|
if json_output {
|
||||||
argv.push("--json-output".to_string());
|
argv.push("--json-output".to_string());
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ version = "0.5.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["Olivier Dehaene"]
|
authors = ["Olivier Dehaene"]
|
||||||
description = "Text Generation Webserver"
|
description = "Text Generation Webserver"
|
||||||
|
build = "build.rs"
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
path = "src/lib.rs"
|
path = "src/lib.rs"
|
||||||
|
@ -39,3 +40,5 @@ tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
|
||||||
utoipa = { version = "3.0.1", features = ["axum_extras"] }
|
utoipa = { version = "3.0.1", features = ["axum_extras"] }
|
||||||
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }
|
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] }
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
use std::error::Error;
|
||||||
|
use vergen::EmitBuilder;
|
||||||
|
|
||||||
|
fn main() -> Result<(), Box<dyn Error>> {
|
||||||
|
EmitBuilder::builder().git_sha(false).emit()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
|
@ -10,6 +10,29 @@ use serde::{Deserialize, Serialize};
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
use validation::Validation;
|
use validation::Validation;
|
||||||
|
|
||||||
|
/// Hub type
|
||||||
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
|
pub struct ModelInfo {
|
||||||
|
#[serde(rename(deserialize = "id"))]
|
||||||
|
pub model_id: String,
|
||||||
|
pub sha: Option<String>,
|
||||||
|
pub pipeline_tag: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||||
|
pub struct Info {
|
||||||
|
#[schema(example = "bigscience/blomm-560m")]
|
||||||
|
pub model_id: String,
|
||||||
|
#[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")]
|
||||||
|
pub model_sha: Option<String>,
|
||||||
|
#[schema(nullable = true, example = "text-generation")]
|
||||||
|
pub model_pipeline_tag: Option<String>,
|
||||||
|
#[schema(example = "0.5.0")]
|
||||||
|
pub version: &'static str,
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub sha: Option<&'static str>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||||
pub(crate) struct GenerateParameters {
|
pub(crate) struct GenerateParameters {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
|
|
@ -10,8 +10,8 @@ use opentelemetry_otlp::WithExportConfig;
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use text_generation_client::ShardedClient;
|
use text_generation_client::ShardedClient;
|
||||||
use text_generation_router::server;
|
use text_generation_router::{server, ModelInfo};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||||
use tower_http::cors::AllowOrigin;
|
use tower_http::cors::AllowOrigin;
|
||||||
use tracing_subscriber::layer::SubscriberExt;
|
use tracing_subscriber::layer::SubscriberExt;
|
||||||
use tracing_subscriber::util::SubscriberInitExt;
|
use tracing_subscriber::util::SubscriberInitExt;
|
||||||
|
@ -41,6 +41,8 @@ struct Args {
|
||||||
master_shard_uds_path: String,
|
master_shard_uds_path: String,
|
||||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||||
tokenizer_name: String,
|
tokenizer_name: String,
|
||||||
|
#[clap(default_value = "main", long, env)]
|
||||||
|
revision: String,
|
||||||
#[clap(default_value = "2", long, env)]
|
#[clap(default_value = "2", long, env)]
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
|
@ -66,6 +68,7 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
port,
|
port,
|
||||||
master_shard_uds_path,
|
master_shard_uds_path,
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
|
revision,
|
||||||
validation_workers,
|
validation_workers,
|
||||||
json_output,
|
json_output,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
|
@ -90,15 +93,18 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
// Tokenizer instance
|
// Tokenizer instance
|
||||||
// This will only be used to validate payloads
|
// This will only be used to validate payloads
|
||||||
let local_path = Path::new(&tokenizer_name);
|
let local_path = Path::new(&tokenizer_name);
|
||||||
let tokenizer =
|
let local_model = local_path.exists() && local_path.is_dir();
|
||||||
if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
|
let tokenizer = if local_model {
|
||||||
{
|
|
||||||
// Load local tokenizer
|
// Load local tokenizer
|
||||||
Tokenizer::from_file(local_path.join("tokenizer.json")).ok()
|
Tokenizer::from_file(local_path.join("tokenizer.json")).ok()
|
||||||
} else {
|
} else {
|
||||||
// Download and instantiate tokenizer
|
// Download and instantiate tokenizer
|
||||||
// We need to download it outside of the Tokio runtime
|
// We need to download it outside of the Tokio runtime
|
||||||
Tokenizer::from_pretrained(tokenizer_name.clone(), None).ok()
|
let params = FromPretrainedParameters {
|
||||||
|
revision: revision.clone(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).ok()
|
||||||
};
|
};
|
||||||
|
|
||||||
// Launch Tokio runtime
|
// Launch Tokio runtime
|
||||||
|
@ -116,25 +122,23 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
tracing::warn!("Rust input length validation and truncation is disabled");
|
tracing::warn!("Rust input length validation and truncation is disabled");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get pipeline tag
|
// Get Model info
|
||||||
let model_info = reqwest::get(format!(
|
let model_info = match local_model {
|
||||||
"https://huggingface.co/api/models/{tokenizer_name}"
|
true => ModelInfo {
|
||||||
))
|
model_id: tokenizer_name.clone(),
|
||||||
.await
|
sha: None,
|
||||||
.expect("Could not connect to hf.co")
|
pipeline_tag: None,
|
||||||
.text()
|
},
|
||||||
.await
|
false => get_model_info(&tokenizer_name, &revision).await,
|
||||||
.expect("error when retrieving model info from hf.co");
|
};
|
||||||
let model_info: serde_json::Value =
|
|
||||||
serde_json::from_str(&model_info).expect("unable to parse model info");
|
|
||||||
|
|
||||||
// if pipeline-tag == text-generation we default to return_full_text = true
|
// if pipeline-tag == text-generation we default to return_full_text = true
|
||||||
let compat_return_full_text = match model_info.get("pipeline_tag") {
|
let compat_return_full_text = match &model_info.pipeline_tag {
|
||||||
None => {
|
None => {
|
||||||
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
|
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
Some(pipeline_tag) => pipeline_tag.as_str() == Some("text-generation"),
|
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
|
||||||
};
|
};
|
||||||
|
|
||||||
// Instantiate sharded client from the master unix socket
|
// Instantiate sharded client from the master unix socket
|
||||||
|
@ -153,6 +157,7 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
|
|
||||||
// Run server
|
// Run server
|
||||||
server::run(
|
server::run(
|
||||||
|
model_info,
|
||||||
compat_return_full_text,
|
compat_return_full_text,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
max_best_of,
|
max_best_of,
|
||||||
|
@ -226,3 +231,16 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
|
||||||
.with(layers)
|
.with(layers)
|
||||||
.init();
|
.init();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// get model info from the Huggingface Hub
|
||||||
|
pub async fn get_model_info(model_id: &str, revision: &str) -> ModelInfo {
|
||||||
|
let model_info = reqwest::get(format!(
|
||||||
|
"https://huggingface.co/api/models/{model_id}/revision/{revision}"
|
||||||
|
))
|
||||||
|
.await
|
||||||
|
.expect("Could not connect to hf.co")
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.expect("error when retrieving model info from hf.co");
|
||||||
|
serde_json::from_str(&model_info).expect("unable to parse model info")
|
||||||
|
}
|
||||||
|
|
|
@ -3,8 +3,8 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::{
|
use crate::{
|
||||||
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
|
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
|
||||||
GenerateParameters, GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails,
|
GenerateParameters, GenerateRequest, GenerateResponse, Infer, Info, ModelInfo, PrefillToken,
|
||||||
StreamResponse, Token, Validation,
|
StreamDetails, StreamResponse, Token, Validation,
|
||||||
};
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
|
@ -27,7 +27,24 @@ use tracing::{info_span, instrument, Instrument};
|
||||||
use utoipa::OpenApi;
|
use utoipa::OpenApi;
|
||||||
use utoipa_swagger_ui::SwaggerUi;
|
use utoipa_swagger_ui::SwaggerUi;
|
||||||
|
|
||||||
/// Compatibility route with api-inference and AzureML
|
/// Generate tokens if `stream == false` or a stream of token if `stream == true`
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/",
|
||||||
|
request_body = CompatGenerateRequest,
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "See /generate or /generate_stream"),
|
||||||
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Request failed during generation"})),
|
||||||
|
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Model is overloaded"})),
|
||||||
|
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Input validation error"})),
|
||||||
|
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Incomplete generation"})),
|
||||||
|
)
|
||||||
|
)]
|
||||||
#[instrument(skip(infer))]
|
#[instrument(skip(infer))]
|
||||||
async fn compat_generate(
|
async fn compat_generate(
|
||||||
default_return_full_text: Extension<bool>,
|
default_return_full_text: Extension<bool>,
|
||||||
|
@ -53,6 +70,26 @@ async fn compat_generate(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Text Generation Inference endpoint info
|
||||||
|
#[utoipa::path(
|
||||||
|
get,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/info",
|
||||||
|
responses((status = 200, description = "Served model info", body = Info))
|
||||||
|
)]
|
||||||
|
#[instrument]
|
||||||
|
async fn get_model_info(model_info: Extension<ModelInfo>) -> Json<Info> {
|
||||||
|
let model_info = model_info.0;
|
||||||
|
let info = Info {
|
||||||
|
version: env!("CARGO_PKG_VERSION"),
|
||||||
|
sha: option_env!("VERGEN_GIT_SHA"),
|
||||||
|
model_id: model_info.model_id,
|
||||||
|
model_sha: model_info.sha,
|
||||||
|
model_pipeline_tag: model_info.pipeline_tag,
|
||||||
|
};
|
||||||
|
Json(info)
|
||||||
|
}
|
||||||
|
|
||||||
/// Health check method
|
/// Health check method
|
||||||
#[instrument(skip(infer))]
|
#[instrument(skip(infer))]
|
||||||
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
@ -87,21 +124,21 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
||||||
|
|
||||||
/// Generate tokens
|
/// Generate tokens
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
post,
|
||||||
tag = "Text Generation Inference",
|
tag = "Text Generation Inference",
|
||||||
path = "/generate",
|
path = "/generate",
|
||||||
request_body = GenerateRequest,
|
request_body = GenerateRequest,
|
||||||
responses(
|
responses(
|
||||||
(status = 200, description = "Generated Text", body = GenerateResponse),
|
(status = 200, description = "Generated Text", body = GenerateResponse),
|
||||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Request failed during generation"})),
|
example = json ! ({"error": "Request failed during generation"})),
|
||||||
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Model is overloaded"})),
|
example = json ! ({"error": "Model is overloaded"})),
|
||||||
(status = 422, description = "Input validation error", body = ErrorResponse,
|
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Input validation error"})),
|
example = json ! ({"error": "Input validation error"})),
|
||||||
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Incomplete generation"})),
|
example = json ! ({"error": "Incomplete generation"})),
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
#[instrument(
|
#[instrument(
|
||||||
skip(infer),
|
skip(infer),
|
||||||
|
@ -264,26 +301,26 @@ async fn generate(
|
||||||
|
|
||||||
/// Generate a stream of token using Server-Sent Events
|
/// Generate a stream of token using Server-Sent Events
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
post,
|
||||||
tag = "Text Generation Inference",
|
tag = "Text Generation Inference",
|
||||||
path = "/generate_stream",
|
path = "/generate_stream",
|
||||||
request_body = GenerateRequest,
|
request_body = GenerateRequest,
|
||||||
responses(
|
responses(
|
||||||
(status = 200, description = "Generated Text", body = StreamResponse,
|
(status = 200, description = "Generated Text", body = StreamResponse,
|
||||||
content_type = "text/event-stream"),
|
content_type = "text/event-stream"),
|
||||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Request failed during generation"}),
|
example = json ! ({"error": "Request failed during generation"}),
|
||||||
content_type = "text/event-stream"),
|
content_type = "text/event-stream"),
|
||||||
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Model is overloaded"}),
|
example = json ! ({"error": "Model is overloaded"}),
|
||||||
content_type = "text/event-stream"),
|
content_type = "text/event-stream"),
|
||||||
(status = 422, description = "Input validation error", body = ErrorResponse,
|
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Input validation error"}),
|
example = json ! ({"error": "Input validation error"}),
|
||||||
content_type = "text/event-stream"),
|
content_type = "text/event-stream"),
|
||||||
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Incomplete generation"}),
|
example = json ! ({"error": "Incomplete generation"}),
|
||||||
content_type = "text/event-stream"),
|
content_type = "text/event-stream"),
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
#[instrument(
|
#[instrument(
|
||||||
skip(infer),
|
skip(infer),
|
||||||
|
@ -447,10 +484,10 @@ async fn generate_stream(
|
||||||
|
|
||||||
/// Prometheus metrics scrape endpoint
|
/// Prometheus metrics scrape endpoint
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
get,
|
get,
|
||||||
tag = "Text Generation Inference",
|
tag = "Text Generation Inference",
|
||||||
path = "/metrics",
|
path = "/metrics",
|
||||||
responses((status = 200, description = "Prometheus Metrics", body = String))
|
responses((status = 200, description = "Prometheus Metrics", body = String))
|
||||||
)]
|
)]
|
||||||
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
||||||
prom_handle.render()
|
prom_handle.render()
|
||||||
|
@ -459,6 +496,7 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
||||||
/// Serving method
|
/// Serving method
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub async fn run(
|
pub async fn run(
|
||||||
|
model_info: ModelInfo,
|
||||||
compat_return_full_text: bool,
|
compat_return_full_text: bool,
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
max_best_of: usize,
|
max_best_of: usize,
|
||||||
|
@ -477,12 +515,16 @@ pub async fn run(
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
#[openapi(
|
#[openapi(
|
||||||
paths(
|
paths(
|
||||||
|
get_model_info,
|
||||||
|
compat_generate,
|
||||||
generate,
|
generate,
|
||||||
generate_stream,
|
generate_stream,
|
||||||
metrics,
|
metrics,
|
||||||
),
|
),
|
||||||
components(
|
components(
|
||||||
schemas(
|
schemas(
|
||||||
|
Info,
|
||||||
|
CompatGenerateRequest,
|
||||||
GenerateRequest,
|
GenerateRequest,
|
||||||
GenerateParameters,
|
GenerateParameters,
|
||||||
PrefillToken,
|
PrefillToken,
|
||||||
|
@ -584,6 +626,7 @@ pub async fn run(
|
||||||
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
||||||
// Base routes
|
// Base routes
|
||||||
.route("/", post(compat_generate))
|
.route("/", post(compat_generate))
|
||||||
|
.route("/info", get(get_model_info))
|
||||||
.route("/generate", post(generate))
|
.route("/generate", post(generate))
|
||||||
.route("/generate_stream", post(generate_stream))
|
.route("/generate_stream", post(generate_stream))
|
||||||
// AWS Sagemaker route
|
// AWS Sagemaker route
|
||||||
|
@ -596,6 +639,7 @@ pub async fn run(
|
||||||
.route("/ping", get(health))
|
.route("/ping", get(health))
|
||||||
// Prometheus metrics route
|
// Prometheus metrics route
|
||||||
.route("/metrics", get(metrics))
|
.route("/metrics", get(metrics))
|
||||||
|
.layer(Extension(model_info))
|
||||||
.layer(Extension(compat_return_full_text))
|
.layer(Extension(compat_return_full_text))
|
||||||
.layer(Extension(infer))
|
.layer(Extension(infer))
|
||||||
.layer(Extension(prom_handle))
|
.layer(Extension(prom_handle))
|
||||||
|
|
Loading…
Reference in New Issue