misc(deps): add pyo3 to dependencies
This commit is contained in:
parent
9cee00eec3
commit
512225474a
|
@ -4187,6 +4187,7 @@ dependencies = [
|
|||
"hf-hub",
|
||||
"log",
|
||||
"pkg-config",
|
||||
"pyo3",
|
||||
"text-generation-router",
|
||||
"thiserror",
|
||||
"tokenizers",
|
||||
|
|
|
@ -13,6 +13,7 @@ cxx = "1.0"
|
|||
hashbrown = "0.14"
|
||||
hf-hub = { workspace = true }
|
||||
log = { version = "0.4", features = [] }
|
||||
pyo3 = { workspace = true }
|
||||
text-generation-router = { path = "../../router" }
|
||||
tokenizers = { workspace = true }
|
||||
tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
|
|
|
@ -3,12 +3,13 @@ use std::path::{Path, PathBuf};
|
|||
use clap::Parser;
|
||||
use hf_hub::api::tokio::{Api, ApiBuilder};
|
||||
use hf_hub::{Cache, Repo, RepoType};
|
||||
use pyo3::types::IntoPyDict;
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing::info;
|
||||
|
||||
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
|
||||
use text_generation_backends_trtllm::TensorRtLlmBackendV2;
|
||||
use text_generation_router::server::get_base_tokenizer;
|
||||
use text_generation_router::server::{get_base_tokenizer, get_hub_model_info};
|
||||
use text_generation_router::usage_stats::UsageStatsLevel;
|
||||
use text_generation_router::{server, HubTokenizerConfig};
|
||||
|
||||
|
@ -129,6 +130,7 @@ async fn get_tokenizer(
|
|||
tokenizer_config_filename,
|
||||
_preprocessor_config_filename,
|
||||
_processor_config_filename,
|
||||
_model_info,
|
||||
) = match api {
|
||||
Type::None => (
|
||||
Some(local_path.join("tokenizer.json")),
|
||||
|
@ -136,12 +138,13 @@ async fn get_tokenizer(
|
|||
Some(local_path.join("tokenizer_config.json")),
|
||||
Some(local_path.join("preprocessor_config.json")),
|
||||
Some(local_path.join("processor_config.json")),
|
||||
None,
|
||||
),
|
||||
Type::Api(api) => {
|
||||
let api_repo = api.repo(Repo::with_revision(
|
||||
tokenizer_name.to_string(),
|
||||
RepoType::Model,
|
||||
revision.unwrap_or_else(|| "main").to_string(),
|
||||
revision.clone().unwrap_or("main").to_string(),
|
||||
));
|
||||
|
||||
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
|
||||
|
@ -153,19 +156,26 @@ async fn get_tokenizer(
|
|||
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
|
||||
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
|
||||
|
||||
let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await {
|
||||
Some(model_info)
|
||||
} else {
|
||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||
None
|
||||
};
|
||||
(
|
||||
tokenizer_filename,
|
||||
config_filename,
|
||||
tokenizer_config_filename,
|
||||
preprocessor_config_filename,
|
||||
processor_config_filename,
|
||||
model_info,
|
||||
)
|
||||
}
|
||||
Type::Cache(cache) => {
|
||||
let repo = cache.repo(Repo::with_revision(
|
||||
tokenizer_name.to_string(),
|
||||
RepoType::Model,
|
||||
revision.clone().unwrap_or_else(|| "main").to_string(),
|
||||
revision.clone().unwrap_or("main").to_string(),
|
||||
));
|
||||
(
|
||||
repo.get("tokenizer.json"),
|
||||
|
@ -173,6 +183,7 @@ async fn get_tokenizer(
|
|||
repo.get("tokenizer_config.json"),
|
||||
repo.get("preprocessor_config.json"),
|
||||
repo.get("processor_config.json"),
|
||||
None,
|
||||
)
|
||||
}
|
||||
};
|
||||
|
@ -184,8 +195,43 @@ async fn get_tokenizer(
|
|||
} else {
|
||||
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
||||
};
|
||||
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
||||
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
||||
HubTokenizerConfig::default()
|
||||
});
|
||||
|
||||
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok())
|
||||
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
|
||||
use pyo3::prelude::*;
|
||||
let convert = pyo3::Python::with_gil(|py| -> PyResult<()> {
|
||||
let transformers = py.import_bound("transformers")?;
|
||||
let auto = transformers.getattr("AutoTokenizer")?;
|
||||
let from_pretrained = auto.getattr("from_pretrained")?;
|
||||
let args = (tokenizer_name.to_string(),);
|
||||
let kwargs = [(
|
||||
"revision",
|
||||
revision.clone().unwrap_or_else(|| "main"),
|
||||
)]
|
||||
.into_py_dict_bound(py);
|
||||
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
|
||||
let save = tokenizer.getattr("save_pretrained")?;
|
||||
let args = ("out".to_string(),);
|
||||
save.call1(args)?;
|
||||
Ok(())
|
||||
})
|
||||
.inspect_err(|err| {
|
||||
tracing::error!("Failed to import python tokenizer {err}");
|
||||
});
|
||||
let filename = if convert.is_ok() {
|
||||
// If we have correctly loaded and resaved with transformers
|
||||
// We might have modified the tokenizer.json according to transformers
|
||||
"out/tokenizer.json".into()
|
||||
} else {
|
||||
filename
|
||||
};
|
||||
Tokenizer::from_file(filename).ok()
|
||||
});
|
||||
|
||||
tokenizer
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
|
|
Loading…
Reference in New Issue