feat: supports openai chat completions API (#1427)
This PR adds support to make TGI a drop in replacement for OpenAI clients by exposing the same HTTP interface. Notes - TGI inits a single model at startup so the `model` field is unused in HTTP requests. - `max_tokens` and `stream` should work as expected but other params may be (unimplemented or not supported) General approach - fetch the `tokenizer_config` at startup from the hub - pass `tokenizer_config` into `Infer` so we have it at request time - use the `chat_template` on the config to format chat request - parse jinja template and render chat string - pass inputs into existing generate function - wrap generation output in expected structure before returning # How to test ### Streaming curl ```bash curl localhost:3000/v1/chat/completions \ -X POST \ -d '{ "model": "tgi", "messages": [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": "What is deep learning?" } ], "stream": true, "max_tokens": 20 }' \ -H 'Content-Type: application/json' ``` It is also possible to use the `openai` python library and change the base url ### 🌊 STREAMING REQUEST ```python from openai import OpenAI # init the client but point it to TGI client = OpenAI( base_url="http://localhost:3000/v1", api_key="not needed for a local LLM" ) chat_completion = client.chat.completions.create( model="tgi", messages=[ {"role": "system", "content": "You are a helpful assistant." }, {"role": "user", "content": "What is deep learning?"} ], stream=True ) # iterate and print stream for message in chat_completion: print(message) # ChatCompletionChunk(id='', choices=[Choice(delta=ChoiceDelta(content=' that', function_call=None, role='assistant', tool_calls=None), finish_reason=None, index=2, logprobs=None)], created=1704486761, model='', object='text_completion', system_fingerprint='') ``` ### 🚗 SYNCHRONOUS REQUEST ```python from openai import OpenAI # init the client but point it to TGI client = OpenAI( base_url="http://localhost:3000/v1", api_key="not needed for a local LLM" ) chat_completion = client.chat.completions.create( model="tgi", messages=[ {"role": "system", "content": "You are a helpful assistant." }, {"role": "user", "content": "What is deep learning?"} ], stream=False ) print(chat_completion) # ChatCompletion(id='', choices=[Choice(finish_reason=None, index=0, logprobs=None, message=ChatCompletionMessage(content='\nDeep learning is a new field of research that has been gaining traction in the last ...', role='assistant', function_call=None, tool_calls=None))], created=1704486762, model='', object='text_completion', system_fingerprint='', usage=CompletionUsage(completion_tokens=100, prompt_tokens=76, total_tokens=176)) ``` ## How to run dev ```bash cd text-generation-inference/server MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 text-generation-server serve --trust-remote-code gpt2 ``` ***note many of the existing `chat_templates` use non standard `jinja` (ie. adding a `raise` to the template) which will throw an error when parsing; hence using `upstage/SOLAR-10.7B-Instruct-v1.0` since it has a valid template ```bash cd text-generation-inference/router cargo run -- --tokenizer-name upstage/SOLAR-10.7B-Instruct-v1.0 ``` trigger ```bash curl localhost:3000/v1/chat/completions \ -X POST \ -d '{ "model": "gpt-3.5-turbo", "messages": [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": "What is the IP address of the Google DNS servers?" } ], "stream": true, "max_tokens": 20, "logprobs": true }' \ -H 'Content-Type: application/json' ``` ^ supports `stream: true` and `stream: false` requests
This commit is contained in:
parent
ac08b4ef9c
commit
0eabc83541
|
@ -773,9 +773,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-channel"
|
name = "futures-channel"
|
||||||
version = "0.3.29"
|
version = "0.3.30"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ff4dd66668b557604244583e3e1e1eada8c5c2e96a6d0d6653ede395b78bbacb"
|
checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"futures-sink",
|
"futures-sink",
|
||||||
|
@ -783,9 +783,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-core"
|
name = "futures-core"
|
||||||
version = "0.3.29"
|
version = "0.3.30"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c"
|
checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-executor"
|
name = "futures-executor"
|
||||||
|
@ -800,15 +800,15 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-io"
|
name = "futures-io"
|
||||||
version = "0.3.29"
|
version = "0.3.30"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8bf34a163b5c4c52d0478a4d757da8fb65cabef42ba90515efee0f6f9fa45aaa"
|
checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-macro"
|
name = "futures-macro"
|
||||||
version = "0.3.29"
|
version = "0.3.30"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb"
|
checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
|
@ -817,21 +817,21 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-sink"
|
name = "futures-sink"
|
||||||
version = "0.3.29"
|
version = "0.3.30"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e36d3378ee38c2a36ad710c5d30c2911d752cb941c00c72dbabfb786a7970817"
|
checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-task"
|
name = "futures-task"
|
||||||
version = "0.3.29"
|
version = "0.3.30"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2"
|
checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-util"
|
name = "futures-util"
|
||||||
version = "0.3.29"
|
version = "0.3.30"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104"
|
checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures-channel",
|
"futures-channel",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
|
@ -1373,6 +1373,15 @@ dependencies = [
|
||||||
"unicase",
|
"unicase",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "minijinja"
|
||||||
|
version = "1.0.10"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "208758577ef2c86cf5dd3e85730d161413ec3284e2d73b2ef65d9a24d9971bcb"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "minimal-lexical"
|
name = "minimal-lexical"
|
||||||
version = "0.2.1"
|
version = "0.2.1"
|
||||||
|
@ -2807,10 +2816,12 @@ dependencies = [
|
||||||
"axum-tracing-opentelemetry",
|
"axum-tracing-opentelemetry",
|
||||||
"clap",
|
"clap",
|
||||||
"futures",
|
"futures",
|
||||||
|
"futures-util",
|
||||||
"hf-hub",
|
"hf-hub",
|
||||||
"init-tracing-opentelemetry",
|
"init-tracing-opentelemetry",
|
||||||
"metrics",
|
"metrics",
|
||||||
"metrics-exporter-prometheus",
|
"metrics-exporter-prometheus",
|
||||||
|
"minijinja",
|
||||||
"ngrok",
|
"ngrok",
|
||||||
"nohash-hasher",
|
"nohash-hasher",
|
||||||
"opentelemetry",
|
"opentelemetry",
|
||||||
|
|
|
@ -43,6 +43,8 @@ utoipa = { version = "3.5.0", features = ["axum_extras"] }
|
||||||
utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
|
utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
|
||||||
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
||||||
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
||||||
|
minijinja = "1.0.10"
|
||||||
|
futures-util = "0.3.30"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
/// Batching and inference logic
|
/// Batching and inference logic
|
||||||
use crate::validation::{Validation, ValidationError};
|
use crate::validation::{Validation, ValidationError};
|
||||||
|
use crate::HubTokenizerConfig;
|
||||||
|
use crate::{ChatRequest, GenerateRequest, GenerateStreamResponse, PrefillToken};
|
||||||
use crate::{Entry, Queue, Token};
|
use crate::{Entry, Queue, Token};
|
||||||
use crate::{GenerateRequest, PrefillToken};
|
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
|
use minijinja::{Environment, ErrorKind, Template};
|
||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::IntMap;
|
||||||
use std::sync::{
|
use std::sync::{
|
||||||
atomic::{AtomicBool, Ordering},
|
atomic::{AtomicBool, Ordering},
|
||||||
|
@ -13,7 +15,7 @@ use text_generation_client::{
|
||||||
};
|
};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::mpsc::error::SendError;
|
use tokio::sync::mpsc::error::SendError;
|
||||||
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tokio_stream::StreamExt;
|
use tokio_stream::StreamExt;
|
||||||
|
@ -30,6 +32,8 @@ pub struct Infer {
|
||||||
shared: Arc<Shared>,
|
shared: Arc<Shared>,
|
||||||
/// Inference limit
|
/// Inference limit
|
||||||
limit_concurrent_requests: Arc<Semaphore>,
|
limit_concurrent_requests: Arc<Semaphore>,
|
||||||
|
/// Chat template
|
||||||
|
template: Option<Template<'static, 'static>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Infer shared state
|
/// Infer shared state
|
||||||
|
@ -52,6 +56,7 @@ impl Infer {
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
generation_health: Arc<AtomicBool>,
|
generation_health: Arc<AtomicBool>,
|
||||||
|
tokenizer_config: HubTokenizerConfig,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Infer shared state
|
// Infer shared state
|
||||||
let queue = Queue::new(requires_padding, 16, window_size, speculate);
|
let queue = Queue::new(requires_padding, 16, window_size, speculate);
|
||||||
|
@ -74,11 +79,21 @@ impl Infer {
|
||||||
// Inference limit with a semaphore
|
// Inference limit with a semaphore
|
||||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||||
|
|
||||||
|
let template = tokenizer_config.chat_template.map(|t| {
|
||||||
|
let env = Box::new(Environment::new());
|
||||||
|
let template_str = t.into_boxed_str();
|
||||||
|
// leaking env and template_str as read-only, static resources for performance.
|
||||||
|
Box::leak(env)
|
||||||
|
.template_from_str(Box::leak(template_str))
|
||||||
|
.unwrap()
|
||||||
|
});
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
validation,
|
validation,
|
||||||
queue,
|
queue,
|
||||||
shared,
|
shared,
|
||||||
limit_concurrent_requests: semaphore,
|
limit_concurrent_requests: semaphore,
|
||||||
|
template,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,14 +102,7 @@ impl Infer {
|
||||||
pub(crate) async fn generate_stream(
|
pub(crate) async fn generate_stream(
|
||||||
&self,
|
&self,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
) -> Result<
|
) -> Result<GenerateStreamResponse, InferError> {
|
||||||
(
|
|
||||||
OwnedSemaphorePermit,
|
|
||||||
u32,
|
|
||||||
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
|
||||||
),
|
|
||||||
InferError,
|
|
||||||
> {
|
|
||||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||||
let permit = self
|
let permit = self
|
||||||
.clone()
|
.clone()
|
||||||
|
@ -139,6 +147,20 @@ impl Infer {
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Apply the chat template to the chat request
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
pub(crate) fn apply_chat_template(&self, chat: ChatRequest) -> Result<String, InferError> {
|
||||||
|
self.template
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
||||||
|
.render(chat)
|
||||||
|
.map_err(|e| {
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
||||||
|
tracing::error!("{e}");
|
||||||
|
InferError::TemplateError(e)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
/// Add a new request to the queue and return a InferResponse
|
/// Add a new request to the queue and return a InferResponse
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
pub(crate) async fn generate(
|
pub(crate) async fn generate(
|
||||||
|
@ -550,9 +572,9 @@ fn send_responses(
|
||||||
let mut iterator = tokens_
|
let mut iterator = tokens_
|
||||||
.ids
|
.ids
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.zip(tokens_.logprobs.into_iter())
|
.zip(tokens_.logprobs)
|
||||||
.zip(tokens_.texts.into_iter())
|
.zip(tokens_.texts)
|
||||||
.zip(tokens_.is_special.into_iter())
|
.zip(tokens_.is_special)
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.peekable();
|
.peekable();
|
||||||
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
|
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
|
||||||
|
@ -665,6 +687,8 @@ pub enum InferError {
|
||||||
ValidationError(#[from] ValidationError),
|
ValidationError(#[from] ValidationError),
|
||||||
#[error("Incomplete generation")]
|
#[error("Incomplete generation")]
|
||||||
IncompleteGeneration,
|
IncompleteGeneration,
|
||||||
|
#[error("Template error: {0}")]
|
||||||
|
TemplateError(#[from] minijinja::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl InferError {
|
impl InferError {
|
||||||
|
@ -674,6 +698,7 @@ impl InferError {
|
||||||
InferError::Overloaded(_) => "overloaded",
|
InferError::Overloaded(_) => "overloaded",
|
||||||
InferError::ValidationError(_) => "validation",
|
InferError::ValidationError(_) => "validation",
|
||||||
InferError::IncompleteGeneration => "incomplete_generation",
|
InferError::IncompleteGeneration => "incomplete_generation",
|
||||||
|
InferError::TemplateError(_) => "template_error",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,12 +5,21 @@ mod queue;
|
||||||
pub mod server;
|
pub mod server;
|
||||||
mod validation;
|
mod validation;
|
||||||
|
|
||||||
use infer::Infer;
|
use infer::{Infer, InferError, InferStreamResponse};
|
||||||
use queue::{Entry, Queue};
|
use queue::{Entry, Queue};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tokio::sync::OwnedSemaphorePermit;
|
||||||
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
use validation::Validation;
|
use validation::Validation;
|
||||||
|
|
||||||
|
/// Type alias for generation responses
|
||||||
|
pub(crate) type GenerateStreamResponse = (
|
||||||
|
OwnedSemaphorePermit,
|
||||||
|
u32, // input_length
|
||||||
|
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
||||||
|
);
|
||||||
|
|
||||||
/// Hub type
|
/// Hub type
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
pub struct HubModelInfo {
|
pub struct HubModelInfo {
|
||||||
|
@ -20,6 +29,19 @@ pub struct HubModelInfo {
|
||||||
pub pipeline_tag: Option<String>,
|
pub pipeline_tag: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Default)]
|
||||||
|
pub struct HubTokenizerConfig {
|
||||||
|
#[serde(default)]
|
||||||
|
pub chat_template: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HubTokenizerConfig {
|
||||||
|
pub fn from_file(filename: &str) -> Self {
|
||||||
|
let content = std::fs::read_to_string(filename).unwrap();
|
||||||
|
serde_json::from_str(&content).unwrap_or_default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||||
pub struct Info {
|
pub struct Info {
|
||||||
/// Model info
|
/// Model info
|
||||||
|
@ -152,7 +174,7 @@ fn default_parameters() -> GenerateParameters {
|
||||||
top_k: None,
|
top_k: None,
|
||||||
top_p: None,
|
top_p: None,
|
||||||
typical_p: None,
|
typical_p: None,
|
||||||
do_sample: false,
|
do_sample: true,
|
||||||
max_new_tokens: default_max_new_tokens(),
|
max_new_tokens: default_max_new_tokens(),
|
||||||
return_full_text: None,
|
return_full_text: None,
|
||||||
stop: Vec::new(),
|
stop: Vec::new(),
|
||||||
|
@ -165,6 +187,193 @@ fn default_parameters() -> GenerateParameters {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize)]
|
||||||
|
pub(crate) struct ChatCompletion {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub created: u64,
|
||||||
|
pub model: String,
|
||||||
|
pub system_fingerprint: String,
|
||||||
|
pub choices: Vec<ChatCompletionComplete>,
|
||||||
|
pub usage: Usage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize)]
|
||||||
|
pub(crate) struct ChatCompletionComplete {
|
||||||
|
pub index: u32,
|
||||||
|
pub message: Message,
|
||||||
|
pub logprobs: Option<Vec<f32>>,
|
||||||
|
pub finish_reason: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize)]
|
||||||
|
pub(crate) struct Usage {
|
||||||
|
pub prompt_tokens: u32,
|
||||||
|
pub completion_tokens: u32,
|
||||||
|
pub total_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChatCompletion {
|
||||||
|
pub(crate) fn new(
|
||||||
|
model: String,
|
||||||
|
system_fingerprint: String,
|
||||||
|
output: String,
|
||||||
|
created: u64,
|
||||||
|
details: Details,
|
||||||
|
return_logprobs: bool,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
id: String::new(),
|
||||||
|
object: "text_completion".into(),
|
||||||
|
created,
|
||||||
|
model,
|
||||||
|
system_fingerprint,
|
||||||
|
choices: vec![ChatCompletionComplete {
|
||||||
|
index: 0,
|
||||||
|
message: Message {
|
||||||
|
role: "assistant".into(),
|
||||||
|
content: output,
|
||||||
|
},
|
||||||
|
logprobs: return_logprobs
|
||||||
|
.then(|| details.tokens.iter().map(|t| t.logprob).collect()),
|
||||||
|
finish_reason: details.finish_reason.to_string(),
|
||||||
|
}],
|
||||||
|
usage: Usage {
|
||||||
|
prompt_tokens: details.prefill.len() as u32,
|
||||||
|
completion_tokens: details.generated_tokens,
|
||||||
|
total_tokens: details.prefill.len() as u32 + details.generated_tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize)]
|
||||||
|
pub(crate) struct ChatCompletionChunk {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub created: u64,
|
||||||
|
pub model: String,
|
||||||
|
pub system_fingerprint: String,
|
||||||
|
pub choices: Vec<ChatCompletionChoice>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize)]
|
||||||
|
pub(crate) struct ChatCompletionChoice {
|
||||||
|
pub index: u32,
|
||||||
|
pub delta: ChatCompletionDelta,
|
||||||
|
pub logprobs: Option<f32>,
|
||||||
|
pub finish_reason: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub(crate) struct ChatCompletionDelta {
|
||||||
|
pub role: String,
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChatCompletionChunk {
|
||||||
|
pub(crate) fn new(
|
||||||
|
model: String,
|
||||||
|
system_fingerprint: String,
|
||||||
|
delta: String,
|
||||||
|
created: u64,
|
||||||
|
index: u32,
|
||||||
|
logprobs: Option<f32>,
|
||||||
|
finish_reason: Option<String>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
id: String::new(),
|
||||||
|
object: "text_completion".to_string(),
|
||||||
|
created,
|
||||||
|
model,
|
||||||
|
system_fingerprint,
|
||||||
|
choices: vec![ChatCompletionChoice {
|
||||||
|
index,
|
||||||
|
delta: ChatCompletionDelta {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: delta,
|
||||||
|
},
|
||||||
|
logprobs,
|
||||||
|
finish_reason,
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_request_messages() -> Vec<Message> {
|
||||||
|
vec![Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "My name is David and I".to_string(),
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
|
pub(crate) struct ChatRequest {
|
||||||
|
/// UNUSED
|
||||||
|
#[schema(example = "bigscience/blomm-560m")]
|
||||||
|
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
||||||
|
pub model: String, /* NOTE: UNUSED */
|
||||||
|
|
||||||
|
/// A list of messages comprising the conversation so far.
|
||||||
|
#[serde(default = "default_request_messages")]
|
||||||
|
pub messages: Vec<Message>,
|
||||||
|
|
||||||
|
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
|
||||||
|
/// decreasing the model's likelihood to repeat the same line verbatim.
|
||||||
|
#[serde(default)]
|
||||||
|
pub frequency_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// UNUSED
|
||||||
|
/// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
|
||||||
|
/// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
|
||||||
|
/// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
|
||||||
|
/// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
|
||||||
|
/// result in a ban or exclusive selection of the relevant token.
|
||||||
|
#[serde(default)]
|
||||||
|
pub logit_bias: Option<Vec<f32>>,
|
||||||
|
|
||||||
|
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
|
||||||
|
/// output token returned in the content of message.
|
||||||
|
#[serde(default)]
|
||||||
|
pub logprobs: Option<bool>,
|
||||||
|
|
||||||
|
/// UNUSED
|
||||||
|
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
|
||||||
|
/// an associated log probability. logprobs must be set to true if this parameter is used.
|
||||||
|
#[serde(default)]
|
||||||
|
pub top_logprobs: Option<u32>,
|
||||||
|
|
||||||
|
/// The maximum number of tokens that can be generated in the chat completion.
|
||||||
|
#[serde(default)]
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
|
||||||
|
/// UNUSED
|
||||||
|
/// How many chat completion choices to generate for each input message. Note that you will be charged based on the
|
||||||
|
/// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
|
||||||
|
#[serde(default)]
|
||||||
|
pub n: Option<u32>,
|
||||||
|
|
||||||
|
/// UNUSED
|
||||||
|
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
|
||||||
|
/// increasing the model's likelihood to talk about new topics
|
||||||
|
#[serde(default)]
|
||||||
|
pub presence_penalty: Option<f32>,
|
||||||
|
|
||||||
|
#[serde(default = "bool::default")]
|
||||||
|
pub stream: bool,
|
||||||
|
|
||||||
|
#[schema(nullable = true, example = 42)]
|
||||||
|
pub seed: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
|
pub(crate) struct Message {
|
||||||
|
#[schema(example = "user")]
|
||||||
|
pub role: String,
|
||||||
|
#[schema(example = "My name is David and I")]
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||||
pub(crate) struct GenerateRequest {
|
pub(crate) struct GenerateRequest {
|
||||||
#[schema(example = "My name is Olivier and I")]
|
#[schema(example = "My name is Olivier and I")]
|
||||||
|
@ -227,6 +436,16 @@ pub(crate) enum FinishReason {
|
||||||
StopSequence,
|
StopSequence,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for FinishReason {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
FinishReason::Length => write!(f, "length"),
|
||||||
|
FinishReason::EndOfSequenceToken => write!(f, "eos_token"),
|
||||||
|
FinishReason::StopSequence => write!(f, "stop_sequence"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
pub(crate) struct BestOfSequence {
|
pub(crate) struct BestOfSequence {
|
||||||
#[schema(example = "test")]
|
#[schema(example = "test")]
|
||||||
|
@ -279,6 +498,7 @@ pub(crate) struct StreamDetails {
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
pub(crate) struct StreamResponse {
|
pub(crate) struct StreamResponse {
|
||||||
|
pub index: u32,
|
||||||
pub token: Token,
|
pub token: Token,
|
||||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||||
pub top_tokens: Vec<Token>,
|
pub top_tokens: Vec<Token>,
|
||||||
|
|
|
@ -8,13 +8,12 @@ use opentelemetry::sdk::trace::Sampler;
|
||||||
use opentelemetry::sdk::Resource;
|
use opentelemetry::sdk::Resource;
|
||||||
use opentelemetry::{global, KeyValue};
|
use opentelemetry::{global, KeyValue};
|
||||||
use opentelemetry_otlp::WithExportConfig;
|
use opentelemetry_otlp::WithExportConfig;
|
||||||
/// Text Generation Inference webserver entrypoint
|
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::BufReader;
|
use std::io::BufReader;
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use text_generation_client::{ClientError, ShardedClient};
|
use text_generation_client::{ClientError, ShardedClient};
|
||||||
use text_generation_router::{server, HubModelInfo};
|
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tower_http::cors::AllowOrigin;
|
use tower_http::cors::AllowOrigin;
|
||||||
|
@ -55,6 +54,8 @@ struct Args {
|
||||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||||
tokenizer_name: String,
|
tokenizer_name: String,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
|
tokenizer_config_path: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
#[clap(default_value = "2", long, env)]
|
#[clap(default_value = "2", long, env)]
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
|
@ -92,6 +93,7 @@ async fn main() -> Result<(), RouterError> {
|
||||||
port,
|
port,
|
||||||
master_shard_uds_path,
|
master_shard_uds_path,
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
|
tokenizer_config_path,
|
||||||
revision,
|
revision,
|
||||||
validation_workers,
|
validation_workers,
|
||||||
json_output,
|
json_output,
|
||||||
|
@ -149,40 +151,64 @@ async fn main() -> Result<(), RouterError> {
|
||||||
let local_path = Path::new(&tokenizer_name);
|
let local_path = Path::new(&tokenizer_name);
|
||||||
let local_model = local_path.exists() && local_path.is_dir();
|
let local_model = local_path.exists() && local_path.is_dir();
|
||||||
|
|
||||||
let (tokenizer, model_info) = if local_model {
|
// Load tokenizer config
|
||||||
// Get Model info
|
// This will be used to format the chat template
|
||||||
let model_info = HubModelInfo {
|
let local_tokenizer_config_path =
|
||||||
model_id: tokenizer_name.clone(),
|
tokenizer_config_path.unwrap_or("tokenizer_config.json".to_string());
|
||||||
sha: None,
|
let local_tokenizer_config = Path::new(&local_tokenizer_config_path).exists();
|
||||||
pipeline_tag: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Load local tokenizer
|
// Shared API builder initialization
|
||||||
let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok();
|
let api_builder = || {
|
||||||
|
|
||||||
(tokenizer, model_info)
|
|
||||||
} else {
|
|
||||||
let mut builder = ApiBuilder::new()
|
let mut builder = ApiBuilder::new()
|
||||||
.with_progress(false)
|
.with_progress(false)
|
||||||
.with_token(authorization_token);
|
.with_token(authorization_token);
|
||||||
|
|
||||||
if let Some(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE").ok() {
|
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
|
||||||
builder = builder.with_cache_dir(cache_dir.into());
|
builder = builder.with_cache_dir(cache_dir.into());
|
||||||
}
|
}
|
||||||
|
|
||||||
if revision.is_none() {
|
builder
|
||||||
tracing::warn!("`--revision` is not set");
|
};
|
||||||
tracing::warn!("We strongly advise to set it to a known supported commit.");
|
|
||||||
}
|
|
||||||
|
|
||||||
let api = builder.build().unwrap();
|
// Decide if we need to use the API based on the revision and local path
|
||||||
|
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
|
||||||
|
|
||||||
|
// Initialize API if needed
|
||||||
|
let api = if use_api {
|
||||||
|
tracing::info!("Using the Hugging Face API");
|
||||||
|
match api_builder().build() {
|
||||||
|
Ok(api) => Some(api),
|
||||||
|
Err(_) => {
|
||||||
|
tracing::warn!("Unable to build the Hugging Face API");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
// Load tokenizer and model info
|
||||||
|
let (tokenizer, model_info) = if local_model {
|
||||||
|
let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok();
|
||||||
|
let model_info = HubModelInfo {
|
||||||
|
model_id: tokenizer_name.to_string(),
|
||||||
|
sha: None,
|
||||||
|
pipeline_tag: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
(tokenizer, model_info)
|
||||||
|
} else if let Some(api) = api.clone() {
|
||||||
let api_repo = api.repo(Repo::with_revision(
|
let api_repo = api.repo(Repo::with_revision(
|
||||||
tokenizer_name.clone(),
|
tokenizer_name.to_string(),
|
||||||
RepoType::Model,
|
RepoType::Model,
|
||||||
revision.clone().unwrap_or("main".to_string()),
|
revision.clone().unwrap_or_else(|| "main".to_string()),
|
||||||
));
|
));
|
||||||
|
|
||||||
// Get Model info
|
let tokenizer = match api_repo.get("tokenizer.json").await {
|
||||||
|
Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(),
|
||||||
|
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
||||||
|
};
|
||||||
|
|
||||||
let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| {
|
let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| {
|
||||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||||
HubModelInfo {
|
HubModelInfo {
|
||||||
|
@ -192,12 +218,33 @@ async fn main() -> Result<(), RouterError> {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
let tokenizer = match api_repo.get("tokenizer.json").await {
|
(tokenizer, model_info)
|
||||||
Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(),
|
} else {
|
||||||
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
// No API and no local model
|
||||||
|
return Err(RouterError::ArgumentValidation(
|
||||||
|
"No local model found and no revision specified".to_string(),
|
||||||
|
));
|
||||||
};
|
};
|
||||||
|
|
||||||
(tokenizer, model_info)
|
// Load tokenizer config if found locally, or check if we can get it from the API if needed
|
||||||
|
let tokenizer_config = if local_tokenizer_config {
|
||||||
|
tracing::info!("Using local tokenizer config");
|
||||||
|
HubTokenizerConfig::from_file(&local_tokenizer_config_path)
|
||||||
|
} else if let Some(api) = api {
|
||||||
|
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
|
||||||
|
get_tokenizer_config(&api.repo(Repo::with_revision(
|
||||||
|
tokenizer_name.to_string(),
|
||||||
|
RepoType::Model,
|
||||||
|
revision.unwrap_or_else(|| "main".to_string()),
|
||||||
|
)))
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|| {
|
||||||
|
tracing::warn!("Could not retrieve tokenizer config from the Hugging Face hub.");
|
||||||
|
HubTokenizerConfig::default()
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
tracing::warn!("Could not find tokenizer config locally and no revision specified");
|
||||||
|
HubTokenizerConfig::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
if tokenizer.is_none() {
|
if tokenizer.is_none() {
|
||||||
|
@ -297,6 +344,7 @@ async fn main() -> Result<(), RouterError> {
|
||||||
ngrok,
|
ngrok,
|
||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
|
tokenizer_config,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -401,6 +449,20 @@ pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<Tokeniz
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// get tokenizer_config from the Huggingface Hub
|
||||||
|
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
|
||||||
|
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;
|
||||||
|
|
||||||
|
// Open the file in read-only mode with buffer.
|
||||||
|
let file = File::open(tokenizer_config_filename).ok()?;
|
||||||
|
let reader = BufReader::new(file);
|
||||||
|
|
||||||
|
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
||||||
|
let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader).ok()?;
|
||||||
|
|
||||||
|
Some(tokenizer_config)
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
enum RouterError {
|
enum RouterError {
|
||||||
#[error("Argument validation error: {0}")]
|
#[error("Argument validation error: {0}")]
|
||||||
|
|
|
@ -2,10 +2,11 @@
|
||||||
use crate::health::Health;
|
use crate::health::Health;
|
||||||
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
|
use crate::HubTokenizerConfig;
|
||||||
use crate::{
|
use crate::{
|
||||||
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
|
BestOfSequence, ChatCompletion, ChatCompletionChunk, ChatRequest, CompatGenerateRequest,
|
||||||
GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken,
|
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
|
||||||
StreamDetails, StreamResponse, Token, Validation,
|
HubModelInfo, Infer, Info, PrefillToken, StreamDetails, StreamResponse, Token, Validation,
|
||||||
};
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
|
@ -343,6 +344,21 @@ async fn generate_stream(
|
||||||
HeaderMap,
|
HeaderMap,
|
||||||
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
||||||
) {
|
) {
|
||||||
|
let on_message_callback = |stream_token: StreamResponse| {
|
||||||
|
let event = Event::default();
|
||||||
|
event.json_data(stream_token).unwrap()
|
||||||
|
};
|
||||||
|
let (headers, response_stream) =
|
||||||
|
generate_stream_internal(infer, Json(req), on_message_callback).await;
|
||||||
|
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||||
|
(headers, sse)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn generate_stream_internal(
|
||||||
|
infer: Infer,
|
||||||
|
Json(req): Json<GenerateRequest>,
|
||||||
|
on_message_callback: impl Fn(StreamResponse) -> Event,
|
||||||
|
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::increment_counter!("tgi_request_count");
|
||||||
|
@ -385,8 +401,10 @@ async fn generate_stream(
|
||||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||||
// Keep permit as long as generate_stream lives
|
// Keep permit as long as generate_stream lives
|
||||||
Ok((_permit, _input_length, mut response_stream)) => {
|
Ok((_permit, _input_length, mut response_stream)) => {
|
||||||
|
let mut index = 0;
|
||||||
// Server-Sent Event stream
|
// Server-Sent Event stream
|
||||||
while let Some(response) = response_stream.next().await {
|
while let Some(response) = response_stream.next().await {
|
||||||
|
index += 1;
|
||||||
match response {
|
match response {
|
||||||
Ok(response) => {
|
Ok(response) => {
|
||||||
match response {
|
match response {
|
||||||
|
@ -401,13 +419,14 @@ async fn generate_stream(
|
||||||
|
|
||||||
// StreamResponse
|
// StreamResponse
|
||||||
let stream_token = StreamResponse {
|
let stream_token = StreamResponse {
|
||||||
|
index,
|
||||||
token,
|
token,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
generated_text: None,
|
generated_text: None,
|
||||||
details: None,
|
details: None,
|
||||||
};
|
};
|
||||||
|
let event = on_message_callback(stream_token);
|
||||||
yield Ok(Event::default().json_data(stream_token).unwrap())
|
yield Ok(event);
|
||||||
}
|
}
|
||||||
// Yield event for last token and compute timings
|
// Yield event for last token and compute timings
|
||||||
InferStreamResponse::End {
|
InferStreamResponse::End {
|
||||||
|
@ -463,13 +482,16 @@ async fn generate_stream(
|
||||||
tracing::info!(parent: &span, "Success");
|
tracing::info!(parent: &span, "Success");
|
||||||
|
|
||||||
let stream_token = StreamResponse {
|
let stream_token = StreamResponse {
|
||||||
|
index,
|
||||||
token,
|
token,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
generated_text: Some(output_text),
|
generated_text: Some(output_text),
|
||||||
details
|
details
|
||||||
};
|
};
|
||||||
|
|
||||||
yield Ok(Event::default().json_data(stream_token).unwrap());
|
|
||||||
|
let event = on_message_callback(stream_token);
|
||||||
|
yield Ok(event);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -500,7 +522,154 @@ async fn generate_stream(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
(headers, Sse::new(stream).keep_alive(KeepAlive::default()))
|
(headers, stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate tokens
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/v1/chat/completions",
|
||||||
|
request_body = ChatRequest,
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Generated Text", body = GenerateResponse),
|
||||||
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Request failed during generation"})),
|
||||||
|
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Model is overloaded"})),
|
||||||
|
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Input validation error"})),
|
||||||
|
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Incomplete generation"})),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
#[instrument(
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
// parameters = ? req.parameters,
|
||||||
|
total_time,
|
||||||
|
validation_time,
|
||||||
|
queue_time,
|
||||||
|
inference_time,
|
||||||
|
time_per_token,
|
||||||
|
seed,
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
async fn chat_completions(
|
||||||
|
Extension(infer): Extension<Infer>,
|
||||||
|
Extension(info): Extension<Info>,
|
||||||
|
Json(req): Json<ChatRequest>,
|
||||||
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
metrics::increment_counter!("tgi_request_count");
|
||||||
|
|
||||||
|
let stream = req.stream;
|
||||||
|
let max_new_tokens = req.max_tokens.or(Some(100));
|
||||||
|
let repetition_penalty = req
|
||||||
|
.frequency_penalty
|
||||||
|
// rescale frequency_penalty from (-2.0, 2.0) to (0.0, 4.0)
|
||||||
|
.map(|x| x + 2.0);
|
||||||
|
let logprobs = req.logprobs.unwrap_or(false);
|
||||||
|
let seed = req.seed;
|
||||||
|
|
||||||
|
// apply chat template to flatten the request into a single input
|
||||||
|
let inputs = match infer.apply_chat_template(req) {
|
||||||
|
Ok(inputs) => inputs,
|
||||||
|
Err(err) => {
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
|
tracing::error!("{err}");
|
||||||
|
return Err((
|
||||||
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: err.to_string(),
|
||||||
|
error_type: err.error_type().to_string(),
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// build the request passing some parameters
|
||||||
|
let generate_request = GenerateRequest {
|
||||||
|
inputs: inputs.to_string(),
|
||||||
|
parameters: GenerateParameters {
|
||||||
|
best_of: None,
|
||||||
|
temperature: None,
|
||||||
|
repetition_penalty,
|
||||||
|
top_k: None,
|
||||||
|
top_p: None,
|
||||||
|
typical_p: None,
|
||||||
|
do_sample: true,
|
||||||
|
max_new_tokens,
|
||||||
|
return_full_text: None,
|
||||||
|
stop: Vec::new(),
|
||||||
|
truncate: None,
|
||||||
|
watermark: false,
|
||||||
|
details: true,
|
||||||
|
decoder_input_details: true,
|
||||||
|
seed,
|
||||||
|
top_n_tokens: None,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
// static values that will be returned in all cases
|
||||||
|
let model_id = info.model_id.clone();
|
||||||
|
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
||||||
|
|
||||||
|
// switch on stream
|
||||||
|
if stream {
|
||||||
|
// pass this callback to the stream generation and build the required event structure
|
||||||
|
let on_message_callback = move |stream_token: StreamResponse| {
|
||||||
|
let event = Event::default();
|
||||||
|
|
||||||
|
let current_time = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||||
|
.as_secs();
|
||||||
|
|
||||||
|
event
|
||||||
|
.json_data(ChatCompletionChunk::new(
|
||||||
|
model_id.clone(),
|
||||||
|
system_fingerprint.clone(),
|
||||||
|
stream_token.token.text,
|
||||||
|
current_time,
|
||||||
|
stream_token.index,
|
||||||
|
logprobs.then_some(stream_token.token.logprob),
|
||||||
|
stream_token.details.map(|d| d.finish_reason.to_string()),
|
||||||
|
))
|
||||||
|
.map_or_else(
|
||||||
|
|e| {
|
||||||
|
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
||||||
|
Event::default()
|
||||||
|
},
|
||||||
|
|data| data,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let (headers, response_stream) =
|
||||||
|
generate_stream_internal(infer, Json(generate_request), on_message_callback).await;
|
||||||
|
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||||
|
Ok((headers, sse).into_response())
|
||||||
|
} else {
|
||||||
|
let (headers, Json(generation)) =
|
||||||
|
generate(Extension(infer), Json(generate_request)).await?;
|
||||||
|
|
||||||
|
let current_time = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||||
|
.as_secs();
|
||||||
|
|
||||||
|
// build the complete response object with the full text
|
||||||
|
let response = ChatCompletion::new(
|
||||||
|
generation.generated_text,
|
||||||
|
model_id,
|
||||||
|
system_fingerprint,
|
||||||
|
current_time,
|
||||||
|
generation.details.unwrap(),
|
||||||
|
logprobs,
|
||||||
|
);
|
||||||
|
|
||||||
|
// wrap generation inside a Vec to match api-inference
|
||||||
|
Ok((headers, Json(response)).into_response())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Prometheus metrics scrape endpoint
|
/// Prometheus metrics scrape endpoint
|
||||||
|
@ -538,6 +707,7 @@ pub async fn run(
|
||||||
ngrok: bool,
|
ngrok: bool,
|
||||||
ngrok_authtoken: Option<String>,
|
ngrok_authtoken: Option<String>,
|
||||||
ngrok_edge: Option<String>,
|
ngrok_edge: Option<String>,
|
||||||
|
tokenizer_config: HubTokenizerConfig,
|
||||||
) -> Result<(), axum::BoxError> {
|
) -> Result<(), axum::BoxError> {
|
||||||
// OpenAPI documentation
|
// OpenAPI documentation
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
|
@ -604,6 +774,7 @@ pub async fn run(
|
||||||
shard_info.window_size,
|
shard_info.window_size,
|
||||||
shard_info.speculate,
|
shard_info.speculate,
|
||||||
generation_health,
|
generation_health,
|
||||||
|
tokenizer_config,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Duration buckets
|
// Duration buckets
|
||||||
|
@ -693,6 +864,7 @@ pub async fn run(
|
||||||
.route("/info", get(get_model_info))
|
.route("/info", get(get_model_info))
|
||||||
.route("/generate", post(generate))
|
.route("/generate", post(generate))
|
||||||
.route("/generate_stream", post(generate_stream))
|
.route("/generate_stream", post(generate_stream))
|
||||||
|
.route("/v1/chat/completions", post(chat_completions))
|
||||||
// AWS Sagemaker route
|
// AWS Sagemaker route
|
||||||
.route("/invocations", post(compat_generate))
|
.route("/invocations", post(compat_generate))
|
||||||
// Base Health route
|
// Base Health route
|
||||||
|
@ -822,6 +994,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
||||||
InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
|
InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
|
||||||
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
|
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
};
|
};
|
||||||
|
|
||||||
(
|
(
|
||||||
|
|
|
@ -376,7 +376,7 @@ type TokenizerRequest = (
|
||||||
Span,
|
Span,
|
||||||
);
|
);
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct ValidGenerateRequest {
|
pub(crate) struct ValidGenerateRequest {
|
||||||
pub inputs: String,
|
pub inputs: String,
|
||||||
pub input_length: u32,
|
pub input_length: u32,
|
||||||
|
|
Loading…
Reference in New Issue