fix: tokenizer config should use local model path when possible (#1518)
This PR fixes the issue with loading a local tokenizer config. Previously the default functionality would look in the current working directory. Now if a local model path is specified we will check that directory for the tokenizer_config. ## Examples of valid commands uses tokenizer_config from hub ``` text-generation-launcher --model-id HuggingFaceH4/zephyr-7b-beta ``` use tokenizer_config from local model path ``` text-generation-launcher \ --model-id ~/.cache/huggingface/hub/models--HuggingFaceH4--zephyr-7b-beta/snapshots/dc24cabd13eacd3ae3a5fe574bd645483a335a4a/ ``` use specific tokenizer_config file ``` text-generation-launcher \ --model-id ~/.cache/huggingface/hub/models--HuggingFaceH4--zephyr-7b-beta/snapshots/dc24cabd13eacd3ae3a5fe574bd645483a335a4a/ \ --tokenizer-config-path ~/.cache/huggingface/hub/models--HuggingFaceH4--zephyr-7b-beta/snapshots/dc24cabd13eacd3ae3a5fe574bd645483a335a4a/tokenizer_config.json ``` --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
parent
1e03b61b5c
commit
ee1cf51ce7
|
@ -37,7 +37,7 @@ pub struct HubTokenizerConfig {
|
|||
}
|
||||
|
||||
impl HubTokenizerConfig {
|
||||
pub fn from_file(filename: &str) -> Self {
|
||||
pub fn from_file(filename: &std::path::Path) -> Self {
|
||||
let content = std::fs::read_to_string(filename).unwrap();
|
||||
serde_json::from_str(&content).unwrap_or_default()
|
||||
}
|
||||
|
|
|
@ -154,12 +154,6 @@ async fn main() -> Result<(), RouterError> {
|
|||
let local_path = Path::new(&tokenizer_name);
|
||||
let local_model = local_path.exists() && local_path.is_dir();
|
||||
|
||||
// Load tokenizer config
|
||||
// This will be used to format the chat template
|
||||
let local_tokenizer_config_path =
|
||||
tokenizer_config_path.unwrap_or("tokenizer_config.json".to_string());
|
||||
let local_tokenizer_config = Path::new(&local_tokenizer_config_path).exists();
|
||||
|
||||
// Shared API builder initialization
|
||||
let api_builder = || {
|
||||
let mut builder = ApiBuilder::new()
|
||||
|
@ -230,24 +224,35 @@ async fn main() -> Result<(), RouterError> {
|
|||
};
|
||||
|
||||
// Load tokenizer config if found locally, or check if we can get it from the API if needed
|
||||
let tokenizer_config = if local_tokenizer_config {
|
||||
let tokenizer_config = if let Some(path) = tokenizer_config_path {
|
||||
tracing::info!("Using local tokenizer config from user specified path");
|
||||
HubTokenizerConfig::from_file(&std::path::PathBuf::from(path))
|
||||
} else if local_model {
|
||||
tracing::info!("Using local tokenizer config");
|
||||
HubTokenizerConfig::from_file(&local_tokenizer_config_path)
|
||||
} else if let Some(api) = api {
|
||||
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
|
||||
get_tokenizer_config(&api.repo(Repo::with_revision(
|
||||
tokenizer_name.to_string(),
|
||||
RepoType::Model,
|
||||
revision.unwrap_or_else(|| "main".to_string()),
|
||||
)))
|
||||
.await
|
||||
.unwrap_or_else(|| {
|
||||
tracing::warn!("Could not retrieve tokenizer config from the Hugging Face hub.");
|
||||
HubTokenizerConfig::default()
|
||||
})
|
||||
HubTokenizerConfig::from_file(&local_path.join("tokenizer_config.json"))
|
||||
} else {
|
||||
tracing::warn!("Could not find tokenizer config locally and no revision specified");
|
||||
HubTokenizerConfig::default()
|
||||
match api {
|
||||
Some(api) => {
|
||||
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
|
||||
let repo = Repo::with_revision(
|
||||
tokenizer_name.to_string(),
|
||||
RepoType::Model,
|
||||
revision.unwrap_or("main".to_string()),
|
||||
);
|
||||
get_tokenizer_config(&api.repo(repo))
|
||||
.await
|
||||
.unwrap_or_else(|| {
|
||||
tracing::warn!(
|
||||
"Could not retrieve tokenizer config from the Hugging Face hub."
|
||||
);
|
||||
HubTokenizerConfig::default()
|
||||
})
|
||||
}
|
||||
None => {
|
||||
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
||||
HubTokenizerConfig::default()
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if tokenizer.is_none() {
|
||||
|
|
Loading…
Reference in New Issue