feat: allow any supported payload on /invocations (#2683)
* feat: allow any supported payload on /invocations * update openAPI * update doc
This commit is contained in:
parent
27ff1871b5
commit
41c2623735
|
@ -98,7 +98,7 @@ curl 127.0.0.1:8080/generate_stream \
|
||||||
You can also use [TGI's Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) to obtain Open AI Chat Completion API compatible responses.
|
You can also use [TGI's Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) to obtain Open AI Chat Completion API compatible responses.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl localhost:3000/v1/chat/completions \
|
curl localhost:8080/v1/chat/completions \
|
||||||
-X POST \
|
-X POST \
|
||||||
-d '{
|
-d '{
|
||||||
"model": "tgi",
|
"model": "tgi",
|
||||||
|
|
|
@ -3,7 +3,7 @@ use std::collections::HashMap;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
|
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
|
||||||
use text_generation_backends_trtllm::TensorRtLlmBackend;
|
use text_generation_backends_trtllm::TensorRtLlmBackend;
|
||||||
use text_generation_router::server;
|
use text_generation_router::{server, usage_stats};
|
||||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||||
|
|
||||||
/// App Configuration
|
/// App Configuration
|
||||||
|
@ -48,14 +48,14 @@ struct Args {
|
||||||
otlp_service_name: String,
|
otlp_service_name: String,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
cors_allow_origin: Option<Vec<String>>,
|
cors_allow_origin: Option<Vec<String>>,
|
||||||
#[clap(long, env, default_value_t = false)]
|
|
||||||
messages_api_enabled: bool,
|
|
||||||
#[clap(default_value = "4", long, env)]
|
#[clap(default_value = "4", long, env)]
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
auth_token: Option<String>,
|
auth_token: Option<String>,
|
||||||
#[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")]
|
#[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")]
|
||||||
executor_worker: PathBuf,
|
executor_worker: PathBuf,
|
||||||
|
#[clap(default_value = "on", long, env)]
|
||||||
|
usage_stats: usage_stats::UsageStatsLevel,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
|
@ -83,10 +83,10 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
otlp_service_name,
|
otlp_service_name,
|
||||||
cors_allow_origin,
|
cors_allow_origin,
|
||||||
messages_api_enabled,
|
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
auth_token,
|
auth_token,
|
||||||
executor_worker,
|
executor_worker,
|
||||||
|
usage_stats,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
// Launch Tokio runtime
|
// Launch Tokio runtime
|
||||||
|
@ -155,11 +155,9 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
||||||
false,
|
false,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
messages_api_enabled,
|
|
||||||
true,
|
true,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
false,
|
usage_stats,
|
||||||
false,
|
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
@ -63,8 +63,6 @@ struct Args {
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
ngrok_edge: Option<String>,
|
ngrok_edge: Option<String>,
|
||||||
#[clap(long, env, default_value_t = false)]
|
#[clap(long, env, default_value_t = false)]
|
||||||
messages_api_enabled: bool,
|
|
||||||
#[clap(long, env, default_value_t = false)]
|
|
||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
#[clap(default_value = "4", long, env)]
|
#[clap(default_value = "4", long, env)]
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
|
@ -110,7 +108,6 @@ async fn main() -> Result<(), RouterError> {
|
||||||
ngrok,
|
ngrok,
|
||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
|
@ -190,7 +187,6 @@ async fn main() -> Result<(), RouterError> {
|
||||||
ngrok,
|
ngrok,
|
||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
|
|
|
@ -63,8 +63,6 @@ struct Args {
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
ngrok_edge: Option<String>,
|
ngrok_edge: Option<String>,
|
||||||
#[clap(long, env, default_value_t = false)]
|
#[clap(long, env, default_value_t = false)]
|
||||||
messages_api_enabled: bool,
|
|
||||||
#[clap(long, env, default_value_t = false)]
|
|
||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
#[clap(default_value = "4", long, env)]
|
#[clap(default_value = "4", long, env)]
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
|
@ -110,7 +108,6 @@ async fn main() -> Result<(), RouterError> {
|
||||||
ngrok,
|
ngrok,
|
||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
|
@ -190,7 +187,6 @@ async fn main() -> Result<(), RouterError> {
|
||||||
ngrok,
|
ngrok,
|
||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
|
|
|
@ -316,6 +316,98 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/invocations": {
|
||||||
|
"post": {
|
||||||
|
"tags": [
|
||||||
|
"Text Generation Inference"
|
||||||
|
],
|
||||||
|
"summary": "Generate tokens from Sagemaker request",
|
||||||
|
"operationId": "sagemaker_compatibility",
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/SagemakerRequest"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
},
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Generated Chat Completion",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/SagemakerResponse"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"text/event-stream": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/SagemakerStreamResponse"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"422": {
|
||||||
|
"description": "Input validation error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Input validation error",
|
||||||
|
"error_type": "validation"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"424": {
|
||||||
|
"description": "Generation Error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Request failed during generation",
|
||||||
|
"error_type": "generation"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"429": {
|
||||||
|
"description": "Model is overloaded",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Model is overloaded",
|
||||||
|
"error_type": "overloaded"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"description": "Incomplete generation",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Incomplete generation",
|
||||||
|
"error_type": "incomplete_generation"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"/metrics": {
|
"/metrics": {
|
||||||
"get": {
|
"get": {
|
||||||
"tags": [
|
"tags": [
|
||||||
|
@ -1865,6 +1957,45 @@
|
||||||
"type": "string"
|
"type": "string"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"SagemakerRequest": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/CompatGenerateRequest"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/ChatRequest"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/CompletionRequest"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"SagemakerResponse": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/GenerateResponse"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/ChatCompletion"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/CompletionFinal"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"SagemakerStreamResponse": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/StreamResponse"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/ChatCompletionChunk"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/Chunk"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
"SimpleToken": {
|
"SimpleToken": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
|
|
|
@ -141,9 +141,7 @@ TGI can be deployed on various cloud providers for scalable and robust text gene
|
||||||
|
|
||||||
## Amazon SageMaker
|
## Amazon SageMaker
|
||||||
|
|
||||||
To enable the Messages API in Amazon SageMaker you need to set the environment variable `MESSAGES_API_ENABLED=true`.
|
Amazon Sagemaker natively supports the message API:
|
||||||
|
|
||||||
This will modify the `/invocations` route to accept Messages dictonaries consisting out of role and content. See the example below on how to deploy Llama with the new Messages API.
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import json
|
import json
|
||||||
|
@ -161,12 +159,11 @@ except ValueError:
|
||||||
hub = {
|
hub = {
|
||||||
'HF_MODEL_ID':'HuggingFaceH4/zephyr-7b-beta',
|
'HF_MODEL_ID':'HuggingFaceH4/zephyr-7b-beta',
|
||||||
'SM_NUM_GPUS': json.dumps(1),
|
'SM_NUM_GPUS': json.dumps(1),
|
||||||
'MESSAGES_API_ENABLED': True
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# create Hugging Face Model Class
|
# create Hugging Face Model Class
|
||||||
huggingface_model = HuggingFaceModel(
|
huggingface_model = HuggingFaceModel(
|
||||||
image_uri=get_huggingface_llm_image_uri("huggingface",version="1.4.0"),
|
image_uri=get_huggingface_llm_image_uri("huggingface",version="2.3.2"),
|
||||||
env=hub,
|
env=hub,
|
||||||
role=role,
|
role=role,
|
||||||
)
|
)
|
||||||
|
|
|
@ -26,7 +26,6 @@ As of release 2.1.2 this is an example of the data collected:
|
||||||
"max_top_n_tokens": 5,
|
"max_top_n_tokens": 5,
|
||||||
"max_total_tokens": 2048,
|
"max_total_tokens": 2048,
|
||||||
"max_waiting_tokens": 20,
|
"max_waiting_tokens": 20,
|
||||||
"messages_api_enabled": false,
|
|
||||||
"model_config": {
|
"model_config": {
|
||||||
"model_type": "Bloom"
|
"model_type": "Bloom"
|
||||||
},
|
},
|
||||||
|
|
|
@ -8,6 +8,7 @@ pub mod validation;
|
||||||
mod kserve;
|
mod kserve;
|
||||||
pub mod logging;
|
pub mod logging;
|
||||||
|
|
||||||
|
mod sagemaker;
|
||||||
pub mod usage_stats;
|
pub mod usage_stats;
|
||||||
mod vertex;
|
mod vertex;
|
||||||
|
|
||||||
|
|
|
@ -1,748 +0,0 @@
|
||||||
use axum::http::HeaderValue;
|
|
||||||
use clap::Parser;
|
|
||||||
use clap::Subcommand;
|
|
||||||
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
|
||||||
use hf_hub::{Cache, Repo, RepoType};
|
|
||||||
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
|
||||||
use opentelemetry::sdk::trace;
|
|
||||||
use opentelemetry::sdk::trace::Sampler;
|
|
||||||
use opentelemetry::sdk::Resource;
|
|
||||||
use opentelemetry::{global, KeyValue};
|
|
||||||
use opentelemetry_otlp::WithExportConfig;
|
|
||||||
use std::fs::File;
|
|
||||||
use std::io::BufReader;
|
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
|
||||||
use std::path::{Path, PathBuf};
|
|
||||||
use text_generation_router::config::Config;
|
|
||||||
use text_generation_router::usage_stats;
|
|
||||||
use text_generation_router::{
|
|
||||||
server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig,
|
|
||||||
};
|
|
||||||
use thiserror::Error;
|
|
||||||
use tokenizers::{processors::template::TemplateProcessing, Tokenizer};
|
|
||||||
use tower_http::cors::AllowOrigin;
|
|
||||||
use tracing_subscriber::layer::SubscriberExt;
|
|
||||||
use tracing_subscriber::util::SubscriberInitExt;
|
|
||||||
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
|
|
||||||
|
|
||||||
/// App Configuration
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[clap(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
|
||||||
#[command(subcommand)]
|
|
||||||
command: Option<Commands>,
|
|
||||||
|
|
||||||
#[clap(default_value = "128", long, env)]
|
|
||||||
max_concurrent_requests: usize,
|
|
||||||
#[clap(default_value = "2", long, env)]
|
|
||||||
max_best_of: usize,
|
|
||||||
#[clap(default_value = "4", long, env)]
|
|
||||||
max_stop_sequences: usize,
|
|
||||||
#[clap(default_value = "5", long, env)]
|
|
||||||
max_top_n_tokens: u32,
|
|
||||||
#[clap(default_value = "1024", long, env)]
|
|
||||||
max_input_tokens: usize,
|
|
||||||
#[clap(default_value = "2048", long, env)]
|
|
||||||
max_total_tokens: usize,
|
|
||||||
#[clap(default_value = "1.2", long, env)]
|
|
||||||
waiting_served_ratio: f32,
|
|
||||||
#[clap(default_value = "4096", long, env)]
|
|
||||||
max_batch_prefill_tokens: u32,
|
|
||||||
#[clap(long, env)]
|
|
||||||
max_batch_total_tokens: Option<u32>,
|
|
||||||
#[clap(default_value = "20", long, env)]
|
|
||||||
max_waiting_tokens: usize,
|
|
||||||
#[clap(long, env)]
|
|
||||||
max_batch_size: Option<usize>,
|
|
||||||
#[clap(default_value = "0.0.0.0", long, env)]
|
|
||||||
hostname: String,
|
|
||||||
#[clap(default_value = "3000", long, short, env)]
|
|
||||||
port: u16,
|
|
||||||
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
|
||||||
master_shard_uds_path: String,
|
|
||||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
|
||||||
tokenizer_name: String,
|
|
||||||
#[clap(long, env)]
|
|
||||||
tokenizer_config_path: Option<String>,
|
|
||||||
#[clap(long, env)]
|
|
||||||
revision: Option<String>,
|
|
||||||
#[clap(default_value = "2", long, env)]
|
|
||||||
validation_workers: usize,
|
|
||||||
#[clap(long, env)]
|
|
||||||
json_output: bool,
|
|
||||||
#[clap(long, env)]
|
|
||||||
otlp_endpoint: Option<String>,
|
|
||||||
#[clap(default_value = "text-generation-inference.router", long, env)]
|
|
||||||
otlp_service_name: String,
|
|
||||||
#[clap(long, env)]
|
|
||||||
cors_allow_origin: Option<Vec<String>>,
|
|
||||||
#[clap(long, env)]
|
|
||||||
api_key: Option<String>,
|
|
||||||
#[clap(long, env)]
|
|
||||||
ngrok: bool,
|
|
||||||
#[clap(long, env)]
|
|
||||||
ngrok_authtoken: Option<String>,
|
|
||||||
#[clap(long, env)]
|
|
||||||
ngrok_edge: Option<String>,
|
|
||||||
#[clap(long, env, default_value_t = false)]
|
|
||||||
messages_api_enabled: bool,
|
|
||||||
#[clap(long, env, default_value_t = false)]
|
|
||||||
disable_grammar_support: bool,
|
|
||||||
#[clap(default_value = "4", long, env)]
|
|
||||||
max_client_batch_size: usize,
|
|
||||||
#[clap(long, env, default_value_t)]
|
|
||||||
disable_usage_stats: bool,
|
|
||||||
#[clap(long, env, default_value_t)]
|
|
||||||
disable_crash_reports: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Subcommand)]
|
|
||||||
enum Commands {
|
|
||||||
PrintSchema,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::main]
|
|
||||||
async fn main() -> Result<(), RouterError> {
|
|
||||||
let args = Args::parse();
|
|
||||||
|
|
||||||
// Pattern match configuration
|
|
||||||
let Args {
|
|
||||||
max_concurrent_requests,
|
|
||||||
max_best_of,
|
|
||||||
max_stop_sequences,
|
|
||||||
max_top_n_tokens,
|
|
||||||
max_input_tokens,
|
|
||||||
max_total_tokens,
|
|
||||||
waiting_served_ratio,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_batch_total_tokens,
|
|
||||||
max_waiting_tokens,
|
|
||||||
max_batch_size,
|
|
||||||
hostname,
|
|
||||||
port,
|
|
||||||
master_shard_uds_path,
|
|
||||||
tokenizer_name,
|
|
||||||
tokenizer_config_path,
|
|
||||||
revision,
|
|
||||||
validation_workers,
|
|
||||||
json_output,
|
|
||||||
otlp_endpoint,
|
|
||||||
otlp_service_name,
|
|
||||||
cors_allow_origin,
|
|
||||||
api_key,
|
|
||||||
ngrok,
|
|
||||||
ngrok_authtoken,
|
|
||||||
ngrok_edge,
|
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
|
||||||
max_client_batch_size,
|
|
||||||
disable_usage_stats,
|
|
||||||
disable_crash_reports,
|
|
||||||
command,
|
|
||||||
} = args;
|
|
||||||
|
|
||||||
let print_schema_command = match command {
|
|
||||||
Some(Commands::PrintSchema) => true,
|
|
||||||
None => {
|
|
||||||
// only init logging if we are not running the print schema command
|
|
||||||
init_logging(otlp_endpoint, otlp_service_name, json_output);
|
|
||||||
false
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Validate args
|
|
||||||
if max_input_tokens >= max_total_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(
|
|
||||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
|
||||||
}
|
|
||||||
|
|
||||||
if validation_workers == 0 {
|
|
||||||
return Err(RouterError::ArgumentValidation(
|
|
||||||
"`validation_workers` must be > 0".to_string(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
|
||||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
|
||||||
}
|
|
||||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CORS allowed origins
|
|
||||||
// map to go inside the option and then map to parse from String to HeaderValue
|
|
||||||
// Finally, convert to AllowOrigin
|
|
||||||
let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
|
|
||||||
AllowOrigin::list(
|
|
||||||
cors_allow_origin
|
|
||||||
.iter()
|
|
||||||
.map(|origin| origin.parse::<HeaderValue>().unwrap()),
|
|
||||||
)
|
|
||||||
});
|
|
||||||
|
|
||||||
// Parse Huggingface hub token
|
|
||||||
let authorization_token = std::env::var("HF_TOKEN")
|
|
||||||
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
|
|
||||||
.ok();
|
|
||||||
|
|
||||||
// Tokenizer instance
|
|
||||||
// This will only be used to validate payloads
|
|
||||||
let local_path = Path::new(&tokenizer_name);
|
|
||||||
|
|
||||||
// Shared API builder initialization
|
|
||||||
let api_builder = || {
|
|
||||||
let mut builder = ApiBuilder::new()
|
|
||||||
.with_progress(false)
|
|
||||||
.with_token(authorization_token);
|
|
||||||
|
|
||||||
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
|
|
||||||
builder = builder.with_cache_dir(cache_dir.into());
|
|
||||||
}
|
|
||||||
|
|
||||||
builder
|
|
||||||
};
|
|
||||||
|
|
||||||
// Decide if we need to use the API based on the revision and local path
|
|
||||||
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
|
|
||||||
|
|
||||||
// Initialize API if needed
|
|
||||||
#[derive(Clone)]
|
|
||||||
enum Type {
|
|
||||||
Api(Api),
|
|
||||||
Cache(Cache),
|
|
||||||
None,
|
|
||||||
}
|
|
||||||
let api = if use_api {
|
|
||||||
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
|
|
||||||
let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
|
|
||||||
.map_err(|_| ())
|
|
||||||
.map(|cache_dir| Cache::new(cache_dir.into()))
|
|
||||||
.unwrap_or_else(|_| Cache::default());
|
|
||||||
|
|
||||||
tracing::warn!("Offline mode active using cache defaults");
|
|
||||||
Type::Cache(cache)
|
|
||||||
} else {
|
|
||||||
tracing::info!("Using the Hugging Face API");
|
|
||||||
match api_builder().build() {
|
|
||||||
Ok(api) => Type::Api(api),
|
|
||||||
Err(_) => {
|
|
||||||
tracing::warn!("Unable to build the Hugging Face API");
|
|
||||||
Type::None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Type::None
|
|
||||||
};
|
|
||||||
|
|
||||||
// Load tokenizer and model info
|
|
||||||
let (
|
|
||||||
tokenizer_filename,
|
|
||||||
config_filename,
|
|
||||||
tokenizer_config_filename,
|
|
||||||
preprocessor_config_filename,
|
|
||||||
processor_config_filename,
|
|
||||||
model_info,
|
|
||||||
) = match api {
|
|
||||||
Type::None => (
|
|
||||||
Some(local_path.join("tokenizer.json")),
|
|
||||||
Some(local_path.join("config.json")),
|
|
||||||
Some(local_path.join("tokenizer_config.json")),
|
|
||||||
Some(local_path.join("preprocessor_config.json")),
|
|
||||||
Some(local_path.join("processor_config.json")),
|
|
||||||
None,
|
|
||||||
),
|
|
||||||
Type::Api(api) => {
|
|
||||||
let api_repo = api.repo(Repo::with_revision(
|
|
||||||
tokenizer_name.to_string(),
|
|
||||||
RepoType::Model,
|
|
||||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
|
||||||
));
|
|
||||||
|
|
||||||
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
|
|
||||||
Ok(tokenizer_filename) => Some(tokenizer_filename),
|
|
||||||
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
|
||||||
};
|
|
||||||
let config_filename = api_repo.get("config.json").await.ok();
|
|
||||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
|
||||||
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
|
|
||||||
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
|
|
||||||
|
|
||||||
let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
|
|
||||||
Some(model_info)
|
|
||||||
} else {
|
|
||||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
|
||||||
None
|
|
||||||
};
|
|
||||||
(
|
|
||||||
tokenizer_filename,
|
|
||||||
config_filename,
|
|
||||||
tokenizer_config_filename,
|
|
||||||
preprocessor_config_filename,
|
|
||||||
processor_config_filename,
|
|
||||||
model_info,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
Type::Cache(cache) => {
|
|
||||||
let repo = cache.repo(Repo::with_revision(
|
|
||||||
tokenizer_name.to_string(),
|
|
||||||
RepoType::Model,
|
|
||||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
|
||||||
));
|
|
||||||
(
|
|
||||||
repo.get("tokenizer.json"),
|
|
||||||
repo.get("config.json"),
|
|
||||||
repo.get("tokenizer_config.json"),
|
|
||||||
repo.get("preprocessor_config.json"),
|
|
||||||
repo.get("processor_config.json"),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let config: Option<Config> = config_filename.and_then(|filename| {
|
|
||||||
std::fs::read_to_string(filename)
|
|
||||||
.ok()
|
|
||||||
.as_ref()
|
|
||||||
.and_then(|c| {
|
|
||||||
let config: Result<Config, _> = serde_json::from_str(c);
|
|
||||||
if let Err(err) = &config {
|
|
||||||
tracing::warn!("Could not parse config {err:?}");
|
|
||||||
}
|
|
||||||
config.ok()
|
|
||||||
})
|
|
||||||
});
|
|
||||||
let model_info = model_info.unwrap_or_else(|| HubModelInfo {
|
|
||||||
model_id: tokenizer_name.to_string(),
|
|
||||||
sha: None,
|
|
||||||
pipeline_tag: None,
|
|
||||||
});
|
|
||||||
|
|
||||||
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
|
||||||
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
|
||||||
{
|
|
||||||
HubTokenizerConfig::from_file(filename)
|
|
||||||
} else {
|
|
||||||
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
|
||||||
};
|
|
||||||
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
|
||||||
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
|
||||||
HubTokenizerConfig::default()
|
|
||||||
});
|
|
||||||
let tokenizer_class = tokenizer_config.tokenizer_class.clone();
|
|
||||||
|
|
||||||
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
|
|
||||||
let mut tokenizer = Tokenizer::from_file(filename).ok();
|
|
||||||
if let Some(tokenizer) = &mut tokenizer {
|
|
||||||
if let Some(class) = &tokenizer_config.tokenizer_class {
|
|
||||||
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
|
|
||||||
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
|
|
||||||
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
|
|
||||||
tokenizer.with_post_processor(post_processor);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tokenizer
|
|
||||||
});
|
|
||||||
|
|
||||||
let preprocessor_config =
|
|
||||||
preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
|
|
||||||
let processor_config = processor_config_filename
|
|
||||||
.and_then(HubProcessorConfig::from_file)
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
tracing::info!("Using config {config:?}");
|
|
||||||
if tokenizer.is_none() {
|
|
||||||
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
|
|
||||||
tracing::warn!("Rust input length validation and truncation is disabled");
|
|
||||||
}
|
|
||||||
|
|
||||||
// if pipeline-tag == text-generation we default to return_full_text = true
|
|
||||||
let compat_return_full_text = match &model_info.pipeline_tag {
|
|
||||||
None => {
|
|
||||||
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
|
|
||||||
true
|
|
||||||
}
|
|
||||||
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
|
|
||||||
};
|
|
||||||
|
|
||||||
// Determine the server port based on the feature and environment variable.
|
|
||||||
let port = if cfg!(feature = "google") {
|
|
||||||
std::env::var("AIP_HTTP_PORT")
|
|
||||||
.map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
|
|
||||||
.unwrap_or(port)
|
|
||||||
} else {
|
|
||||||
port
|
|
||||||
};
|
|
||||||
|
|
||||||
let addr = match hostname.parse() {
|
|
||||||
Ok(ip) => SocketAddr::new(ip, port),
|
|
||||||
Err(_) => {
|
|
||||||
tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
|
|
||||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Only send usage stats when TGI is run in container and the function returns Some
|
|
||||||
let is_container = matches!(usage_stats::is_container(), Ok(true));
|
|
||||||
|
|
||||||
let user_agent = if !disable_usage_stats && is_container {
|
|
||||||
let reduced_args = usage_stats::Args::new(
|
|
||||||
config.clone(),
|
|
||||||
tokenizer_class,
|
|
||||||
max_concurrent_requests,
|
|
||||||
max_best_of,
|
|
||||||
max_stop_sequences,
|
|
||||||
max_top_n_tokens,
|
|
||||||
max_input_tokens,
|
|
||||||
max_total_tokens,
|
|
||||||
waiting_served_ratio,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_batch_total_tokens,
|
|
||||||
max_waiting_tokens,
|
|
||||||
max_batch_size,
|
|
||||||
revision,
|
|
||||||
validation_workers,
|
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
|
||||||
max_client_batch_size,
|
|
||||||
disable_usage_stats,
|
|
||||||
disable_crash_reports,
|
|
||||||
);
|
|
||||||
Some(usage_stats::UserAgent::new(reduced_args))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(ref ua) = user_agent {
|
|
||||||
let start_event =
|
|
||||||
usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None);
|
|
||||||
tokio::spawn(async move {
|
|
||||||
start_event.send().await;
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
// Run server
|
|
||||||
let result = server::run(
|
|
||||||
master_shard_uds_path,
|
|
||||||
model_info,
|
|
||||||
compat_return_full_text,
|
|
||||||
max_concurrent_requests,
|
|
||||||
max_best_of,
|
|
||||||
max_stop_sequences,
|
|
||||||
max_top_n_tokens,
|
|
||||||
max_input_tokens,
|
|
||||||
max_total_tokens,
|
|
||||||
waiting_served_ratio,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_batch_total_tokens,
|
|
||||||
max_waiting_tokens,
|
|
||||||
max_batch_size,
|
|
||||||
tokenizer,
|
|
||||||
config,
|
|
||||||
validation_workers,
|
|
||||||
addr,
|
|
||||||
cors_allow_origin,
|
|
||||||
api_key,
|
|
||||||
ngrok,
|
|
||||||
ngrok_authtoken,
|
|
||||||
ngrok_edge,
|
|
||||||
tokenizer_config,
|
|
||||||
preprocessor_config,
|
|
||||||
processor_config,
|
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
|
||||||
max_client_batch_size,
|
|
||||||
print_schema_command,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(_) => {
|
|
||||||
if let Some(ref ua) = user_agent {
|
|
||||||
let stop_event = usage_stats::UsageStatsEvent::new(
|
|
||||||
ua.clone(),
|
|
||||||
usage_stats::EventType::Stop,
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
stop_event.send().await;
|
|
||||||
};
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
if let Some(ref ua) = user_agent {
|
|
||||||
if !disable_crash_reports {
|
|
||||||
let error_event = usage_stats::UsageStatsEvent::new(
|
|
||||||
ua.clone(),
|
|
||||||
usage_stats::EventType::Error,
|
|
||||||
Some(e.to_string()),
|
|
||||||
);
|
|
||||||
error_event.send().await;
|
|
||||||
} else {
|
|
||||||
let unknow_error_event = usage_stats::UsageStatsEvent::new(
|
|
||||||
ua.clone(),
|
|
||||||
usage_stats::EventType::Error,
|
|
||||||
Some("unknow_error".to_string()),
|
|
||||||
);
|
|
||||||
unknow_error_event.send().await;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Err(RouterError::WebServer(e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
|
|
||||||
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
|
|
||||||
/// - otlp_service_name service name to appear in APM
|
|
||||||
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
|
|
||||||
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
|
|
||||||
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
|
|
||||||
fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
|
|
||||||
let mut layers = Vec::new();
|
|
||||||
|
|
||||||
// STDOUT/STDERR layer
|
|
||||||
let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
|
|
||||||
let fmt_layer = tracing_subscriber::fmt::layer()
|
|
||||||
.with_file(true)
|
|
||||||
.with_ansi(ansi)
|
|
||||||
.with_line_number(true);
|
|
||||||
|
|
||||||
let fmt_layer = match json_output {
|
|
||||||
true => fmt_layer.json().flatten_event(true).boxed(),
|
|
||||||
false => fmt_layer.boxed(),
|
|
||||||
};
|
|
||||||
layers.push(fmt_layer);
|
|
||||||
|
|
||||||
// OpenTelemetry tracing layer
|
|
||||||
if let Some(otlp_endpoint) = otlp_endpoint {
|
|
||||||
global::set_text_map_propagator(TraceContextPropagator::new());
|
|
||||||
|
|
||||||
let tracer = opentelemetry_otlp::new_pipeline()
|
|
||||||
.tracing()
|
|
||||||
.with_exporter(
|
|
||||||
opentelemetry_otlp::new_exporter()
|
|
||||||
.tonic()
|
|
||||||
.with_endpoint(otlp_endpoint),
|
|
||||||
)
|
|
||||||
.with_trace_config(
|
|
||||||
trace::config()
|
|
||||||
.with_resource(Resource::new(vec![KeyValue::new(
|
|
||||||
"service.name",
|
|
||||||
otlp_service_name,
|
|
||||||
)]))
|
|
||||||
.with_sampler(Sampler::AlwaysOn),
|
|
||||||
)
|
|
||||||
.install_batch(opentelemetry::runtime::Tokio);
|
|
||||||
|
|
||||||
if let Ok(tracer) = tracer {
|
|
||||||
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
|
|
||||||
init_tracing_opentelemetry::init_propagator().unwrap();
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Filter events with LOG_LEVEL
|
|
||||||
let varname = "LOG_LEVEL";
|
|
||||||
let env_filter = if let Ok(log_level) = std::env::var(varname) {
|
|
||||||
// Override to avoid simple logs to be spammed with tokio level informations
|
|
||||||
let log_level = match &log_level[..] {
|
|
||||||
"warn" => "text_generation_launcher=warn,text_generation_router=warn",
|
|
||||||
"info" => "text_generation_launcher=info,text_generation_router=info",
|
|
||||||
"debug" => "text_generation_launcher=debug,text_generation_router=debug",
|
|
||||||
log_level => log_level,
|
|
||||||
};
|
|
||||||
EnvFilter::builder()
|
|
||||||
.with_default_directive(LevelFilter::INFO.into())
|
|
||||||
.parse_lossy(log_level)
|
|
||||||
} else {
|
|
||||||
EnvFilter::new("info")
|
|
||||||
};
|
|
||||||
|
|
||||||
tracing_subscriber::registry()
|
|
||||||
.with(env_filter)
|
|
||||||
.with(layers)
|
|
||||||
.init();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// get model info from the Huggingface Hub
|
|
||||||
pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
|
|
||||||
let response = api.info_request().send().await.ok()?;
|
|
||||||
|
|
||||||
if response.status().is_success() {
|
|
||||||
let hub_model_info: HubModelInfo =
|
|
||||||
serde_json::from_str(&response.text().await.ok()?).ok()?;
|
|
||||||
if let Some(sha) = &hub_model_info.sha {
|
|
||||||
tracing::info!(
|
|
||||||
"Serving revision {sha} of model {}",
|
|
||||||
hub_model_info.model_id
|
|
||||||
);
|
|
||||||
}
|
|
||||||
Some(hub_model_info)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// get base tokenizer
|
|
||||||
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
|
|
||||||
let config_filename = api_repo.get("config.json").await.ok()?;
|
|
||||||
|
|
||||||
// Open the file in read-only mode with buffer.
|
|
||||||
let file = File::open(config_filename).ok()?;
|
|
||||||
let reader = BufReader::new(file);
|
|
||||||
|
|
||||||
// Read the JSON contents of the file as an instance of `User`.
|
|
||||||
let config: serde_json::Value = serde_json::from_reader(reader).ok()?;
|
|
||||||
|
|
||||||
if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
|
|
||||||
let api_base_repo = api.repo(Repo::with_revision(
|
|
||||||
base_model_id.to_string(),
|
|
||||||
RepoType::Model,
|
|
||||||
"main".to_string(),
|
|
||||||
));
|
|
||||||
|
|
||||||
api_base_repo.get("tokenizer.json").await.ok()
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// get tokenizer_config from the Huggingface Hub
|
|
||||||
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
|
|
||||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;
|
|
||||||
|
|
||||||
// Open the file in read-only mode with buffer.
|
|
||||||
let file = File::open(tokenizer_config_filename).ok()?;
|
|
||||||
let reader = BufReader::new(file);
|
|
||||||
|
|
||||||
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
|
||||||
let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
|
|
||||||
.map_err(|e| {
|
|
||||||
tracing::warn!("Unable to parse tokenizer config: {}", e);
|
|
||||||
e
|
|
||||||
})
|
|
||||||
.ok()?;
|
|
||||||
|
|
||||||
Some(tokenizer_config)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a post_processor for the LlamaTokenizer
|
|
||||||
pub fn create_post_processor(
|
|
||||||
tokenizer: &Tokenizer,
|
|
||||||
tokenizer_config: &HubTokenizerConfig,
|
|
||||||
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
|
|
||||||
let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
|
|
||||||
let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);
|
|
||||||
|
|
||||||
let bos_token = tokenizer_config.bos_token.as_ref();
|
|
||||||
let eos_token = tokenizer_config.eos_token.as_ref();
|
|
||||||
|
|
||||||
if add_bos_token && bos_token.is_none() {
|
|
||||||
panic!("add_bos_token = true but bos_token is None");
|
|
||||||
}
|
|
||||||
|
|
||||||
if add_eos_token && eos_token.is_none() {
|
|
||||||
panic!("add_eos_token = true but eos_token is None");
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut single = Vec::new();
|
|
||||||
let mut pair = Vec::new();
|
|
||||||
let mut special_tokens = Vec::new();
|
|
||||||
|
|
||||||
if add_bos_token {
|
|
||||||
if let Some(bos) = bos_token {
|
|
||||||
let bos_token_id = tokenizer
|
|
||||||
.token_to_id(bos.as_str())
|
|
||||||
.expect("Should have found the bos token id");
|
|
||||||
special_tokens.push((bos.as_str(), bos_token_id));
|
|
||||||
single.push(format!("{}:0", bos.as_str()));
|
|
||||||
pair.push(format!("{}:0", bos.as_str()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
single.push("$A:0".to_string());
|
|
||||||
pair.push("$A:0".to_string());
|
|
||||||
|
|
||||||
if add_eos_token {
|
|
||||||
if let Some(eos) = eos_token {
|
|
||||||
let eos_token_id = tokenizer
|
|
||||||
.token_to_id(eos.as_str())
|
|
||||||
.expect("Should have found the eos token id");
|
|
||||||
special_tokens.push((eos.as_str(), eos_token_id));
|
|
||||||
single.push(format!("{}:0", eos.as_str()));
|
|
||||||
pair.push(format!("{}:0", eos.as_str()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if add_bos_token {
|
|
||||||
if let Some(bos) = bos_token {
|
|
||||||
pair.push(format!("{}:1", bos.as_str()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pair.push("$B:1".to_string());
|
|
||||||
|
|
||||||
if add_eos_token {
|
|
||||||
if let Some(eos) = eos_token {
|
|
||||||
pair.push(format!("{}:1", eos.as_str()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let post_processor = TemplateProcessing::builder()
|
|
||||||
.try_single(single)?
|
|
||||||
.try_pair(pair)?
|
|
||||||
.special_tokens(special_tokens)
|
|
||||||
.build()?;
|
|
||||||
|
|
||||||
Ok(post_processor)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
enum RouterError {
|
|
||||||
#[error("Argument validation error: {0}")]
|
|
||||||
ArgumentValidation(String),
|
|
||||||
#[error("WebServer error: {0}")]
|
|
||||||
WebServer(#[from] server::WebServerError),
|
|
||||||
#[error("Tokio runtime failed to start: {0}")]
|
|
||||||
Tokio(#[from] std::io::Error),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use text_generation_router::TokenizerConfigToken;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_create_post_processor() {
|
|
||||||
let tokenizer_config = HubTokenizerConfig {
|
|
||||||
add_bos_token: None,
|
|
||||||
add_eos_token: None,
|
|
||||||
bos_token: Some(TokenizerConfigToken::String("<s>".to_string())),
|
|
||||||
eos_token: Some(TokenizerConfigToken::String("</s>".to_string())),
|
|
||||||
chat_template: None,
|
|
||||||
tokenizer_class: None,
|
|
||||||
completion_template: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let tokenizer =
|
|
||||||
Tokenizer::from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", None).unwrap();
|
|
||||||
let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap();
|
|
||||||
|
|
||||||
let expected = TemplateProcessing::builder()
|
|
||||||
.try_single("<s>:0 $A:0")
|
|
||||||
.unwrap()
|
|
||||||
.try_pair("<s>:0 $A:0 <s>:1 $B:1")
|
|
||||||
.unwrap()
|
|
||||||
.special_tokens(vec![("<s>".to_string(), 1)])
|
|
||||||
.build()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(post_processor, expected);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,82 @@
|
||||||
|
use crate::infer::Infer;
|
||||||
|
use crate::server::{chat_completions, compat_generate, completions, ComputeType};
|
||||||
|
use crate::{
|
||||||
|
ChatCompletion, ChatCompletionChunk, ChatRequest, Chunk, CompatGenerateRequest,
|
||||||
|
CompletionFinal, CompletionRequest, ErrorResponse, GenerateResponse, Info, StreamResponse,
|
||||||
|
};
|
||||||
|
use axum::extract::Extension;
|
||||||
|
use axum::http::StatusCode;
|
||||||
|
use axum::response::Response;
|
||||||
|
use axum::Json;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tracing::instrument;
|
||||||
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub(crate) enum SagemakerRequest {
|
||||||
|
Generate(CompatGenerateRequest),
|
||||||
|
Chat(ChatRequest),
|
||||||
|
Completion(CompletionRequest),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Used for OpenAPI specs
|
||||||
|
#[allow(dead_code)]
|
||||||
|
#[derive(Serialize, ToSchema)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub(crate) enum SagemakerResponse {
|
||||||
|
Generate(GenerateResponse),
|
||||||
|
Chat(ChatCompletion),
|
||||||
|
Completion(CompletionFinal),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Used for OpenAPI specs
|
||||||
|
#[allow(dead_code)]
|
||||||
|
#[derive(Serialize, ToSchema)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub(crate) enum SagemakerStreamResponse {
|
||||||
|
Generate(StreamResponse),
|
||||||
|
Chat(ChatCompletionChunk),
|
||||||
|
Completion(Chunk),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate tokens from Sagemaker request
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/invocations",
|
||||||
|
request_body = SagemakerRequest,
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Generated Chat Completion",
|
||||||
|
content(
|
||||||
|
("application/json" = SagemakerResponse),
|
||||||
|
("text/event-stream" = SagemakerStreamResponse),
|
||||||
|
)),
|
||||||
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Request failed during generation", "error_type": "generation"})),
|
||||||
|
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Model is overloaded", "error_type": "overloaded"})),
|
||||||
|
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Input validation error", "error_type": "validation"})),
|
||||||
|
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Incomplete generation", "error_type": "incomplete_generation"})),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
pub(crate) async fn sagemaker_compatibility(
|
||||||
|
default_return_full_text: Extension<bool>,
|
||||||
|
infer: Extension<Infer>,
|
||||||
|
compute_type: Extension<ComputeType>,
|
||||||
|
info: Extension<Info>,
|
||||||
|
Json(req): Json<SagemakerRequest>,
|
||||||
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
match req {
|
||||||
|
SagemakerRequest::Generate(req) => {
|
||||||
|
compat_generate(default_return_full_text, infer, compute_type, Json(req)).await
|
||||||
|
}
|
||||||
|
SagemakerRequest::Chat(req) => chat_completions(infer, compute_type, info, Json(req)).await,
|
||||||
|
SagemakerRequest::Completion(req) => {
|
||||||
|
completions(infer, compute_type, info, Json(req)).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -7,6 +7,10 @@ use crate::kserve::{
|
||||||
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
|
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
|
||||||
kserve_model_metadata, kserve_model_metadata_ready,
|
kserve_model_metadata, kserve_model_metadata_ready,
|
||||||
};
|
};
|
||||||
|
use crate::sagemaker::{
|
||||||
|
sagemaker_compatibility, SagemakerRequest, SagemakerResponse, SagemakerStreamResponse,
|
||||||
|
__path_sagemaker_compatibility,
|
||||||
|
};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::vertex::vertex_compatibility;
|
use crate::vertex::vertex_compatibility;
|
||||||
use crate::ChatTokenizeResponse;
|
use crate::ChatTokenizeResponse;
|
||||||
|
@ -83,7 +87,7 @@ example = json ! ({"error": "Incomplete generation"})),
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
#[instrument(skip(infer, req))]
|
#[instrument(skip(infer, req))]
|
||||||
async fn compat_generate(
|
pub(crate) async fn compat_generate(
|
||||||
Extension(default_return_full_text): Extension<bool>,
|
Extension(default_return_full_text): Extension<bool>,
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
compute_type: Extension<ComputeType>,
|
compute_type: Extension<ComputeType>,
|
||||||
|
@ -678,7 +682,7 @@ time_per_token,
|
||||||
seed,
|
seed,
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
async fn completions(
|
pub(crate) async fn completions(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
Extension(compute_type): Extension<ComputeType>,
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
Extension(info): Extension<Info>,
|
Extension(info): Extension<Info>,
|
||||||
|
@ -1202,7 +1206,7 @@ time_per_token,
|
||||||
seed,
|
seed,
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
async fn chat_completions(
|
pub(crate) async fn chat_completions(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
Extension(compute_type): Extension<ComputeType>,
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
Extension(info): Extension<Info>,
|
Extension(info): Extension<Info>,
|
||||||
|
@ -1513,11 +1517,13 @@ completions,
|
||||||
tokenize,
|
tokenize,
|
||||||
metrics,
|
metrics,
|
||||||
openai_get_model_info,
|
openai_get_model_info,
|
||||||
|
sagemaker_compatibility,
|
||||||
),
|
),
|
||||||
components(
|
components(
|
||||||
schemas(
|
schemas(
|
||||||
Info,
|
Info,
|
||||||
CompatGenerateRequest,
|
CompatGenerateRequest,
|
||||||
|
SagemakerRequest,
|
||||||
GenerateRequest,
|
GenerateRequest,
|
||||||
GrammarType,
|
GrammarType,
|
||||||
ChatRequest,
|
ChatRequest,
|
||||||
|
@ -1540,6 +1546,8 @@ ChatCompletionTopLogprob,
|
||||||
ChatCompletion,
|
ChatCompletion,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
CompletionComplete,
|
CompletionComplete,
|
||||||
|
SagemakerResponse,
|
||||||
|
SagemakerStreamResponse,
|
||||||
Chunk,
|
Chunk,
|
||||||
Completion,
|
Completion,
|
||||||
CompletionFinal,
|
CompletionFinal,
|
||||||
|
@ -1607,7 +1615,6 @@ pub async fn run(
|
||||||
ngrok: bool,
|
ngrok: bool,
|
||||||
_ngrok_authtoken: Option<String>,
|
_ngrok_authtoken: Option<String>,
|
||||||
_ngrok_edge: Option<String>,
|
_ngrok_edge: Option<String>,
|
||||||
messages_api_enabled: bool,
|
|
||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
usage_stats_level: usage_stats::UsageStatsLevel,
|
usage_stats_level: usage_stats::UsageStatsLevel,
|
||||||
|
@ -1836,7 +1843,6 @@ pub async fn run(
|
||||||
// max_batch_size,
|
// max_batch_size,
|
||||||
revision.clone(),
|
revision.clone(),
|
||||||
validation_workers,
|
validation_workers,
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats_level,
|
usage_stats_level,
|
||||||
|
@ -1878,7 +1884,6 @@ pub async fn run(
|
||||||
ngrok,
|
ngrok,
|
||||||
_ngrok_authtoken,
|
_ngrok_authtoken,
|
||||||
_ngrok_edge,
|
_ngrok_edge,
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
model_info,
|
model_info,
|
||||||
|
@ -1938,7 +1943,6 @@ async fn start(
|
||||||
ngrok: bool,
|
ngrok: bool,
|
||||||
_ngrok_authtoken: Option<String>,
|
_ngrok_authtoken: Option<String>,
|
||||||
_ngrok_edge: Option<String>,
|
_ngrok_edge: Option<String>,
|
||||||
messages_api_enabled: bool,
|
|
||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
model_info: HubModelInfo,
|
model_info: HubModelInfo,
|
||||||
|
@ -2253,6 +2257,7 @@ async fn start(
|
||||||
.route("/v1/chat/completions", post(chat_completions))
|
.route("/v1/chat/completions", post(chat_completions))
|
||||||
.route("/v1/completions", post(completions))
|
.route("/v1/completions", post(completions))
|
||||||
.route("/vertex", post(vertex_compatibility))
|
.route("/vertex", post(vertex_compatibility))
|
||||||
|
.route("/invocations", post(sagemaker_compatibility))
|
||||||
.route("/tokenize", post(tokenize));
|
.route("/tokenize", post(tokenize));
|
||||||
|
|
||||||
if let Some(api_key) = api_key {
|
if let Some(api_key) = api_key {
|
||||||
|
@ -2288,13 +2293,6 @@ async fn start(
|
||||||
.route("/metrics", get(metrics))
|
.route("/metrics", get(metrics))
|
||||||
.route("/v1/models", get(openai_get_model_info));
|
.route("/v1/models", get(openai_get_model_info));
|
||||||
|
|
||||||
// Conditional AWS Sagemaker route
|
|
||||||
let aws_sagemaker_route = if messages_api_enabled {
|
|
||||||
Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED
|
|
||||||
} else {
|
|
||||||
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
|
|
||||||
};
|
|
||||||
|
|
||||||
let compute_type =
|
let compute_type =
|
||||||
ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
|
ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
|
||||||
|
|
||||||
|
@ -2302,8 +2300,7 @@ async fn start(
|
||||||
let mut app = Router::new()
|
let mut app = Router::new()
|
||||||
.merge(swagger_ui)
|
.merge(swagger_ui)
|
||||||
.merge(base_routes)
|
.merge(base_routes)
|
||||||
.merge(info_routes)
|
.merge(info_routes);
|
||||||
.merge(aws_sagemaker_route);
|
|
||||||
|
|
||||||
#[cfg(feature = "google")]
|
#[cfg(feature = "google")]
|
||||||
{
|
{
|
||||||
|
|
|
@ -93,7 +93,6 @@ pub struct Args {
|
||||||
// max_batch_size: Option<usize>,
|
// max_batch_size: Option<usize>,
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
messages_api_enabled: bool,
|
|
||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
usage_stats_level: UsageStatsLevel,
|
usage_stats_level: UsageStatsLevel,
|
||||||
|
@ -117,7 +116,6 @@ impl Args {
|
||||||
// max_batch_size: Option<usize>,
|
// max_batch_size: Option<usize>,
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
messages_api_enabled: bool,
|
|
||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
usage_stats_level: UsageStatsLevel,
|
usage_stats_level: UsageStatsLevel,
|
||||||
|
@ -138,7 +136,6 @@ impl Args {
|
||||||
// max_batch_size,
|
// max_batch_size,
|
||||||
revision,
|
revision,
|
||||||
validation_workers,
|
validation_workers,
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats_level,
|
usage_stats_level,
|
||||||
|
|
|
@ -172,6 +172,8 @@ def check_openapi(check: bool):
|
||||||
# allow for trailing whitespace since it's not significant
|
# allow for trailing whitespace since it's not significant
|
||||||
# and the precommit hook will remove it
|
# and the precommit hook will remove it
|
||||||
"lint",
|
"lint",
|
||||||
|
"--skip-rule",
|
||||||
|
"security-defined",
|
||||||
filename,
|
filename,
|
||||||
],
|
],
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
|
|
Loading…
Reference in New Issue