Adding support for `HF_HUB_OFFLINE` support in the router. (#1789)
# What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
This commit is contained in:
parent
23d82b8fb6
commit
4c698fa6c2
|
@ -73,9 +73,9 @@ pub struct HubTokenizerConfig {
|
|||
}
|
||||
|
||||
impl HubTokenizerConfig {
|
||||
pub fn from_file(filename: &std::path::Path) -> Self {
|
||||
let content = std::fs::read_to_string(filename).unwrap();
|
||||
serde_json::from_str(&content).unwrap_or_default()
|
||||
pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> {
|
||||
let content = std::fs::read_to_string(filename).ok()?;
|
||||
serde_json::from_str(&content).ok()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use axum::http::HeaderValue;
|
||||
use clap::Parser;
|
||||
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
||||
use hf_hub::{Repo, RepoType};
|
||||
use hf_hub::{Cache, Repo, RepoType};
|
||||
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
||||
use opentelemetry::sdk::trace;
|
||||
use opentelemetry::sdk::trace::Sampler;
|
||||
|
@ -11,7 +11,7 @@ use opentelemetry_otlp::WithExportConfig;
|
|||
use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::path::Path;
|
||||
use std::path::{Path, PathBuf};
|
||||
use text_generation_client::{ClientError, ShardedClient};
|
||||
use text_generation_router::config::Config;
|
||||
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
|
||||
|
@ -162,7 +162,6 @@ async fn main() -> Result<(), RouterError> {
|
|||
// Tokenizer instance
|
||||
// This will only be used to validate payloads
|
||||
let local_path = Path::new(&tokenizer_name);
|
||||
let local_model = local_path.exists() && local_path.is_dir();
|
||||
|
||||
// Shared API builder initialization
|
||||
let api_builder = || {
|
||||
|
@ -181,112 +180,113 @@ async fn main() -> Result<(), RouterError> {
|
|||
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
|
||||
|
||||
// Initialize API if needed
|
||||
#[derive(Clone)]
|
||||
enum Type {
|
||||
Api(Api),
|
||||
Cache(Cache),
|
||||
None,
|
||||
}
|
||||
let api = if use_api {
|
||||
tracing::info!("Using the Hugging Face API");
|
||||
match api_builder().build() {
|
||||
Ok(api) => Some(api),
|
||||
Err(_) => {
|
||||
tracing::warn!("Unable to build the Hugging Face API");
|
||||
None
|
||||
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
|
||||
let cache = Cache::default();
|
||||
tracing::warn!("Offline mode active using cache defaults");
|
||||
Type::Cache(cache)
|
||||
} else {
|
||||
tracing::info!("Using the Hugging Face API");
|
||||
match api_builder().build() {
|
||||
Ok(api) => Type::Api(api),
|
||||
Err(_) => {
|
||||
tracing::warn!("Unable to build the Hugging Face API");
|
||||
Type::None
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
Type::None
|
||||
};
|
||||
|
||||
// Load tokenizer and model info
|
||||
let (tokenizer, model_info, config) = if local_model {
|
||||
let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok();
|
||||
let model_info = HubModelInfo {
|
||||
model_id: tokenizer_name.to_string(),
|
||||
sha: None,
|
||||
pipeline_tag: None,
|
||||
};
|
||||
let config: Option<Config> = std::fs::read_to_string(local_path.join("config.json"))
|
||||
.ok()
|
||||
.as_ref()
|
||||
.and_then(|c| serde_json::from_str(c).ok());
|
||||
let (tokenizer_filename, config_filename, tokenizer_config_filename, model_info) = match api {
|
||||
Type::None => (
|
||||
Some(local_path.join("tokenizer.json")),
|
||||
Some(local_path.join("config.json")),
|
||||
Some(local_path.join("tokenizer_config.json")),
|
||||
None,
|
||||
),
|
||||
Type::Api(api) => {
|
||||
let api_repo = api.repo(Repo::with_revision(
|
||||
tokenizer_name.to_string(),
|
||||
RepoType::Model,
|
||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
||||
));
|
||||
|
||||
(tokenizer, model_info, config)
|
||||
} else if let Some(api) = api.clone() {
|
||||
let api_repo = api.repo(Repo::with_revision(
|
||||
tokenizer_name.to_string(),
|
||||
RepoType::Model,
|
||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
||||
));
|
||||
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
|
||||
Ok(tokenizer_filename) => Some(tokenizer_filename),
|
||||
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
||||
};
|
||||
let config_filename = api_repo.get("config.json").await.ok();
|
||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
||||
|
||||
let tokenizer = match api_repo.get("tokenizer.json").await {
|
||||
Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(),
|
||||
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
||||
};
|
||||
|
||||
let config: Option<Config> = api_repo.get("config.json").await.ok().and_then(|filename| {
|
||||
std::fs::read_to_string(filename)
|
||||
.ok()
|
||||
.as_ref()
|
||||
.and_then(|c| {
|
||||
let config: Result<Config, _> = serde_json::from_str(c);
|
||||
if let Err(err) = &config {
|
||||
tracing::warn!("Could not parse config {err:?}");
|
||||
}
|
||||
config.ok()
|
||||
})
|
||||
});
|
||||
|
||||
let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| {
|
||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||
HubModelInfo {
|
||||
model_id: tokenizer_name.to_string(),
|
||||
sha: None,
|
||||
pipeline_tag: None,
|
||||
}
|
||||
});
|
||||
|
||||
(tokenizer, model_info, config)
|
||||
} else {
|
||||
// No API and no local model
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"No local model found and no revision specified".to_string(),
|
||||
));
|
||||
};
|
||||
|
||||
tracing::info!("Using config {config:?}");
|
||||
|
||||
// Load tokenizer config if found locally, or check if we can get it from the API if needed
|
||||
let tokenizer_config = if let Some(path) = tokenizer_config_path {
|
||||
tracing::info!(
|
||||
"Using local tokenizer config from user specified path {}",
|
||||
path
|
||||
);
|
||||
HubTokenizerConfig::from_file(&std::path::PathBuf::from(path))
|
||||
} else if local_model {
|
||||
tracing::info!("Using local tokenizer config");
|
||||
HubTokenizerConfig::from_file(&local_path.join("tokenizer_config.json"))
|
||||
} else {
|
||||
match api {
|
||||
Some(api) => {
|
||||
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
|
||||
let repo = Repo::with_revision(
|
||||
tokenizer_name.to_string(),
|
||||
RepoType::Model,
|
||||
revision.unwrap_or("main".to_string()),
|
||||
);
|
||||
get_tokenizer_config(&api.repo(repo))
|
||||
.await
|
||||
.unwrap_or_else(|| {
|
||||
tracing::warn!(
|
||||
"Could not retrieve tokenizer config from the Hugging Face hub."
|
||||
);
|
||||
HubTokenizerConfig::default()
|
||||
})
|
||||
}
|
||||
None => {
|
||||
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
||||
HubTokenizerConfig::default()
|
||||
}
|
||||
let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
|
||||
Some(model_info)
|
||||
} else {
|
||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||
None
|
||||
};
|
||||
(
|
||||
tokenizer_filename,
|
||||
config_filename,
|
||||
tokenizer_config_filename,
|
||||
model_info,
|
||||
)
|
||||
}
|
||||
Type::Cache(cache) => {
|
||||
let repo = cache.repo(Repo::with_revision(
|
||||
tokenizer_name.to_string(),
|
||||
RepoType::Model,
|
||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
||||
));
|
||||
(
|
||||
repo.get("tokenizer.json"),
|
||||
repo.get("config.json"),
|
||||
repo.get("tokenizer_config.json"),
|
||||
None,
|
||||
)
|
||||
}
|
||||
};
|
||||
let tokenizer: Option<Tokenizer> =
|
||||
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok());
|
||||
let config: Option<Config> = config_filename.and_then(|filename| {
|
||||
std::fs::read_to_string(filename)
|
||||
.ok()
|
||||
.as_ref()
|
||||
.and_then(|c| {
|
||||
let config: Result<Config, _> = serde_json::from_str(c);
|
||||
if let Err(err) = &config {
|
||||
tracing::warn!("Could not parse config {err:?}");
|
||||
}
|
||||
config.ok()
|
||||
})
|
||||
});
|
||||
let model_info = model_info.unwrap_or_else(|| HubModelInfo {
|
||||
model_id: tokenizer_name.to_string(),
|
||||
sha: None,
|
||||
pipeline_tag: None,
|
||||
});
|
||||
|
||||
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
||||
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
||||
{
|
||||
HubTokenizerConfig::from_file(filename)
|
||||
} else {
|
||||
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
||||
};
|
||||
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
||||
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
||||
HubTokenizerConfig::default()
|
||||
});
|
||||
|
||||
tracing::info!("Using config {config:?}");
|
||||
if tokenizer.is_none() {
|
||||
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
|
||||
tracing::warn!("Rust input length validation and truncation is disabled");
|
||||
|
@ -483,7 +483,7 @@ pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
|
|||
}
|
||||
|
||||
/// get base tokenizer
|
||||
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<Tokenizer> {
|
||||
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
|
||||
let config_filename = api_repo.get("config.json").await.ok()?;
|
||||
|
||||
// Open the file in read-only mode with buffer.
|
||||
|
@ -500,8 +500,7 @@ pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<Tokeniz
|
|||
"main".to_string(),
|
||||
));
|
||||
|
||||
let tokenizer_filename = api_base_repo.get("tokenizer.json").await.ok()?;
|
||||
Tokenizer::from_file(tokenizer_filename).ok()
|
||||
api_base_repo.get("tokenizer.json").await.ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue