From 230f2a415ac3b011197a7d7ac4e6d1827f33d72d Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 26 Jun 2024 14:12:01 +0200 Subject: [PATCH] refacto --- backends/v3/src/backend.rs | 28 +- backends/v3/src/client/client.rs | 2 +- backends/v3/src/client/mod.rs | 1 - backends/v3/src/client/sharded_client.rs | 13 +- backends/v3/src/lib.rs | 13 +- backends/v3/src/main.rs | 372 ++--------------------- backends/v3/src/queue.rs | 36 ++- router/src/infer/chat_template.rs | 1 - router/src/infer/mod.rs | 21 +- router/src/infer/v2/scheduler.rs | 2 +- router/src/lib.rs | 2 +- router/src/logging.rs | 81 +++++ router/src/server.rs | 280 ++++++++++++++++- router/src/validation.rs | 2 +- 14 files changed, 430 insertions(+), 424 deletions(-) create mode 100644 router/src/logging.rs diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index bfe587f4..dbc27c31 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -1,21 +1,17 @@ +use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient}; /// Batching and inference logic use crate::queue::{Entry, Queue}; -use text_generation_router::infer::{ - GeneratedText, InferError, InferStreamResponse, Backend, -}; +use async_trait::async_trait; +use nohash_hasher::IntMap; +use std::sync::Arc; +use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::validation::ValidGenerateRequest; use text_generation_router::{FinishReason, PrefillToken, Token}; -use nohash_hasher::IntMap; -use std::sync::{ - Arc, -}; -use crate::client::{Batch, CachedBatch, Generation, ShardedClient, ClientError, Health}; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Notify}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{info_span, instrument, Instrument, Span}; -use async_trait::async_trait; pub struct BackendV3 { /// Request queue @@ -94,9 +90,7 @@ impl Backend for BackendV3 { self.batching_task_notifier.notify_one(); // Return stream - Ok( - UnboundedReceiverStream::new(response_rx), - ) + Ok(UnboundedReceiverStream::new(response_rx)) } async fn health(&self, current_health: bool) -> bool { @@ -193,10 +187,9 @@ pub(crate) async fn batching_task( }); // Generate one token for this new batch to have the attention past in cache - let new_cached_batch = - prefill(&mut client, new_batch, &mut new_entries) - .instrument(span) - .await; + let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries) + .instrument(span) + .await; // Reset waiting counter waiting_tokens = 1; // Extend current batch with the new batch @@ -480,8 +473,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { impl From for GeneratedText { fn from(value: crate::client::GeneratedText) -> Self { - let v3_finish_reason = - crate::client::FinishReason::try_from(value.finish_reason).unwrap(); + let v3_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap(); let finish_reason = match v3_finish_reason { crate::client::FinishReason::Length => FinishReason::Length, crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken, diff --git a/backends/v3/src/client/client.rs b/backends/v3/src/client/client.rs index 242a82d9..c407687b 100644 --- a/backends/v3/src/client/client.rs +++ b/backends/v3/src/client/client.rs @@ -1,5 +1,4 @@ /// Single shard Client - use crate::client::{pb, Chunk}; use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64}; use base64::engine::general_purpose::STANDARD; @@ -20,6 +19,7 @@ pub struct Client { impl Client { /// Returns a client connected to the given url + #[allow(dead_code)] pub async fn connect(uri: Uri) -> Result { let channel = Channel::builder(uri).connect().await?; diff --git a/backends/v3/src/client/mod.rs b/backends/v3/src/client/mod.rs index 4099ff87..a5eb7625 100644 --- a/backends/v3/src/client/mod.rs +++ b/backends/v3/src/client/mod.rs @@ -1,6 +1,5 @@ //! Text Generation gRPC client library - use async_trait::async_trait; use thiserror::Error; use tonic::transport; diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index 32365648..b4840deb 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -1,17 +1,17 @@ +use crate::client::{ClientError, Result}; /// Multi shard Client use crate::client::{Health, ShardInfo}; -use crate::client::{ClientError, Result}; -use crate::client::{Chunk, InfoResponse, Input}; -use async_trait::async_trait; -use futures::future::join_all; -use tonic::transport::Uri; -use tracing::instrument; use crate::client::client::{DecodeTimings, PrefillTimings}; use crate::client::{ Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; +use crate::client::{Chunk, InfoResponse, Input}; +use async_trait::async_trait; +use futures::future::join_all; +use tonic::transport::Uri; +use tracing::instrument; #[derive(Debug, Clone)] /// Text Generation Inference gRPC multi client @@ -35,6 +35,7 @@ impl ShardedClient { } /// Returns a client connected to the given uri + #[allow(dead_code)] pub async fn connect(uri: Uri) -> Result { let master_client = Client::connect(uri).await?; Self::from_master_client(master_client).await diff --git a/backends/v3/src/lib.rs b/backends/v3/src/lib.rs index cd4b3b0a..0a9614aa 100644 --- a/backends/v3/src/lib.rs +++ b/backends/v3/src/lib.rs @@ -1,13 +1,13 @@ -mod block_allocator; -mod queue; mod backend; +mod block_allocator; mod client; +mod queue; +use crate::client::{ClientError, ShardedClient}; +pub(crate) use backend::BackendV3; use serde::Serialize; use thiserror::Error; use utoipa::ToSchema; -pub(crate) use backend::BackendV3; -use crate::client::{ShardedClient, ClientError}; #[derive(Clone, Debug, Serialize, ToSchema)] pub struct BackendInfo { @@ -31,7 +31,8 @@ pub struct BackendInfo { } pub async fn connect_backend( - max_input_tokens: usize, max_total_tokens: usize, + max_input_tokens: usize, + max_total_tokens: usize, master_shard_uds_path: String, waiting_served_ratio: f32, max_batch_prefill_tokens: u32, @@ -137,4 +138,4 @@ pub enum V3Error { Warmup(ClientError), #[error("Not enough memory to handle `max_total_tokens={0}`")] NotEnoughMemory(usize), -} \ No newline at end of file +} diff --git a/backends/v3/src/main.rs b/backends/v3/src/main.rs index e6971ac8..a97d7a44 100644 --- a/backends/v3/src/main.rs +++ b/backends/v3/src/main.rs @@ -1,26 +1,7 @@ -use axum::http::HeaderValue; use clap::Parser; -use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; -use hf_hub::{Cache, Repo, RepoType}; -use opentelemetry::sdk::propagation::TraceContextPropagator; -use opentelemetry::sdk::trace; -use opentelemetry::sdk::trace::Sampler; -use opentelemetry::sdk::Resource; -use opentelemetry::{global, KeyValue}; -use opentelemetry_otlp::WithExportConfig; -use std::fs::File; -use std::io::BufReader; -use std::net::{IpAddr, Ipv4Addr, SocketAddr}; -use std::path::{Path, PathBuf}; -use text_generation_router::config::Config; -use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig}; -use thiserror::Error; -use tokenizers::Tokenizer; -use tower_http::cors::AllowOrigin; -use tracing_subscriber::layer::SubscriberExt; -use tracing_subscriber::util::SubscriberInitExt; -use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer}; +use text_generation_router::server; use text_generation_router_v3::{connect_backend, V3Error}; +use thiserror::Error; /// App Configuration #[derive(Parser, Debug)] @@ -121,7 +102,7 @@ async fn main() -> Result<(), RouterError> { } = args; // Launch Tokio runtime - init_logging(otlp_endpoint, otlp_service_name, json_output); + text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output); // Validate args if max_input_tokens >= max_total_tokens { @@ -148,218 +129,37 @@ async fn main() -> Result<(), RouterError> { } } - // CORS allowed origins - // map to go inside the option and then map to parse from String to HeaderValue - // Finally, convert to AllowOrigin - let cors_allow_origin: Option = cors_allow_origin.map(|cors_allow_origin| { - AllowOrigin::list( - cors_allow_origin - .iter() - .map(|origin| origin.parse::().unwrap()), - ) - }); - - // Parse Huggingface hub token - let authorization_token = std::env::var("HF_TOKEN") - .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")) - .ok(); - - // Tokenizer instance - // This will only be used to validate payloads - let local_path = Path::new(&tokenizer_name); - - // Shared API builder initialization - let api_builder = || { - let mut builder = ApiBuilder::new() - .with_progress(false) - .with_token(authorization_token); - - if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") { - builder = builder.with_cache_dir(cache_dir.into()); - } - - builder - }; - - // Decide if we need to use the API based on the revision and local path - let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir(); - - // Initialize API if needed - #[derive(Clone)] - enum Type { - Api(Api), - Cache(Cache), - None, - } - let api = if use_api { - if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) { - let cache = Cache::default(); - tracing::warn!("Offline mode active using cache defaults"); - Type::Cache(cache) - } else { - tracing::info!("Using the Hugging Face API"); - match api_builder().build() { - Ok(api) => Type::Api(api), - Err(_) => { - tracing::warn!("Unable to build the Hugging Face API"); - Type::None - } - } - } - } else { - Type::None - }; - - // Load tokenizer and model info - let ( - tokenizer_filename, - config_filename, - tokenizer_config_filename, - processor_config_filename, - model_info, - ) = match api { - Type::None => ( - Some(local_path.join("tokenizer.json")), - Some(local_path.join("config.json")), - Some(local_path.join("tokenizer_config.json")), - Some(local_path.join("processor_config.json")), - None, - ), - Type::Api(api) => { - let api_repo = api.repo(Repo::with_revision( - tokenizer_name.to_string(), - RepoType::Model, - revision.clone().unwrap_or_else(|| "main".to_string()), - )); - - let tokenizer_filename = match api_repo.get("tokenizer.json").await { - Ok(tokenizer_filename) => Some(tokenizer_filename), - Err(_) => get_base_tokenizer(&api, &api_repo).await, - }; - let config_filename = api_repo.get("config.json").await.ok(); - let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); - let processor_config_filename = api_repo.get("processor_config.json").await.ok(); - - let model_info = if let Some(model_info) = get_model_info(&api_repo).await { - Some(model_info) - } else { - tracing::warn!("Could not retrieve model info from the Hugging Face hub."); - None - }; - ( - tokenizer_filename, - config_filename, - tokenizer_config_filename, - processor_config_filename, - model_info, - ) - } - Type::Cache(cache) => { - let repo = cache.repo(Repo::with_revision( - tokenizer_name.to_string(), - RepoType::Model, - revision.clone().unwrap_or_else(|| "main".to_string()), - )); - ( - repo.get("tokenizer.json"), - repo.get("config.json"), - repo.get("tokenizer_config.json"), - repo.get("processor_config.json"), - None, - ) - } - }; - let tokenizer: Option = - tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok()); - let config: Option = config_filename.and_then(|filename| { - std::fs::read_to_string(filename) - .ok() - .as_ref() - .and_then(|c| { - let config: Result = serde_json::from_str(c); - if let Err(err) = &config { - tracing::warn!("Could not parse config {err:?}"); - } - config.ok() - }) - }); - let model_info = model_info.unwrap_or_else(|| HubModelInfo { - model_id: tokenizer_name.to_string(), - sha: None, - pipeline_tag: None, - }); - - // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. - let tokenizer_config: Option = if let Some(filename) = tokenizer_config_path - { - HubTokenizerConfig::from_file(filename) - } else { - tokenizer_config_filename.and_then(HubTokenizerConfig::from_file) - }; - let tokenizer_config = tokenizer_config.unwrap_or_else(|| { - tracing::warn!("Could not find tokenizer config locally and no API specified"); - HubTokenizerConfig::default() - }); - - let processor_config = processor_config_filename - .and_then(HubProcessorConfig::from_file) - .unwrap_or_default(); - - tracing::info!("Using config {config:?}"); - if tokenizer.is_none() { - tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}"); - tracing::warn!("Rust input length validation and truncation is disabled"); - } - - // if pipeline-tag == text-generation we default to return_full_text = true - let compat_return_full_text = match &model_info.pipeline_tag { - None => { - tracing::warn!("no pipeline tag found for model {tokenizer_name}"); - true - } - Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation", - }; - - // Determine the server port based on the feature and environment variable. - let port = if cfg!(feature = "google") { - std::env::var("AIP_HTTP_PORT") - .map(|aip_http_port| aip_http_port.parse::().unwrap_or(port)) - .unwrap_or(port) - } else { - port - }; - - let addr = match hostname.parse() { - Ok(ip) => SocketAddr::new(ip, port), - Err(_) => { - tracing::warn!("Invalid hostname, defaulting to 0.0.0.0"); - SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port) - } - }; - - let (backend, backend_info) = connect_backend(max_input_tokens, max_total_tokens, master_shard_uds_path, waiting_served_ratio, max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, max_batch_size).await?; + let (backend, _backend_info) = connect_backend( + max_input_tokens, + max_total_tokens, + master_shard_uds_path, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + ) + .await?; // Run server server::run( backend, - model_info, - compat_return_full_text, max_concurrent_requests, max_best_of, max_stop_sequences, max_top_n_tokens, max_input_tokens, max_total_tokens, - tokenizer, - config, validation_workers, - addr, + tokenizer_name, + tokenizer_config_path, + revision, + hostname, + port, cors_allow_origin, ngrok, ngrok_authtoken, ngrok_edge, - tokenizer_config, - processor_config, messages_api_enabled, disable_grammar_support, max_client_batch_size, @@ -368,140 +168,6 @@ async fn main() -> Result<(), RouterError> { Ok(()) } -/// Init logging using env variables LOG_LEVEL and LOG_FORMAT: -/// - otlp_endpoint is an optional URL to an Open Telemetry collector -/// - otlp_service_name service name to appear in APM -/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO) -/// - LOG_FORMAT may be TEXT or JSON (default to TEXT) -/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms) -fn init_logging(otlp_endpoint: Option, otlp_service_name: String, json_output: bool) { - let mut layers = Vec::new(); - - // STDOUT/STDERR layer - let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string()); - let fmt_layer = tracing_subscriber::fmt::layer() - .with_file(true) - .with_ansi(ansi) - .with_line_number(true); - - let fmt_layer = match json_output { - true => fmt_layer.json().flatten_event(true).boxed(), - false => fmt_layer.boxed(), - }; - layers.push(fmt_layer); - - // OpenTelemetry tracing layer - if let Some(otlp_endpoint) = otlp_endpoint { - global::set_text_map_propagator(TraceContextPropagator::new()); - - let tracer = opentelemetry_otlp::new_pipeline() - .tracing() - .with_exporter( - opentelemetry_otlp::new_exporter() - .tonic() - .with_endpoint(otlp_endpoint), - ) - .with_trace_config( - trace::config() - .with_resource(Resource::new(vec![KeyValue::new( - "service.name", - otlp_service_name, - )])) - .with_sampler(Sampler::AlwaysOn), - ) - .install_batch(opentelemetry::runtime::Tokio); - - if let Ok(tracer) = tracer { - layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed()); - init_tracing_opentelemetry::init_propagator().unwrap(); - }; - } - - // Filter events with LOG_LEVEL - let varname = "LOG_LEVEL"; - let env_filter = if let Ok(log_level) = std::env::var(varname) { - // Override to avoid simple logs to be spammed with tokio level informations - let log_level = match &log_level[..] { - "warn" => "text_generation_launcher=warn,text_generation_router=warn", - "info" => "text_generation_launcher=info,text_generation_router=info", - "debug" => "text_generation_launcher=debug,text_generation_router=debug", - log_level => log_level, - }; - EnvFilter::builder() - .with_default_directive(LevelFilter::INFO.into()) - .parse_lossy(log_level) - } else { - EnvFilter::new("info") - }; - - tracing_subscriber::registry() - .with(env_filter) - .with(layers) - .init(); -} - -/// get model info from the Huggingface Hub -pub async fn get_model_info(api: &ApiRepo) -> Option { - let response = api.info_request().send().await.ok()?; - - if response.status().is_success() { - let hub_model_info: HubModelInfo = - serde_json::from_str(&response.text().await.ok()?).ok()?; - if let Some(sha) = &hub_model_info.sha { - tracing::info!( - "Serving revision {sha} of model {}", - hub_model_info.model_id - ); - } - Some(hub_model_info) - } else { - None - } -} - -/// get base tokenizer -pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option { - let config_filename = api_repo.get("config.json").await.ok()?; - - // Open the file in read-only mode with buffer. - let file = File::open(config_filename).ok()?; - let reader = BufReader::new(file); - - // Read the JSON contents of the file as an instance of `User`. - let config: serde_json::Value = serde_json::from_reader(reader).ok()?; - - if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") { - let api_base_repo = api.repo(Repo::with_revision( - base_model_id.to_string(), - RepoType::Model, - "main".to_string(), - )); - - api_base_repo.get("tokenizer.json").await.ok() - } else { - None - } -} - -/// get tokenizer_config from the Huggingface Hub -pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option { - let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?; - - // Open the file in read-only mode with buffer. - let file = File::open(tokenizer_config_filename).ok()?; - let reader = BufReader::new(file); - - // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. - let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader) - .map_err(|e| { - tracing::warn!("Unable to parse tokenizer config: {}", e); - e - }) - .ok()?; - - Some(tokenizer_config) -} - #[derive(Debug, Error)] enum RouterError { #[error("Argument validation error: {0}")] diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 3985cb85..65e31e20 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -1,14 +1,17 @@ use crate::block_allocator::{BlockAllocation, BlockAllocator}; -use text_generation_router::infer::InferError; -use text_generation_router::infer::InferStreamResponse; -use text_generation_router::validation::{Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters}; -use nohash_hasher::{BuildNoHashHasher, IntMap}; -use std::cmp::{max, min}; -use std::collections::VecDeque; +use crate::client; use crate::client::{ Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; -use crate::client as client; +use nohash_hasher::{BuildNoHashHasher, IntMap}; +use std::cmp::{max, min}; +use std::collections::VecDeque; +use text_generation_router::infer::InferError; +use text_generation_router::infer::InferStreamResponse; +use text_generation_router::validation::{ + Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, + ValidStoppingParameters, +}; use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; use tracing::{info_span, instrument, Instrument, Span}; @@ -335,10 +338,21 @@ impl State { id, prefill_logprobs: entry.request.decoder_input_details, input_chunks: Some(client::Input { - chunks: entry.request.inputs.clone().into_iter().map(|c| client::InputChunk { chunk: Some(match c { - Chunk::Text(text) => client::Chunk::Text(text), - Chunk::Image(image) => client::Chunk::Image(client::Image { data: image.data, mimetype: image.mimetype }) - })}).collect() + chunks: entry + .request + .inputs + .clone() + .into_iter() + .map(|c| client::InputChunk { + chunk: Some(match c { + Chunk::Text(text) => client::Chunk::Text(text), + Chunk::Image(image) => client::Chunk::Image(client::Image { + data: image.data, + mimetype: image.mimetype, + }), + }), + }) + .collect(), }), inputs: entry.request.inputs.chunks_to_string(), truncate: entry.request.truncate, diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index c1b65158..5ae833ed 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -74,7 +74,6 @@ impl ChatTemplate { } } - // tests #[cfg(test)] mod tests { diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 4cef5def..af357ade 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -3,23 +3,23 @@ mod chat_template; pub mod tool_grammar; use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; +use crate::GrammarType; use crate::{ - ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, - HubTokenizerConfig, Message, PrefillToken, Token, + ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig, + Message, PrefillToken, Token, }; -use crate::{GrammarType}; +use async_trait::async_trait; +use chat_template::ChatTemplate; use futures::future::try_join_all; -use minijinja::{ErrorKind}; -use std::sync::Arc; +use minijinja::ErrorKind; use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; use thiserror::Error; use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::StreamExt; use tracing::instrument; -use chat_template::ChatTemplate; -use async_trait::async_trait; #[async_trait] pub trait Backend { @@ -86,7 +86,8 @@ impl Infer { #[instrument(skip_all)] pub(crate) async fn generate_stream<'a>( &'a self, - request: GenerateRequest) -> Result { + request: GenerateRequest, + ) -> Result { // Limit concurrent requests by acquiring a permit from the semaphore let permit = self .clone() @@ -106,9 +107,7 @@ impl Infer { })?; let input_length = valid_request.input_length; - let generation_stream = self - .backend - .schedule(valid_request)?; + let generation_stream = self.backend.schedule(valid_request)?; Ok((permit, input_length, generation_stream)) } diff --git a/router/src/infer/v2/scheduler.rs b/router/src/infer/v2/scheduler.rs index 6123c7ac..89356582 100644 --- a/router/src/infer/v2/scheduler.rs +++ b/router/src/infer/v2/scheduler.rs @@ -1,7 +1,7 @@ /// Batching and inference logic use crate::infer::v2::queue::{Entry, Queue}; use crate::infer::{ - GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Backend, + Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, }; use crate::validation::ValidGenerateRequest; use crate::{FinishReason, PrefillToken, Token}; diff --git a/router/src/lib.rs b/router/src/lib.rs index d62f5257..2e8c6262 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -6,6 +6,7 @@ pub mod validation; #[cfg(feature = "kserve")] mod kserve; +pub mod logging; use serde::{Deserialize, Serialize}; use tracing::warn; @@ -158,7 +159,6 @@ pub struct Info { #[schema(example = "32")] pub max_client_batch_size: usize, - /// Router Info #[schema(example = "text-generation-router")] pub router: &'static str, diff --git a/router/src/logging.rs b/router/src/logging.rs new file mode 100644 index 00000000..5a98ef57 --- /dev/null +++ b/router/src/logging.rs @@ -0,0 +1,81 @@ +use opentelemetry::sdk::propagation::TraceContextPropagator; +use opentelemetry::sdk::trace; +use opentelemetry::sdk::trace::Sampler; +use opentelemetry::sdk::Resource; +use opentelemetry::{global, KeyValue}; +use opentelemetry_otlp::WithExportConfig; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; +use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer}; + +/// Init logging using env variables LOG_LEVEL and LOG_FORMAT: +/// - otlp_endpoint is an optional URL to an Open Telemetry collector +/// - otlp_service_name service name to appear in APM +/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO) +/// - LOG_FORMAT may be TEXT or JSON (default to TEXT) +/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms) +pub fn init_logging(otlp_endpoint: Option, otlp_service_name: String, json_output: bool) { + let mut layers = Vec::new(); + + // STDOUT/STDERR layer + let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string()); + let fmt_layer = tracing_subscriber::fmt::layer() + .with_file(true) + .with_ansi(ansi) + .with_line_number(true); + + let fmt_layer = match json_output { + true => fmt_layer.json().flatten_event(true).boxed(), + false => fmt_layer.boxed(), + }; + layers.push(fmt_layer); + + // OpenTelemetry tracing layer + if let Some(otlp_endpoint) = otlp_endpoint { + global::set_text_map_propagator(TraceContextPropagator::new()); + + let tracer = opentelemetry_otlp::new_pipeline() + .tracing() + .with_exporter( + opentelemetry_otlp::new_exporter() + .tonic() + .with_endpoint(otlp_endpoint), + ) + .with_trace_config( + trace::config() + .with_resource(Resource::new(vec![KeyValue::new( + "service.name", + otlp_service_name, + )])) + .with_sampler(Sampler::AlwaysOn), + ) + .install_batch(opentelemetry::runtime::Tokio); + + if let Ok(tracer) = tracer { + layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed()); + init_tracing_opentelemetry::init_propagator().unwrap(); + }; + } + + // Filter events with LOG_LEVEL + let varname = "LOG_LEVEL"; + let env_filter = if let Ok(log_level) = std::env::var(varname) { + // Override to avoid simple logs to be spammed with tokio level informations + let log_level = match &log_level[..] { + "warn" => "text_generation_launcher=warn,text_generation_router=warn", + "info" => "text_generation_launcher=info,text_generation_router=info", + "debug" => "text_generation_launcher=debug,text_generation_router=debug", + log_level => log_level, + }; + EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .parse_lossy(log_level) + } else { + EnvFilter::new("info") + }; + + tracing_subscriber::registry() + .with(env_filter) + .with(layers) + .init(); +} diff --git a/router/src/server.rs b/router/src/server.rs index a7cfd74e..382b35cf 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,7 +1,7 @@ /// HTTP Server logic use crate::config::Config; -use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, Backend}; use crate::infer::tool_grammar::ToolGrammar; +use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse}; #[cfg(feature = "kserve")] use crate::kserve::{ kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer, @@ -23,7 +23,7 @@ use crate::{ use crate::{FunctionDefinition, ToolCall, ToolType}; use async_stream::__private::AsyncStream; use axum::extract::Extension; -use axum::http::{HeaderMap, Method, StatusCode}; +use axum::http::{HeaderMap, HeaderValue, Method, StatusCode}; use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::{IntoResponse, Response}; use axum::routing::{get, post}; @@ -33,10 +33,15 @@ use futures::stream::StreamExt; use futures::stream::{FuturesOrdered, FuturesUnordered}; use futures::Stream; use futures::TryStreamExt; +use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; +use hf_hub::{Cache, Repo, RepoType}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use serde_json::Value; use std::convert::Infallible; -use std::net::SocketAddr; +use std::fs::File; +use std::io::BufReader; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::path::{Path, PathBuf}; use thiserror::Error; use tokenizers::Tokenizer; use tokio::select; @@ -1395,24 +1400,22 @@ pub(crate) struct ComputeType(String); #[allow(clippy::too_many_arguments)] pub async fn run( backend: impl Backend + Send + Sync + 'static, - model_info: HubModelInfo, - compat_return_full_text: bool, max_concurrent_requests: usize, max_best_of: usize, max_stop_sequences: usize, max_top_n_tokens: u32, max_input_tokens: usize, max_total_tokens: usize, - tokenizer: Option, - config: Option, validation_workers: usize, - addr: SocketAddr, - allow_origin: Option, + tokenizer_name: String, + tokenizer_config_path: Option, + revision: Option, + hostname: String, + port: u16, + cors_allow_origin: Option>, ngrok: bool, _ngrok_authtoken: Option, _ngrok_edge: Option, - tokenizer_config: HubTokenizerConfig, - processor_config: HubProcessorConfig, messages_api_enabled: bool, grammar_support: bool, max_client_batch_size: usize, @@ -1485,6 +1488,195 @@ pub async fn run( )] struct ApiDoc; + // CORS allowed origins + // map to go inside the option and then map to parse from String to HeaderValue + // Finally, convert to AllowOrigin + let allow_origin: Option = cors_allow_origin.map(|cors_allow_origin| { + AllowOrigin::list( + cors_allow_origin + .iter() + .map(|origin| origin.parse::().unwrap()), + ) + }); + + // Parse Huggingface hub token + let authorization_token = std::env::var("HF_TOKEN") + .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")) + .ok(); + + // Tokenizer instance + // This will only be used to validate payloads + let local_path = Path::new(&tokenizer_name); + + // Shared API builder initialization + let api_builder = || { + let mut builder = ApiBuilder::new() + .with_progress(false) + .with_token(authorization_token); + + if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") { + builder = builder.with_cache_dir(cache_dir.into()); + } + + builder + }; + + // Decide if we need to use the API based on the revision and local path + let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir(); + + // Initialize API if needed + #[derive(Clone)] + enum Type { + Api(Api), + Cache(Cache), + None, + } + let api = if use_api { + if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) { + let cache = Cache::default(); + tracing::warn!("Offline mode active using cache defaults"); + Type::Cache(cache) + } else { + tracing::info!("Using the Hugging Face API"); + match api_builder().build() { + Ok(api) => Type::Api(api), + Err(_) => { + tracing::warn!("Unable to build the Hugging Face API"); + Type::None + } + } + } + } else { + Type::None + }; + + // Load tokenizer and model info + let ( + tokenizer_filename, + config_filename, + tokenizer_config_filename, + processor_config_filename, + model_info, + ) = match api { + Type::None => ( + Some(local_path.join("tokenizer.json")), + Some(local_path.join("config.json")), + Some(local_path.join("tokenizer_config.json")), + Some(local_path.join("processor_config.json")), + None, + ), + Type::Api(api) => { + let api_repo = api.repo(Repo::with_revision( + tokenizer_name.to_string(), + RepoType::Model, + revision.clone().unwrap_or_else(|| "main".to_string()), + )); + + let tokenizer_filename = match api_repo.get("tokenizer.json").await { + Ok(tokenizer_filename) => Some(tokenizer_filename), + Err(_) => get_base_tokenizer(&api, &api_repo).await, + }; + let config_filename = api_repo.get("config.json").await.ok(); + let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); + let processor_config_filename = api_repo.get("processor_config.json").await.ok(); + + let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await { + Some(model_info) + } else { + tracing::warn!("Could not retrieve model info from the Hugging Face hub."); + None + }; + ( + tokenizer_filename, + config_filename, + tokenizer_config_filename, + processor_config_filename, + model_info, + ) + } + Type::Cache(cache) => { + let repo = cache.repo(Repo::with_revision( + tokenizer_name.to_string(), + RepoType::Model, + revision.clone().unwrap_or_else(|| "main".to_string()), + )); + ( + repo.get("tokenizer.json"), + repo.get("config.json"), + repo.get("tokenizer_config.json"), + repo.get("processor_config.json"), + None, + ) + } + }; + let tokenizer: Option = + tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok()); + let config: Option = config_filename.and_then(|filename| { + std::fs::read_to_string(filename) + .ok() + .as_ref() + .and_then(|c| { + let config: Result = serde_json::from_str(c); + if let Err(err) = &config { + tracing::warn!("Could not parse config {err:?}"); + } + config.ok() + }) + }); + let model_info = model_info.unwrap_or_else(|| HubModelInfo { + model_id: tokenizer_name.to_string(), + sha: None, + pipeline_tag: None, + }); + + // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. + let tokenizer_config: Option = if let Some(filename) = tokenizer_config_path + { + HubTokenizerConfig::from_file(filename) + } else { + tokenizer_config_filename.and_then(HubTokenizerConfig::from_file) + }; + let tokenizer_config = tokenizer_config.unwrap_or_else(|| { + tracing::warn!("Could not find tokenizer config locally and no API specified"); + HubTokenizerConfig::default() + }); + + let processor_config = processor_config_filename + .and_then(HubProcessorConfig::from_file) + .unwrap_or_default(); + + tracing::info!("Using config {config:?}"); + if tokenizer.is_none() { + tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}"); + tracing::warn!("Rust input length validation and truncation is disabled"); + } + + // if pipeline-tag == text-generation we default to return_full_text = true + let compat_return_full_text = match &model_info.pipeline_tag { + None => { + tracing::warn!("no pipeline tag found for model {tokenizer_name}"); + true + } + Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation", + }; + + // Determine the server port based on the feature and environment variable. + let port = if cfg!(feature = "google") { + std::env::var("AIP_HTTP_PORT") + .map(|aip_http_port| aip_http_port.parse::().unwrap_or(port)) + .unwrap_or(port) + } else { + port + }; + + let addr = match hostname.parse() { + Ok(ip) => SocketAddr::new(ip, port), + Err(_) => { + tracing::warn!("Invalid hostname, defaulting to 0.0.0.0"); + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port) + } + }; + // Create state let validation = Validation::new( validation_workers, @@ -1551,8 +1743,8 @@ pub async fn run( .unwrap() .set_buckets_for_metric(batch_size_matcher, &batch_size_buckets) .unwrap(); - // .set_buckets_for_metric(skipped_matcher, &skipped_buckets) - // .unwrap(); + // .set_buckets_for_metric(skipped_matcher, &skipped_buckets) + // .unwrap(); let prom_handle = builder .install_recorder() .expect("failed to install metrics recorder"); @@ -1750,6 +1942,68 @@ pub async fn run( Ok(()) } +/// get model info from the Huggingface Hub +pub async fn get_hub_model_info(api: &ApiRepo) -> Option { + let response = api.info_request().send().await.ok()?; + + if response.status().is_success() { + let hub_model_info: HubModelInfo = + serde_json::from_str(&response.text().await.ok()?).ok()?; + if let Some(sha) = &hub_model_info.sha { + tracing::info!( + "Serving revision {sha} of model {}", + hub_model_info.model_id + ); + } + Some(hub_model_info) + } else { + None + } +} + +/// get base tokenizer +pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option { + let config_filename = api_repo.get("config.json").await.ok()?; + + // Open the file in read-only mode with buffer. + let file = File::open(config_filename).ok()?; + let reader = BufReader::new(file); + + // Read the JSON contents of the file as an instance of `User`. + let config: serde_json::Value = serde_json::from_reader(reader).ok()?; + + if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") { + let api_base_repo = api.repo(Repo::with_revision( + base_model_id.to_string(), + RepoType::Model, + "main".to_string(), + )); + + api_base_repo.get("tokenizer.json").await.ok() + } else { + None + } +} + +/// get tokenizer_config from the Huggingface Hub +pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option { + let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?; + + // Open the file in read-only mode with buffer. + let file = File::open(tokenizer_config_filename).ok()?; + let reader = BufReader::new(file); + + // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. + let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader) + .map_err(|e| { + tracing::warn!("Unable to parse tokenizer config: {}", e); + e + }) + .ok()?; + + Some(tokenizer_config) +} + /// Shutdown signal handler async fn shutdown_signal() { let ctrl_c = async { diff --git a/router/src/validation.rs b/router/src/validation.rs index cf9e107c..2a00b08e 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -638,7 +638,7 @@ pub struct Image { #[derive(Debug, Clone, Eq, PartialEq)] pub enum Chunk { Text(String), - Image(Image) + Image(Image), } /// Convert input chunks to a stringly-typed input for backwards