From ee1cf51ce796e4b034eedaf3e909b4c902eae70c Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 1 Feb 2024 09:39:32 -0500 Subject: [PATCH] fix: tokenizer config should use local model path when possible (#1518) This PR fixes the issue with loading a local tokenizer config. Previously the default functionality would look in the current working directory. Now if a local model path is specified we will check that directory for the tokenizer_config. ## Examples of valid commands uses tokenizer_config from hub ``` text-generation-launcher --model-id HuggingFaceH4/zephyr-7b-beta ``` use tokenizer_config from local model path ``` text-generation-launcher \ --model-id ~/.cache/huggingface/hub/models--HuggingFaceH4--zephyr-7b-beta/snapshots/dc24cabd13eacd3ae3a5fe574bd645483a335a4a/ ``` use specific tokenizer_config file ``` text-generation-launcher \ --model-id ~/.cache/huggingface/hub/models--HuggingFaceH4--zephyr-7b-beta/snapshots/dc24cabd13eacd3ae3a5fe574bd645483a335a4a/ \ --tokenizer-config-path ~/.cache/huggingface/hub/models--HuggingFaceH4--zephyr-7b-beta/snapshots/dc24cabd13eacd3ae3a5fe574bd645483a335a4a/tokenizer_config.json ``` --------- Co-authored-by: Nicolas Patry --- router/src/lib.rs | 2 +- router/src/main.rs | 49 +++++++++++++++++++++++++--------------------- 2 files changed, 28 insertions(+), 23 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index fc5670a0..07360e78 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -37,7 +37,7 @@ pub struct HubTokenizerConfig { } impl HubTokenizerConfig { - pub fn from_file(filename: &str) -> Self { + pub fn from_file(filename: &std::path::Path) -> Self { let content = std::fs::read_to_string(filename).unwrap(); serde_json::from_str(&content).unwrap_or_default() } diff --git a/router/src/main.rs b/router/src/main.rs index 495fd5bc..2a080468 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -154,12 +154,6 @@ async fn main() -> Result<(), RouterError> { let local_path = Path::new(&tokenizer_name); let local_model = local_path.exists() && local_path.is_dir(); - // Load tokenizer config - // This will be used to format the chat template - let local_tokenizer_config_path = - tokenizer_config_path.unwrap_or("tokenizer_config.json".to_string()); - let local_tokenizer_config = Path::new(&local_tokenizer_config_path).exists(); - // Shared API builder initialization let api_builder = || { let mut builder = ApiBuilder::new() @@ -230,24 +224,35 @@ async fn main() -> Result<(), RouterError> { }; // Load tokenizer config if found locally, or check if we can get it from the API if needed - let tokenizer_config = if local_tokenizer_config { + let tokenizer_config = if let Some(path) = tokenizer_config_path { + tracing::info!("Using local tokenizer config from user specified path"); + HubTokenizerConfig::from_file(&std::path::PathBuf::from(path)) + } else if local_model { tracing::info!("Using local tokenizer config"); - HubTokenizerConfig::from_file(&local_tokenizer_config_path) - } else if let Some(api) = api { - tracing::info!("Using the Hugging Face API to retrieve tokenizer config"); - get_tokenizer_config(&api.repo(Repo::with_revision( - tokenizer_name.to_string(), - RepoType::Model, - revision.unwrap_or_else(|| "main".to_string()), - ))) - .await - .unwrap_or_else(|| { - tracing::warn!("Could not retrieve tokenizer config from the Hugging Face hub."); - HubTokenizerConfig::default() - }) + HubTokenizerConfig::from_file(&local_path.join("tokenizer_config.json")) } else { - tracing::warn!("Could not find tokenizer config locally and no revision specified"); - HubTokenizerConfig::default() + match api { + Some(api) => { + tracing::info!("Using the Hugging Face API to retrieve tokenizer config"); + let repo = Repo::with_revision( + tokenizer_name.to_string(), + RepoType::Model, + revision.unwrap_or("main".to_string()), + ); + get_tokenizer_config(&api.repo(repo)) + .await + .unwrap_or_else(|| { + tracing::warn!( + "Could not retrieve tokenizer config from the Hugging Face hub." + ); + HubTokenizerConfig::default() + }) + } + None => { + tracing::warn!("Could not find tokenizer config locally and no API specified"); + HubTokenizerConfig::default() + } + } }; if tokenizer.is_none() {