diff --git a/router/src/lib.rs b/router/src/lib.rs index fc5670a0..07360e78 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -37,7 +37,7 @@ pub struct HubTokenizerConfig { } impl HubTokenizerConfig { - pub fn from_file(filename: &str) -> Self { + pub fn from_file(filename: &std::path::Path) -> Self { let content = std::fs::read_to_string(filename).unwrap(); serde_json::from_str(&content).unwrap_or_default() } diff --git a/router/src/main.rs b/router/src/main.rs index 495fd5bc..2a080468 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -154,12 +154,6 @@ async fn main() -> Result<(), RouterError> { let local_path = Path::new(&tokenizer_name); let local_model = local_path.exists() && local_path.is_dir(); - // Load tokenizer config - // This will be used to format the chat template - let local_tokenizer_config_path = - tokenizer_config_path.unwrap_or("tokenizer_config.json".to_string()); - let local_tokenizer_config = Path::new(&local_tokenizer_config_path).exists(); - // Shared API builder initialization let api_builder = || { let mut builder = ApiBuilder::new() @@ -230,24 +224,35 @@ async fn main() -> Result<(), RouterError> { }; // Load tokenizer config if found locally, or check if we can get it from the API if needed - let tokenizer_config = if local_tokenizer_config { + let tokenizer_config = if let Some(path) = tokenizer_config_path { + tracing::info!("Using local tokenizer config from user specified path"); + HubTokenizerConfig::from_file(&std::path::PathBuf::from(path)) + } else if local_model { tracing::info!("Using local tokenizer config"); - HubTokenizerConfig::from_file(&local_tokenizer_config_path) - } else if let Some(api) = api { - tracing::info!("Using the Hugging Face API to retrieve tokenizer config"); - get_tokenizer_config(&api.repo(Repo::with_revision( - tokenizer_name.to_string(), - RepoType::Model, - revision.unwrap_or_else(|| "main".to_string()), - ))) - .await - .unwrap_or_else(|| { - tracing::warn!("Could not retrieve tokenizer config from the Hugging Face hub."); - HubTokenizerConfig::default() - }) + HubTokenizerConfig::from_file(&local_path.join("tokenizer_config.json")) } else { - tracing::warn!("Could not find tokenizer config locally and no revision specified"); - HubTokenizerConfig::default() + match api { + Some(api) => { + tracing::info!("Using the Hugging Face API to retrieve tokenizer config"); + let repo = Repo::with_revision( + tokenizer_name.to_string(), + RepoType::Model, + revision.unwrap_or("main".to_string()), + ); + get_tokenizer_config(&api.repo(repo)) + .await + .unwrap_or_else(|| { + tracing::warn!( + "Could not retrieve tokenizer config from the Hugging Face hub." + ); + HubTokenizerConfig::default() + }) + } + None => { + tracing::warn!("Could not find tokenizer config locally and no API specified"); + HubTokenizerConfig::default() + } + } }; if tokenizer.is_none() {