misc(deps): add pyo3 to dependencies
This commit is contained in:
parent
9cee00eec3
commit
512225474a
|
@ -4187,6 +4187,7 @@ dependencies = [
|
||||||
"hf-hub",
|
"hf-hub",
|
||||||
"log",
|
"log",
|
||||||
"pkg-config",
|
"pkg-config",
|
||||||
|
"pyo3",
|
||||||
"text-generation-router",
|
"text-generation-router",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
|
|
|
@ -13,6 +13,7 @@ cxx = "1.0"
|
||||||
hashbrown = "0.14"
|
hashbrown = "0.14"
|
||||||
hf-hub = { workspace = true }
|
hf-hub = { workspace = true }
|
||||||
log = { version = "0.4", features = [] }
|
log = { version = "0.4", features = [] }
|
||||||
|
pyo3 = { workspace = true }
|
||||||
text-generation-router = { path = "../../router" }
|
text-generation-router = { path = "../../router" }
|
||||||
tokenizers = { workspace = true }
|
tokenizers = { workspace = true }
|
||||||
tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||||
|
|
|
@ -3,12 +3,13 @@ use std::path::{Path, PathBuf};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use hf_hub::api::tokio::{Api, ApiBuilder};
|
use hf_hub::api::tokio::{Api, ApiBuilder};
|
||||||
use hf_hub::{Cache, Repo, RepoType};
|
use hf_hub::{Cache, Repo, RepoType};
|
||||||
|
use pyo3::types::IntoPyDict;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
|
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
|
||||||
use text_generation_backends_trtllm::TensorRtLlmBackendV2;
|
use text_generation_backends_trtllm::TensorRtLlmBackendV2;
|
||||||
use text_generation_router::server::get_base_tokenizer;
|
use text_generation_router::server::{get_base_tokenizer, get_hub_model_info};
|
||||||
use text_generation_router::usage_stats::UsageStatsLevel;
|
use text_generation_router::usage_stats::UsageStatsLevel;
|
||||||
use text_generation_router::{server, HubTokenizerConfig};
|
use text_generation_router::{server, HubTokenizerConfig};
|
||||||
|
|
||||||
|
@ -129,6 +130,7 @@ async fn get_tokenizer(
|
||||||
tokenizer_config_filename,
|
tokenizer_config_filename,
|
||||||
_preprocessor_config_filename,
|
_preprocessor_config_filename,
|
||||||
_processor_config_filename,
|
_processor_config_filename,
|
||||||
|
_model_info,
|
||||||
) = match api {
|
) = match api {
|
||||||
Type::None => (
|
Type::None => (
|
||||||
Some(local_path.join("tokenizer.json")),
|
Some(local_path.join("tokenizer.json")),
|
||||||
|
@ -136,12 +138,13 @@ async fn get_tokenizer(
|
||||||
Some(local_path.join("tokenizer_config.json")),
|
Some(local_path.join("tokenizer_config.json")),
|
||||||
Some(local_path.join("preprocessor_config.json")),
|
Some(local_path.join("preprocessor_config.json")),
|
||||||
Some(local_path.join("processor_config.json")),
|
Some(local_path.join("processor_config.json")),
|
||||||
|
None,
|
||||||
),
|
),
|
||||||
Type::Api(api) => {
|
Type::Api(api) => {
|
||||||
let api_repo = api.repo(Repo::with_revision(
|
let api_repo = api.repo(Repo::with_revision(
|
||||||
tokenizer_name.to_string(),
|
tokenizer_name.to_string(),
|
||||||
RepoType::Model,
|
RepoType::Model,
|
||||||
revision.unwrap_or_else(|| "main").to_string(),
|
revision.clone().unwrap_or("main").to_string(),
|
||||||
));
|
));
|
||||||
|
|
||||||
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
|
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
|
||||||
|
@ -153,19 +156,26 @@ async fn get_tokenizer(
|
||||||
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
|
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
|
||||||
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
|
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
|
||||||
|
|
||||||
|
let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await {
|
||||||
|
Some(model_info)
|
||||||
|
} else {
|
||||||
|
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||||
|
None
|
||||||
|
};
|
||||||
(
|
(
|
||||||
tokenizer_filename,
|
tokenizer_filename,
|
||||||
config_filename,
|
config_filename,
|
||||||
tokenizer_config_filename,
|
tokenizer_config_filename,
|
||||||
preprocessor_config_filename,
|
preprocessor_config_filename,
|
||||||
processor_config_filename,
|
processor_config_filename,
|
||||||
|
model_info,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
Type::Cache(cache) => {
|
Type::Cache(cache) => {
|
||||||
let repo = cache.repo(Repo::with_revision(
|
let repo = cache.repo(Repo::with_revision(
|
||||||
tokenizer_name.to_string(),
|
tokenizer_name.to_string(),
|
||||||
RepoType::Model,
|
RepoType::Model,
|
||||||
revision.clone().unwrap_or_else(|| "main").to_string(),
|
revision.clone().unwrap_or("main").to_string(),
|
||||||
));
|
));
|
||||||
(
|
(
|
||||||
repo.get("tokenizer.json"),
|
repo.get("tokenizer.json"),
|
||||||
|
@ -173,6 +183,7 @@ async fn get_tokenizer(
|
||||||
repo.get("tokenizer_config.json"),
|
repo.get("tokenizer_config.json"),
|
||||||
repo.get("preprocessor_config.json"),
|
repo.get("preprocessor_config.json"),
|
||||||
repo.get("processor_config.json"),
|
repo.get("processor_config.json"),
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -184,8 +195,43 @@ async fn get_tokenizer(
|
||||||
} else {
|
} else {
|
||||||
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
||||||
};
|
};
|
||||||
|
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
||||||
|
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
||||||
|
HubTokenizerConfig::default()
|
||||||
|
});
|
||||||
|
|
||||||
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok())
|
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
|
||||||
|
use pyo3::prelude::*;
|
||||||
|
let convert = pyo3::Python::with_gil(|py| -> PyResult<()> {
|
||||||
|
let transformers = py.import_bound("transformers")?;
|
||||||
|
let auto = transformers.getattr("AutoTokenizer")?;
|
||||||
|
let from_pretrained = auto.getattr("from_pretrained")?;
|
||||||
|
let args = (tokenizer_name.to_string(),);
|
||||||
|
let kwargs = [(
|
||||||
|
"revision",
|
||||||
|
revision.clone().unwrap_or_else(|| "main"),
|
||||||
|
)]
|
||||||
|
.into_py_dict_bound(py);
|
||||||
|
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
|
||||||
|
let save = tokenizer.getattr("save_pretrained")?;
|
||||||
|
let args = ("out".to_string(),);
|
||||||
|
save.call1(args)?;
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.inspect_err(|err| {
|
||||||
|
tracing::error!("Failed to import python tokenizer {err}");
|
||||||
|
});
|
||||||
|
let filename = if convert.is_ok() {
|
||||||
|
// If we have correctly loaded and resaved with transformers
|
||||||
|
// We might have modified the tokenizer.json according to transformers
|
||||||
|
"out/tokenizer.json".into()
|
||||||
|
} else {
|
||||||
|
filename
|
||||||
|
};
|
||||||
|
Tokenizer::from_file(filename).ok()
|
||||||
|
});
|
||||||
|
|
||||||
|
tokenizer
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
|
|
Loading…
Reference in New Issue