misc(deps): add pyo3 to dependencies

This commit is contained in:
Morgan Funtowicz 2024-10-28 17:23:32 +01:00
parent 9cee00eec3
commit 512225474a
3 changed files with 52 additions and 4 deletions

1
Cargo.lock generated
View File

@ -4187,6 +4187,7 @@ dependencies = [
"hf-hub",
"log",
"pkg-config",
"pyo3",
"text-generation-router",
"thiserror",
"tokenizers",

View File

@ -13,6 +13,7 @@ cxx = "1.0"
hashbrown = "0.14"
hf-hub = { workspace = true }
log = { version = "0.4", features = [] }
pyo3 = { workspace = true }
text-generation-router = { path = "../../router" }
tokenizers = { workspace = true }
tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }

View File

@ -3,12 +3,13 @@ use std::path::{Path, PathBuf};
use clap::Parser;
use hf_hub::api::tokio::{Api, ApiBuilder};
use hf_hub::{Cache, Repo, RepoType};
use pyo3::types::IntoPyDict;
use tokenizers::Tokenizer;
use tracing::info;
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
use text_generation_backends_trtllm::TensorRtLlmBackendV2;
use text_generation_router::server::get_base_tokenizer;
use text_generation_router::server::{get_base_tokenizer, get_hub_model_info};
use text_generation_router::usage_stats::UsageStatsLevel;
use text_generation_router::{server, HubTokenizerConfig};
@ -129,6 +130,7 @@ async fn get_tokenizer(
tokenizer_config_filename,
_preprocessor_config_filename,
_processor_config_filename,
_model_info,
) = match api {
Type::None => (
Some(local_path.join("tokenizer.json")),
@ -136,12 +138,13 @@ async fn get_tokenizer(
Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("preprocessor_config.json")),
Some(local_path.join("processor_config.json")),
None,
),
Type::Api(api) => {
let api_repo = api.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.unwrap_or_else(|| "main").to_string(),
revision.clone().unwrap_or("main").to_string(),
));
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
@ -153,19 +156,26 @@ async fn get_tokenizer(
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await {
Some(model_info)
} else {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
None
};
(
tokenizer_filename,
config_filename,
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
model_info,
)
}
Type::Cache(cache) => {
let repo = cache.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main").to_string(),
revision.clone().unwrap_or("main").to_string(),
));
(
repo.get("tokenizer.json"),
@ -173,6 +183,7 @@ async fn get_tokenizer(
repo.get("tokenizer_config.json"),
repo.get("preprocessor_config.json"),
repo.get("processor_config.json"),
None,
)
}
};
@ -184,8 +195,43 @@ async fn get_tokenizer(
} else {
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
};
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
});
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok())
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
use pyo3::prelude::*;
let convert = pyo3::Python::with_gil(|py| -> PyResult<()> {
let transformers = py.import_bound("transformers")?;
let auto = transformers.getattr("AutoTokenizer")?;
let from_pretrained = auto.getattr("from_pretrained")?;
let args = (tokenizer_name.to_string(),);
let kwargs = [(
"revision",
revision.clone().unwrap_or_else(|| "main"),
)]
.into_py_dict_bound(py);
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
let save = tokenizer.getattr("save_pretrained")?;
let args = ("out".to_string(),);
save.call1(args)?;
Ok(())
})
.inspect_err(|err| {
tracing::error!("Failed to import python tokenizer {err}");
});
let filename = if convert.is_ok() {
// If we have correctly loaded and resaved with transformers
// We might have modified the tokenizer.json according to transformers
"out/tokenizer.json".into()
} else {
filename
};
Tokenizer::from_file(filename).ok()
});
tokenizer
}
#[tokio::main]