diff --git a/Cargo.lock b/Cargo.lock index f826ea34..3baff665 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -773,9 +773,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff4dd66668b557604244583e3e1e1eada8c5c2e96a6d0d6653ede395b78bbacb" +checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" dependencies = [ "futures-core", "futures-sink", @@ -783,9 +783,9 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" [[package]] name = "futures-executor" @@ -800,15 +800,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bf34a163b5c4c52d0478a4d757da8fb65cabef42ba90515efee0f6f9fa45aaa" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" [[package]] name = "futures-macro" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", @@ -817,21 +817,21 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e36d3378ee38c2a36ad710c5d30c2911d752cb941c00c72dbabfb786a7970817" +checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" [[package]] name = "futures-task" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" [[package]] name = "futures-util" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ "futures-channel", "futures-core", @@ -1373,6 +1373,15 @@ dependencies = [ "unicase", ] +[[package]] +name = "minijinja" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "208758577ef2c86cf5dd3e85730d161413ec3284e2d73b2ef65d9a24d9971bcb" +dependencies = [ + "serde", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -2807,10 +2816,12 @@ dependencies = [ "axum-tracing-opentelemetry", "clap", "futures", + "futures-util", "hf-hub", "init-tracing-opentelemetry", "metrics", "metrics-exporter-prometheus", + "minijinja", "ngrok", "nohash-hasher", "opentelemetry", diff --git a/router/Cargo.toml b/router/Cargo.toml index 5ccdb0cd..f6f16dae 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -43,6 +43,8 @@ utoipa = { version = "3.5.0", features = ["axum_extras"] } utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] } ngrok = { version = "0.13.1", features = ["axum"], optional = true } init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } +minijinja = "1.0.10" +futures-util = "0.3.30" [build-dependencies] vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } diff --git a/router/src/infer.rs b/router/src/infer.rs index b4094c1b..6de07982 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -1,8 +1,10 @@ /// Batching and inference logic use crate::validation::{Validation, ValidationError}; +use crate::HubTokenizerConfig; +use crate::{ChatRequest, GenerateRequest, GenerateStreamResponse, PrefillToken}; use crate::{Entry, Queue, Token}; -use crate::{GenerateRequest, PrefillToken}; use futures::future::try_join_all; +use minijinja::{Environment, ErrorKind, Template}; use nohash_hasher::IntMap; use std::sync::{ atomic::{AtomicBool, Ordering}, @@ -13,7 +15,7 @@ use text_generation_client::{ }; use thiserror::Error; use tokio::sync::mpsc::error::SendError; -use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; +use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::StreamExt; @@ -30,6 +32,8 @@ pub struct Infer { shared: Arc, /// Inference limit limit_concurrent_requests: Arc, + /// Chat template + template: Option>, } /// Infer shared state @@ -52,6 +56,7 @@ impl Infer { window_size: Option, speculate: u32, generation_health: Arc, + tokenizer_config: HubTokenizerConfig, ) -> Self { // Infer shared state let queue = Queue::new(requires_padding, 16, window_size, speculate); @@ -74,11 +79,21 @@ impl Infer { // Inference limit with a semaphore let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); + let template = tokenizer_config.chat_template.map(|t| { + let env = Box::new(Environment::new()); + let template_str = t.into_boxed_str(); + // leaking env and template_str as read-only, static resources for performance. + Box::leak(env) + .template_from_str(Box::leak(template_str)) + .unwrap() + }); + Self { validation, queue, shared, limit_concurrent_requests: semaphore, + template, } } @@ -87,14 +102,7 @@ impl Infer { pub(crate) async fn generate_stream( &self, request: GenerateRequest, - ) -> Result< - ( - OwnedSemaphorePermit, - u32, - UnboundedReceiverStream>, - ), - InferError, - > { + ) -> Result { // Limit concurrent requests by acquiring a permit from the semaphore let permit = self .clone() @@ -139,6 +147,20 @@ impl Infer { )) } + /// Apply the chat template to the chat request + #[instrument(skip_all)] + pub(crate) fn apply_chat_template(&self, chat: ChatRequest) -> Result { + self.template + .as_ref() + .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? + .render(chat) + .map_err(|e| { + metrics::increment_counter!("tgi_request_failure", "err" => "template"); + tracing::error!("{e}"); + InferError::TemplateError(e) + }) + } + /// Add a new request to the queue and return a InferResponse #[instrument(skip_all)] pub(crate) async fn generate( @@ -550,9 +572,9 @@ fn send_responses( let mut iterator = tokens_ .ids .into_iter() - .zip(tokens_.logprobs.into_iter()) - .zip(tokens_.texts.into_iter()) - .zip(tokens_.is_special.into_iter()) + .zip(tokens_.logprobs) + .zip(tokens_.texts) + .zip(tokens_.is_special) .enumerate() .peekable(); while let Some((i, (((id, logprob), text), special))) = iterator.next() { @@ -665,6 +687,8 @@ pub enum InferError { ValidationError(#[from] ValidationError), #[error("Incomplete generation")] IncompleteGeneration, + #[error("Template error: {0}")] + TemplateError(#[from] minijinja::Error), } impl InferError { @@ -674,6 +698,7 @@ impl InferError { InferError::Overloaded(_) => "overloaded", InferError::ValidationError(_) => "validation", InferError::IncompleteGeneration => "incomplete_generation", + InferError::TemplateError(_) => "template_error", } } } diff --git a/router/src/lib.rs b/router/src/lib.rs index 898fcd04..f6f8276f 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -5,12 +5,21 @@ mod queue; pub mod server; mod validation; -use infer::Infer; +use infer::{Infer, InferError, InferStreamResponse}; use queue::{Entry, Queue}; use serde::{Deserialize, Serialize}; +use tokio::sync::OwnedSemaphorePermit; +use tokio_stream::wrappers::UnboundedReceiverStream; use utoipa::ToSchema; use validation::Validation; +/// Type alias for generation responses +pub(crate) type GenerateStreamResponse = ( + OwnedSemaphorePermit, + u32, // input_length + UnboundedReceiverStream>, +); + /// Hub type #[derive(Clone, Debug, Deserialize)] pub struct HubModelInfo { @@ -20,6 +29,19 @@ pub struct HubModelInfo { pub pipeline_tag: Option, } +#[derive(Clone, Deserialize, Default)] +pub struct HubTokenizerConfig { + #[serde(default)] + pub chat_template: Option, +} + +impl HubTokenizerConfig { + pub fn from_file(filename: &str) -> Self { + let content = std::fs::read_to_string(filename).unwrap(); + serde_json::from_str(&content).unwrap_or_default() + } +} + #[derive(Clone, Debug, Serialize, ToSchema)] pub struct Info { /// Model info @@ -152,7 +174,7 @@ fn default_parameters() -> GenerateParameters { top_k: None, top_p: None, typical_p: None, - do_sample: false, + do_sample: true, max_new_tokens: default_max_new_tokens(), return_full_text: None, stop: Vec::new(), @@ -165,6 +187,193 @@ fn default_parameters() -> GenerateParameters { } } +#[derive(Clone, Deserialize, Serialize)] +pub(crate) struct ChatCompletion { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub system_fingerprint: String, + pub choices: Vec, + pub usage: Usage, +} + +#[derive(Clone, Deserialize, Serialize)] +pub(crate) struct ChatCompletionComplete { + pub index: u32, + pub message: Message, + pub logprobs: Option>, + pub finish_reason: String, +} + +#[derive(Clone, Deserialize, Serialize)] +pub(crate) struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +impl ChatCompletion { + pub(crate) fn new( + model: String, + system_fingerprint: String, + output: String, + created: u64, + details: Details, + return_logprobs: bool, + ) -> Self { + Self { + id: String::new(), + object: "text_completion".into(), + created, + model, + system_fingerprint, + choices: vec![ChatCompletionComplete { + index: 0, + message: Message { + role: "assistant".into(), + content: output, + }, + logprobs: return_logprobs + .then(|| details.tokens.iter().map(|t| t.logprob).collect()), + finish_reason: details.finish_reason.to_string(), + }], + usage: Usage { + prompt_tokens: details.prefill.len() as u32, + completion_tokens: details.generated_tokens, + total_tokens: details.prefill.len() as u32 + details.generated_tokens, + }, + } + } +} + +#[derive(Clone, Deserialize, Serialize)] +pub(crate) struct ChatCompletionChunk { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub system_fingerprint: String, + pub choices: Vec, +} + +#[derive(Clone, Deserialize, Serialize)] +pub(crate) struct ChatCompletionChoice { + pub index: u32, + pub delta: ChatCompletionDelta, + pub logprobs: Option, + pub finish_reason: Option, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub(crate) struct ChatCompletionDelta { + pub role: String, + pub content: String, +} + +impl ChatCompletionChunk { + pub(crate) fn new( + model: String, + system_fingerprint: String, + delta: String, + created: u64, + index: u32, + logprobs: Option, + finish_reason: Option, + ) -> Self { + Self { + id: String::new(), + object: "text_completion".to_string(), + created, + model, + system_fingerprint, + choices: vec![ChatCompletionChoice { + index, + delta: ChatCompletionDelta { + role: "assistant".to_string(), + content: delta, + }, + logprobs, + finish_reason, + }], + } + } +} + +fn default_request_messages() -> Vec { + vec![Message { + role: "user".to_string(), + content: "My name is David and I".to_string(), + }] +} + +#[derive(Clone, Deserialize, ToSchema, Serialize)] +pub(crate) struct ChatRequest { + /// UNUSED + #[schema(example = "bigscience/blomm-560m")] + /// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. + pub model: String, /* NOTE: UNUSED */ + + /// A list of messages comprising the conversation so far. + #[serde(default = "default_request_messages")] + pub messages: Vec, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, + /// decreasing the model's likelihood to repeat the same line verbatim. + #[serde(default)] + pub frequency_penalty: Option, + + /// UNUSED + /// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens + /// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, + /// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, + /// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should + /// result in a ban or exclusive selection of the relevant token. + #[serde(default)] + pub logit_bias: Option>, + + /// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each + /// output token returned in the content of message. + #[serde(default)] + pub logprobs: Option, + + /// UNUSED + /// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with + /// an associated log probability. logprobs must be set to true if this parameter is used. + #[serde(default)] + pub top_logprobs: Option, + + /// The maximum number of tokens that can be generated in the chat completion. + #[serde(default)] + pub max_tokens: Option, + + /// UNUSED + /// How many chat completion choices to generate for each input message. Note that you will be charged based on the + /// number of generated tokens across all of the choices. Keep n as 1 to minimize costs. + #[serde(default)] + pub n: Option, + + /// UNUSED + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, + /// increasing the model's likelihood to talk about new topics + #[serde(default)] + pub presence_penalty: Option, + + #[serde(default = "bool::default")] + pub stream: bool, + + #[schema(nullable = true, example = 42)] + pub seed: Option, +} + +#[derive(Clone, Deserialize, ToSchema, Serialize)] +pub(crate) struct Message { + #[schema(example = "user")] + pub role: String, + #[schema(example = "My name is David and I")] + pub content: String, +} + #[derive(Clone, Debug, Deserialize, ToSchema)] pub(crate) struct GenerateRequest { #[schema(example = "My name is Olivier and I")] @@ -227,6 +436,16 @@ pub(crate) enum FinishReason { StopSequence, } +impl std::fmt::Display for FinishReason { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + FinishReason::Length => write!(f, "length"), + FinishReason::EndOfSequenceToken => write!(f, "eos_token"), + FinishReason::StopSequence => write!(f, "stop_sequence"), + } + } +} + #[derive(Serialize, ToSchema)] pub(crate) struct BestOfSequence { #[schema(example = "test")] @@ -279,6 +498,7 @@ pub(crate) struct StreamDetails { #[derive(Serialize, ToSchema)] pub(crate) struct StreamResponse { + pub index: u32, pub token: Token, #[serde(skip_serializing_if = "Vec::is_empty")] pub top_tokens: Vec, diff --git a/router/src/main.rs b/router/src/main.rs index 4637c77c..f5d44305 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -8,13 +8,12 @@ use opentelemetry::sdk::trace::Sampler; use opentelemetry::sdk::Resource; use opentelemetry::{global, KeyValue}; use opentelemetry_otlp::WithExportConfig; -/// Text Generation Inference webserver entrypoint use std::fs::File; use std::io::BufReader; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::Path; use text_generation_client::{ClientError, ShardedClient}; -use text_generation_router::{server, HubModelInfo}; +use text_generation_router::{server, HubModelInfo, HubTokenizerConfig}; use thiserror::Error; use tokenizers::Tokenizer; use tower_http::cors::AllowOrigin; @@ -55,6 +54,8 @@ struct Args { #[clap(default_value = "bigscience/bloom", long, env)] tokenizer_name: String, #[clap(long, env)] + tokenizer_config_path: Option, + #[clap(long, env)] revision: Option, #[clap(default_value = "2", long, env)] validation_workers: usize, @@ -92,6 +93,7 @@ async fn main() -> Result<(), RouterError> { port, master_shard_uds_path, tokenizer_name, + tokenizer_config_path, revision, validation_workers, json_output, @@ -149,40 +151,64 @@ async fn main() -> Result<(), RouterError> { let local_path = Path::new(&tokenizer_name); let local_model = local_path.exists() && local_path.is_dir(); - let (tokenizer, model_info) = if local_model { - // Get Model info - let model_info = HubModelInfo { - model_id: tokenizer_name.clone(), - sha: None, - pipeline_tag: None, - }; + // Load tokenizer config + // This will be used to format the chat template + let local_tokenizer_config_path = + tokenizer_config_path.unwrap_or("tokenizer_config.json".to_string()); + let local_tokenizer_config = Path::new(&local_tokenizer_config_path).exists(); - // Load local tokenizer - let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok(); - - (tokenizer, model_info) - } else { + // Shared API builder initialization + let api_builder = || { let mut builder = ApiBuilder::new() .with_progress(false) .with_token(authorization_token); - if let Some(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE").ok() { + if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") { builder = builder.with_cache_dir(cache_dir.into()); } - if revision.is_none() { - tracing::warn!("`--revision` is not set"); - tracing::warn!("We strongly advise to set it to a known supported commit."); - } + builder + }; - let api = builder.build().unwrap(); + // Decide if we need to use the API based on the revision and local path + let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir(); + + // Initialize API if needed + let api = if use_api { + tracing::info!("Using the Hugging Face API"); + match api_builder().build() { + Ok(api) => Some(api), + Err(_) => { + tracing::warn!("Unable to build the Hugging Face API"); + None + } + } + } else { + None + }; + + // Load tokenizer and model info + let (tokenizer, model_info) = if local_model { + let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok(); + let model_info = HubModelInfo { + model_id: tokenizer_name.to_string(), + sha: None, + pipeline_tag: None, + }; + + (tokenizer, model_info) + } else if let Some(api) = api.clone() { let api_repo = api.repo(Repo::with_revision( - tokenizer_name.clone(), + tokenizer_name.to_string(), RepoType::Model, - revision.clone().unwrap_or("main".to_string()), + revision.clone().unwrap_or_else(|| "main".to_string()), )); - // Get Model info + let tokenizer = match api_repo.get("tokenizer.json").await { + Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(), + Err(_) => get_base_tokenizer(&api, &api_repo).await, + }; + let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| { tracing::warn!("Could not retrieve model info from the Hugging Face hub."); HubModelInfo { @@ -192,12 +218,33 @@ async fn main() -> Result<(), RouterError> { } }); - let tokenizer = match api_repo.get("tokenizer.json").await { - Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(), - Err(_) => get_base_tokenizer(&api, &api_repo).await, - }; - (tokenizer, model_info) + } else { + // No API and no local model + return Err(RouterError::ArgumentValidation( + "No local model found and no revision specified".to_string(), + )); + }; + + // Load tokenizer config if found locally, or check if we can get it from the API if needed + let tokenizer_config = if local_tokenizer_config { + tracing::info!("Using local tokenizer config"); + HubTokenizerConfig::from_file(&local_tokenizer_config_path) + } else if let Some(api) = api { + tracing::info!("Using the Hugging Face API to retrieve tokenizer config"); + get_tokenizer_config(&api.repo(Repo::with_revision( + tokenizer_name.to_string(), + RepoType::Model, + revision.unwrap_or_else(|| "main".to_string()), + ))) + .await + .unwrap_or_else(|| { + tracing::warn!("Could not retrieve tokenizer config from the Hugging Face hub."); + HubTokenizerConfig::default() + }) + } else { + tracing::warn!("Could not find tokenizer config locally and no revision specified"); + HubTokenizerConfig::default() }; if tokenizer.is_none() { @@ -297,6 +344,7 @@ async fn main() -> Result<(), RouterError> { ngrok, ngrok_authtoken, ngrok_edge, + tokenizer_config, ) .await?; Ok(()) @@ -401,6 +449,20 @@ pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option Option { + let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?; + + // Open the file in read-only mode with buffer. + let file = File::open(tokenizer_config_filename).ok()?; + let reader = BufReader::new(file); + + // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. + let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader).ok()?; + + Some(tokenizer_config) +} + #[derive(Debug, Error)] enum RouterError { #[error("Argument validation error: {0}")] diff --git a/router/src/server.rs b/router/src/server.rs index 3db5c7cd..fe1827c4 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -2,10 +2,11 @@ use crate::health::Health; use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::validation::ValidationError; +use crate::HubTokenizerConfig; use crate::{ - BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason, - GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken, - StreamDetails, StreamResponse, Token, Validation, + BestOfSequence, ChatCompletion, ChatCompletionChunk, ChatRequest, CompatGenerateRequest, + Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, + HubModelInfo, Infer, Info, PrefillToken, StreamDetails, StreamResponse, Token, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -343,6 +344,21 @@ async fn generate_stream( HeaderMap, Sse>>, ) { + let on_message_callback = |stream_token: StreamResponse| { + let event = Event::default(); + event.json_data(stream_token).unwrap() + }; + let (headers, response_stream) = + generate_stream_internal(infer, Json(req), on_message_callback).await; + let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); + (headers, sse) +} + +async fn generate_stream_internal( + infer: Infer, + Json(req): Json, + on_message_callback: impl Fn(StreamResponse) -> Event, +) -> (HeaderMap, impl Stream>) { let span = tracing::Span::current(); let start_time = Instant::now(); metrics::increment_counter!("tgi_request_count"); @@ -385,8 +401,10 @@ async fn generate_stream( match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { // Keep permit as long as generate_stream lives Ok((_permit, _input_length, mut response_stream)) => { + let mut index = 0; // Server-Sent Event stream while let Some(response) = response_stream.next().await { + index += 1; match response { Ok(response) => { match response { @@ -401,13 +419,14 @@ async fn generate_stream( // StreamResponse let stream_token = StreamResponse { + index, token, top_tokens, generated_text: None, details: None, }; - - yield Ok(Event::default().json_data(stream_token).unwrap()) + let event = on_message_callback(stream_token); + yield Ok(event); } // Yield event for last token and compute timings InferStreamResponse::End { @@ -463,13 +482,16 @@ async fn generate_stream( tracing::info!(parent: &span, "Success"); let stream_token = StreamResponse { + index, token, top_tokens, generated_text: Some(output_text), details }; - yield Ok(Event::default().json_data(stream_token).unwrap()); + + let event = on_message_callback(stream_token); + yield Ok(event); break; } } @@ -500,7 +522,154 @@ async fn generate_stream( } }; - (headers, Sse::new(stream).keep_alive(KeepAlive::default())) + (headers, stream) +} + +/// Generate tokens +#[utoipa::path( + post, + tag = "Text Generation Inference", + path = "/v1/chat/completions", + request_body = ChatRequest, + responses( + (status = 200, description = "Generated Text", body = GenerateResponse), + (status = 424, description = "Generation Error", body = ErrorResponse, + example = json ! ({"error": "Request failed during generation"})), + (status = 429, description = "Model is overloaded", body = ErrorResponse, + example = json ! ({"error": "Model is overloaded"})), + (status = 422, description = "Input validation error", body = ErrorResponse, + example = json ! ({"error": "Input validation error"})), + (status = 500, description = "Incomplete generation", body = ErrorResponse, + example = json ! ({"error": "Incomplete generation"})), + ) + )] +#[instrument( + skip_all, + fields( + // parameters = ? req.parameters, + total_time, + validation_time, + queue_time, + inference_time, + time_per_token, + seed, + ) + )] +async fn chat_completions( + Extension(infer): Extension, + Extension(info): Extension, + Json(req): Json, +) -> Result)> { + metrics::increment_counter!("tgi_request_count"); + + let stream = req.stream; + let max_new_tokens = req.max_tokens.or(Some(100)); + let repetition_penalty = req + .frequency_penalty + // rescale frequency_penalty from (-2.0, 2.0) to (0.0, 4.0) + .map(|x| x + 2.0); + let logprobs = req.logprobs.unwrap_or(false); + let seed = req.seed; + + // apply chat template to flatten the request into a single input + let inputs = match infer.apply_chat_template(req) { + Ok(inputs) => inputs, + Err(err) => { + metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + tracing::error!("{err}"); + return Err(( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: err.to_string(), + error_type: err.error_type().to_string(), + }), + )); + } + }; + + // build the request passing some parameters + let generate_request = GenerateRequest { + inputs: inputs.to_string(), + parameters: GenerateParameters { + best_of: None, + temperature: None, + repetition_penalty, + top_k: None, + top_p: None, + typical_p: None, + do_sample: true, + max_new_tokens, + return_full_text: None, + stop: Vec::new(), + truncate: None, + watermark: false, + details: true, + decoder_input_details: true, + seed, + top_n_tokens: None, + }, + }; + + // static values that will be returned in all cases + let model_id = info.model_id.clone(); + let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native")); + + // switch on stream + if stream { + // pass this callback to the stream generation and build the required event structure + let on_message_callback = move |stream_token: StreamResponse| { + let event = Event::default(); + + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + + event + .json_data(ChatCompletionChunk::new( + model_id.clone(), + system_fingerprint.clone(), + stream_token.token.text, + current_time, + stream_token.index, + logprobs.then_some(stream_token.token.logprob), + stream_token.details.map(|d| d.finish_reason.to_string()), + )) + .map_or_else( + |e| { + println!("Failed to serialize ChatCompletionChunk: {:?}", e); + Event::default() + }, + |data| data, + ) + }; + + let (headers, response_stream) = + generate_stream_internal(infer, Json(generate_request), on_message_callback).await; + let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); + Ok((headers, sse).into_response()) + } else { + let (headers, Json(generation)) = + generate(Extension(infer), Json(generate_request)).await?; + + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + + // build the complete response object with the full text + let response = ChatCompletion::new( + generation.generated_text, + model_id, + system_fingerprint, + current_time, + generation.details.unwrap(), + logprobs, + ); + + // wrap generation inside a Vec to match api-inference + Ok((headers, Json(response)).into_response()) + } } /// Prometheus metrics scrape endpoint @@ -538,6 +707,7 @@ pub async fn run( ngrok: bool, ngrok_authtoken: Option, ngrok_edge: Option, + tokenizer_config: HubTokenizerConfig, ) -> Result<(), axum::BoxError> { // OpenAPI documentation #[derive(OpenApi)] @@ -604,6 +774,7 @@ pub async fn run( shard_info.window_size, shard_info.speculate, generation_health, + tokenizer_config, ); // Duration buckets @@ -693,6 +864,7 @@ pub async fn run( .route("/info", get(get_model_info)) .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) + .route("/v1/chat/completions", post(chat_completions)) // AWS Sagemaker route .route("/invocations", post(compat_generate)) // Base Health route @@ -822,6 +994,7 @@ impl From for (StatusCode, Json) { InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS, InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, + InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, }; ( diff --git a/router/src/validation.rs b/router/src/validation.rs index 64f25c82..370e9588 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -376,7 +376,7 @@ type TokenizerRequest = ( Span, ); -#[derive(Debug)] +#[derive(Debug, Clone)] pub(crate) struct ValidGenerateRequest { pub inputs: String, pub input_length: u32,