diff --git a/router/src/lib.rs b/router/src/lib.rs index 5e207a03..ecd8e2e0 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -73,9 +73,9 @@ pub struct HubTokenizerConfig { } impl HubTokenizerConfig { - pub fn from_file(filename: &std::path::Path) -> Self { - let content = std::fs::read_to_string(filename).unwrap(); - serde_json::from_str(&content).unwrap_or_default() + pub fn from_file>(filename: P) -> Option { + let content = std::fs::read_to_string(filename).ok()?; + serde_json::from_str(&content).ok() } } diff --git a/router/src/main.rs b/router/src/main.rs index c7e3f90b..63347b78 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -1,7 +1,7 @@ use axum::http::HeaderValue; use clap::Parser; use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; -use hf_hub::{Repo, RepoType}; +use hf_hub::{Cache, Repo, RepoType}; use opentelemetry::sdk::propagation::TraceContextPropagator; use opentelemetry::sdk::trace; use opentelemetry::sdk::trace::Sampler; @@ -11,7 +11,7 @@ use opentelemetry_otlp::WithExportConfig; use std::fs::File; use std::io::BufReader; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; -use std::path::Path; +use std::path::{Path, PathBuf}; use text_generation_client::{ClientError, ShardedClient}; use text_generation_router::config::Config; use text_generation_router::{server, HubModelInfo, HubTokenizerConfig}; @@ -162,7 +162,6 @@ async fn main() -> Result<(), RouterError> { // Tokenizer instance // This will only be used to validate payloads let local_path = Path::new(&tokenizer_name); - let local_model = local_path.exists() && local_path.is_dir(); // Shared API builder initialization let api_builder = || { @@ -181,112 +180,113 @@ async fn main() -> Result<(), RouterError> { let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir(); // Initialize API if needed + #[derive(Clone)] + enum Type { + Api(Api), + Cache(Cache), + None, + } let api = if use_api { - tracing::info!("Using the Hugging Face API"); - match api_builder().build() { - Ok(api) => Some(api), - Err(_) => { - tracing::warn!("Unable to build the Hugging Face API"); - None + if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) { + let cache = Cache::default(); + tracing::warn!("Offline mode active using cache defaults"); + Type::Cache(cache) + } else { + tracing::info!("Using the Hugging Face API"); + match api_builder().build() { + Ok(api) => Type::Api(api), + Err(_) => { + tracing::warn!("Unable to build the Hugging Face API"); + Type::None + } } } } else { - None + Type::None }; // Load tokenizer and model info - let (tokenizer, model_info, config) = if local_model { - let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok(); - let model_info = HubModelInfo { - model_id: tokenizer_name.to_string(), - sha: None, - pipeline_tag: None, - }; - let config: Option = std::fs::read_to_string(local_path.join("config.json")) - .ok() - .as_ref() - .and_then(|c| serde_json::from_str(c).ok()); + let (tokenizer_filename, config_filename, tokenizer_config_filename, model_info) = match api { + Type::None => ( + Some(local_path.join("tokenizer.json")), + Some(local_path.join("config.json")), + Some(local_path.join("tokenizer_config.json")), + None, + ), + Type::Api(api) => { + let api_repo = api.repo(Repo::with_revision( + tokenizer_name.to_string(), + RepoType::Model, + revision.clone().unwrap_or_else(|| "main".to_string()), + )); - (tokenizer, model_info, config) - } else if let Some(api) = api.clone() { - let api_repo = api.repo(Repo::with_revision( - tokenizer_name.to_string(), - RepoType::Model, - revision.clone().unwrap_or_else(|| "main".to_string()), - )); + let tokenizer_filename = match api_repo.get("tokenizer.json").await { + Ok(tokenizer_filename) => Some(tokenizer_filename), + Err(_) => get_base_tokenizer(&api, &api_repo).await, + }; + let config_filename = api_repo.get("config.json").await.ok(); + let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); - let tokenizer = match api_repo.get("tokenizer.json").await { - Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(), - Err(_) => get_base_tokenizer(&api, &api_repo).await, - }; - - let config: Option = api_repo.get("config.json").await.ok().and_then(|filename| { - std::fs::read_to_string(filename) - .ok() - .as_ref() - .and_then(|c| { - let config: Result = serde_json::from_str(c); - if let Err(err) = &config { - tracing::warn!("Could not parse config {err:?}"); - } - config.ok() - }) - }); - - let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| { - tracing::warn!("Could not retrieve model info from the Hugging Face hub."); - HubModelInfo { - model_id: tokenizer_name.to_string(), - sha: None, - pipeline_tag: None, - } - }); - - (tokenizer, model_info, config) - } else { - // No API and no local model - return Err(RouterError::ArgumentValidation( - "No local model found and no revision specified".to_string(), - )); - }; - - tracing::info!("Using config {config:?}"); - - // Load tokenizer config if found locally, or check if we can get it from the API if needed - let tokenizer_config = if let Some(path) = tokenizer_config_path { - tracing::info!( - "Using local tokenizer config from user specified path {}", - path - ); - HubTokenizerConfig::from_file(&std::path::PathBuf::from(path)) - } else if local_model { - tracing::info!("Using local tokenizer config"); - HubTokenizerConfig::from_file(&local_path.join("tokenizer_config.json")) - } else { - match api { - Some(api) => { - tracing::info!("Using the Hugging Face API to retrieve tokenizer config"); - let repo = Repo::with_revision( - tokenizer_name.to_string(), - RepoType::Model, - revision.unwrap_or("main".to_string()), - ); - get_tokenizer_config(&api.repo(repo)) - .await - .unwrap_or_else(|| { - tracing::warn!( - "Could not retrieve tokenizer config from the Hugging Face hub." - ); - HubTokenizerConfig::default() - }) - } - None => { - tracing::warn!("Could not find tokenizer config locally and no API specified"); - HubTokenizerConfig::default() - } + let model_info = if let Some(model_info) = get_model_info(&api_repo).await { + Some(model_info) + } else { + tracing::warn!("Could not retrieve model info from the Hugging Face hub."); + None + }; + ( + tokenizer_filename, + config_filename, + tokenizer_config_filename, + model_info, + ) + } + Type::Cache(cache) => { + let repo = cache.repo(Repo::with_revision( + tokenizer_name.to_string(), + RepoType::Model, + revision.clone().unwrap_or_else(|| "main".to_string()), + )); + ( + repo.get("tokenizer.json"), + repo.get("config.json"), + repo.get("tokenizer_config.json"), + None, + ) } }; + let tokenizer: Option = + tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok()); + let config: Option = config_filename.and_then(|filename| { + std::fs::read_to_string(filename) + .ok() + .as_ref() + .and_then(|c| { + let config: Result = serde_json::from_str(c); + if let Err(err) = &config { + tracing::warn!("Could not parse config {err:?}"); + } + config.ok() + }) + }); + let model_info = model_info.unwrap_or_else(|| HubModelInfo { + model_id: tokenizer_name.to_string(), + sha: None, + pipeline_tag: None, + }); + // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. + let tokenizer_config: Option = if let Some(filename) = tokenizer_config_path + { + HubTokenizerConfig::from_file(filename) + } else { + tokenizer_config_filename.and_then(HubTokenizerConfig::from_file) + }; + let tokenizer_config = tokenizer_config.unwrap_or_else(|| { + tracing::warn!("Could not find tokenizer config locally and no API specified"); + HubTokenizerConfig::default() + }); + + tracing::info!("Using config {config:?}"); if tokenizer.is_none() { tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}"); tracing::warn!("Rust input length validation and truncation is disabled"); @@ -483,7 +483,7 @@ pub async fn get_model_info(api: &ApiRepo) -> Option { } /// get base tokenizer -pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option { +pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option { let config_filename = api_repo.get("config.json").await.ok()?; // Open the file in read-only mode with buffer. @@ -500,8 +500,7 @@ pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option