diff --git a/Cargo.lock b/Cargo.lock index 27c345b3..f826ea34 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -950,13 +950,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732" dependencies = [ "dirs 5.0.1", + "futures", "indicatif", "log", "native-tls", + "num_cpus", "rand", + "reqwest", "serde", "serde_json", "thiserror", + "tokio", "ureq", ] diff --git a/router/Cargo.toml b/router/Cargo.toml index 55af635a..5ccdb0cd 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -21,6 +21,7 @@ axum-tracing-opentelemetry = "0.14.1" text-generation-client = { path = "client" } clap = { version = "4.4.5", features = ["derive", "env"] } futures = "0.3.28" +hf-hub = { version = "0.3.0", features = ["tokio"] } metrics = "0.21.1" metrics-exporter-prometheus = { version = "0.12.1", features = [] } nohash-hasher = "0.2.0" @@ -41,7 +42,6 @@ tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } utoipa = { version = "3.5.0", features = ["axum_extras"] } utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] } ngrok = { version = "0.13.1", features = ["axum"], optional = true } -hf-hub = "0.3.1" init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } [build-dependencies] diff --git a/router/src/main.rs b/router/src/main.rs index d90632ef..4637c77c 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -1,19 +1,22 @@ -/// Text Generation Inference webserver entrypoint use axum::http::HeaderValue; use clap::Parser; +use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; +use hf_hub::{Repo, RepoType}; use opentelemetry::sdk::propagation::TraceContextPropagator; use opentelemetry::sdk::trace; use opentelemetry::sdk::trace::Sampler; use opentelemetry::sdk::Resource; use opentelemetry::{global, KeyValue}; use opentelemetry_otlp::WithExportConfig; +/// Text Generation Inference webserver entrypoint +use std::fs::File; +use std::io::BufReader; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::Path; -use std::time::Duration; use text_generation_client::{ClientError, ShardedClient}; use text_generation_router::{server, HubModelInfo}; use thiserror::Error; -use tokenizers::{FromPretrainedParameters, Tokenizer}; +use tokenizers::Tokenizer; use tower_http::cors::AllowOrigin; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; @@ -69,7 +72,8 @@ struct Args { ngrok_edge: Option, } -fn main() -> Result<(), RouterError> { +#[tokio::main] +async fn main() -> Result<(), RouterError> { // Get args let args = Args::parse(); // Pattern match configuration @@ -98,6 +102,9 @@ fn main() -> Result<(), RouterError> { ngrok_edge, } = args; + // Launch Tokio runtime + init_logging(otlp_endpoint, json_output); + // Validate args if max_input_length >= max_total_tokens { return Err(RouterError::ArgumentValidation( @@ -141,146 +148,158 @@ fn main() -> Result<(), RouterError> { // This will only be used to validate payloads let local_path = Path::new(&tokenizer_name); let local_model = local_path.exists() && local_path.is_dir(); - let tokenizer = if local_model { - // Load local tokenizer - Tokenizer::from_file(local_path.join("tokenizer.json")).ok() - } else { - // Download and instantiate tokenizer - // We need to download it outside of the Tokio runtime - let params = FromPretrainedParameters { - revision: revision.clone().unwrap_or("main".to_string()), - auth_token: authorization_token.clone(), - ..Default::default() + + let (tokenizer, model_info) = if local_model { + // Get Model info + let model_info = HubModelInfo { + model_id: tokenizer_name.clone(), + sha: None, + pipeline_tag: None, }; - Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).ok() + + // Load local tokenizer + let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok(); + + (tokenizer, model_info) + } else { + let mut builder = ApiBuilder::new() + .with_progress(false) + .with_token(authorization_token); + + if let Some(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE").ok() { + builder = builder.with_cache_dir(cache_dir.into()); + } + + if revision.is_none() { + tracing::warn!("`--revision` is not set"); + tracing::warn!("We strongly advise to set it to a known supported commit."); + } + + let api = builder.build().unwrap(); + let api_repo = api.repo(Repo::with_revision( + tokenizer_name.clone(), + RepoType::Model, + revision.clone().unwrap_or("main".to_string()), + )); + + // Get Model info + let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| { + tracing::warn!("Could not retrieve model info from the Hugging Face hub."); + HubModelInfo { + model_id: tokenizer_name.to_string(), + sha: None, + pipeline_tag: None, + } + }); + + let tokenizer = match api_repo.get("tokenizer.json").await { + Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(), + Err(_) => get_base_tokenizer(&api, &api_repo).await, + }; + + (tokenizer, model_info) }; - // Launch Tokio runtime - tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build()? - .block_on(async { - init_logging(otlp_endpoint, json_output); + if tokenizer.is_none() { + tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}"); + tracing::warn!("Rust input length validation and truncation is disabled"); + } - if tokenizer.is_none() { + // if pipeline-tag == text-generation we default to return_full_text = true + let compat_return_full_text = match &model_info.pipeline_tag { + None => { + tracing::warn!("no pipeline tag found for model {tokenizer_name}"); + false + } + Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation", + }; + + // Instantiate sharded client from the master unix socket + let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) + .await + .map_err(RouterError::Connection)?; + // Clear the cache; useful if the webserver rebooted + sharded_client + .clear_cache(None) + .await + .map_err(RouterError::Cache)?; + // Get info from the shard + let shard_info = sharded_client.info().await.map_err(RouterError::Info)?; + + // Warmup model + tracing::info!("Warming up model"); + let max_supported_batch_total_tokens = match sharded_client + .warmup( + max_input_length as u32, + max_batch_prefill_tokens, + max_total_tokens as u32, + ) + .await + .map_err(RouterError::Warmup)? + { + // Older models do not support automatic max-batch-total-tokens + None => { + let max_batch_total_tokens = max_batch_total_tokens + .unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))); + tracing::warn!("Model does not support automatic max batch total tokens"); + max_batch_total_tokens + } + // Flash attention models return their max supported total tokens + Some(max_supported_batch_total_tokens) => { + // Warn if user added his own max-batch-total-tokens as we will ignore it + if max_batch_total_tokens.is_some() { tracing::warn!( - "Could not find a fast tokenizer implementation for {tokenizer_name}" + "`--max-batch-total-tokens` is deprecated for Flash \ + Attention models." ); - tracing::warn!("Rust input length validation and truncation is disabled"); + tracing::warn!( + "Inferred max batch total tokens: {max_supported_batch_total_tokens}" + ); + } + if max_total_tokens as u32 > max_supported_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_supported_batch_total_tokens}"))); } - // Get Model info - let model_info = match local_model { - true => HubModelInfo { - model_id: tokenizer_name.clone(), - sha: None, - pipeline_tag: None, - }, - false => get_model_info(&tokenizer_name, revision, authorization_token) - .await - .unwrap_or_else(|| { - tracing::warn!("Could not retrieve model info from the Hugging Face hub."); - HubModelInfo { - model_id: tokenizer_name.to_string(), - sha: None, - pipeline_tag: None, - } - }), - }; + max_supported_batch_total_tokens + } + }; + tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}"); + tracing::info!("Connected"); - // if pipeline-tag == text-generation we default to return_full_text = true - let compat_return_full_text = match &model_info.pipeline_tag { - None => { - tracing::warn!("no pipeline tag found for model {tokenizer_name}"); - false - } - Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation", - }; + let addr = match hostname.parse() { + Ok(ip) => SocketAddr::new(ip, port), + Err(_) => { + tracing::warn!("Invalid hostname, defaulting to 0.0.0.0"); + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port) + } + }; - // Instantiate sharded client from the master unix socket - let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) - .await - .map_err(RouterError::Connection)?; - // Clear the cache; useful if the webserver rebooted - sharded_client - .clear_cache(None) - .await - .map_err(RouterError::Cache)?; - // Get info from the shard - let shard_info = sharded_client.info().await.map_err(RouterError::Info)?; - - // Warmup model - tracing::info!("Warming up model"); - let max_supported_batch_total_tokens = match sharded_client - .warmup(max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32) - .await - .map_err(RouterError::Warmup)? - { - // Older models do not support automatic max-batch-total-tokens - None => { - let max_batch_total_tokens = max_batch_total_tokens.unwrap_or( - 16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)), - ); - tracing::warn!("Model does not support automatic max batch total tokens"); - max_batch_total_tokens - } - // Flash attention models return their max supported total tokens - Some(max_supported_batch_total_tokens) => { - // Warn if user added his own max-batch-total-tokens as we will ignore it - if max_batch_total_tokens.is_some() { - tracing::warn!( - "`--max-batch-total-tokens` is deprecated for Flash \ - Attention models." - ); - tracing::warn!( - "Inferred max batch total tokens: {max_supported_batch_total_tokens}" - ); - } - if max_total_tokens as u32 > max_supported_batch_total_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_supported_batch_total_tokens}"))); - } - - max_supported_batch_total_tokens - } - }; - tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}"); - tracing::info!("Connected"); - - let addr = match hostname.parse() { - Ok(ip) => SocketAddr::new(ip, port), - Err(_) => { - tracing::warn!("Invalid hostname, defaulting to 0.0.0.0"); - SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port) - } - }; - - // Run server - server::run( - model_info, - shard_info, - compat_return_full_text, - max_concurrent_requests, - max_best_of, - max_stop_sequences, - max_top_n_tokens, - max_input_length, - max_total_tokens, - waiting_served_ratio, - max_batch_prefill_tokens, - max_supported_batch_total_tokens, - max_waiting_tokens, - sharded_client, - tokenizer, - validation_workers, - addr, - cors_allow_origin, - ngrok, - ngrok_authtoken, - ngrok_edge, - ) - .await?; - Ok(()) - }) + // Run server + server::run( + model_info, + shard_info, + compat_return_full_text, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_length, + max_total_tokens, + waiting_served_ratio, + max_batch_prefill_tokens, + max_supported_batch_total_tokens, + max_waiting_tokens, + sharded_client, + tokenizer, + validation_workers, + addr, + cors_allow_origin, + ngrok, + ngrok_authtoken, + ngrok_edge, + ) + .await?; + Ok(()) } /// Init logging using env variables LOG_LEVEL and LOG_FORMAT: @@ -339,30 +358,8 @@ fn init_logging(otlp_endpoint: Option, json_output: bool) { } /// get model info from the Huggingface Hub -pub async fn get_model_info( - model_id: &str, - revision: Option, - token: Option, -) -> Option { - let revision = match revision { - None => { - tracing::warn!("`--revision` is not set"); - tracing::warn!("We strongly advise to set it to a known supported commit."); - "main".to_string() - } - Some(revision) => revision, - }; - - let client = reqwest::Client::new(); - // Poor man's urlencode - let revision = revision.replace('/', "%2F"); - let url = format!("https://huggingface.co/api/models/{model_id}/revision/{revision}"); - let mut builder = client.get(url).timeout(Duration::from_secs(5)); - if let Some(token) = token { - builder = builder.bearer_auth(token); - } - - let response = builder.send().await.ok()?; +pub async fn get_model_info(api: &ApiRepo) -> Option { + let response = api.info_request().send().await.ok()?; if response.status().is_success() { let hub_model_info: HubModelInfo = @@ -379,6 +376,31 @@ pub async fn get_model_info( } } +/// get base tokenizer +pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option { + let config_filename = api_repo.get("config.json").await.ok()?; + + // Open the file in read-only mode with buffer. + let file = File::open(config_filename).ok()?; + let reader = BufReader::new(file); + + // Read the JSON contents of the file as an instance of `User`. + let config: serde_json::Value = serde_json::from_reader(reader).ok()?; + + if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") { + let api_base_repo = api.repo(Repo::with_revision( + base_model_id.to_string(), + RepoType::Model, + "main".to_string(), + )); + + let tokenizer_filename = api_base_repo.get("tokenizer.json").await.ok()?; + Tokenizer::from_file(tokenizer_filename).ok() + } else { + None + } +} + #[derive(Debug, Error)] enum RouterError { #[error("Argument validation error: {0}")]