diff --git a/Cargo.lock b/Cargo.lock index c1251832..dc401c10 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4187,6 +4187,7 @@ dependencies = [ "hf-hub", "log", "pkg-config", + "pyo3", "text-generation-router", "thiserror", "tokenizers", diff --git a/backends/trtllm/Cargo.toml b/backends/trtllm/Cargo.toml index 97ef1a76..bed04183 100644 --- a/backends/trtllm/Cargo.toml +++ b/backends/trtllm/Cargo.toml @@ -13,6 +13,7 @@ cxx = "1.0" hashbrown = "0.14" hf-hub = { workspace = true } log = { version = "0.4", features = [] } +pyo3 = { workspace = true } text-generation-router = { path = "../../router" } tokenizers = { workspace = true } tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs index ec54ccce..5f1b70f0 100644 --- a/backends/trtllm/src/main.rs +++ b/backends/trtllm/src/main.rs @@ -3,12 +3,13 @@ use std::path::{Path, PathBuf}; use clap::Parser; use hf_hub::api::tokio::{Api, ApiBuilder}; use hf_hub::{Cache, Repo, RepoType}; +use pyo3::types::IntoPyDict; use tokenizers::Tokenizer; use tracing::info; use text_generation_backends_trtllm::errors::TensorRtLlmBackendError; use text_generation_backends_trtllm::TensorRtLlmBackendV2; -use text_generation_router::server::get_base_tokenizer; +use text_generation_router::server::{get_base_tokenizer, get_hub_model_info}; use text_generation_router::usage_stats::UsageStatsLevel; use text_generation_router::{server, HubTokenizerConfig}; @@ -129,6 +130,7 @@ async fn get_tokenizer( tokenizer_config_filename, _preprocessor_config_filename, _processor_config_filename, + _model_info, ) = match api { Type::None => ( Some(local_path.join("tokenizer.json")), @@ -136,12 +138,13 @@ async fn get_tokenizer( Some(local_path.join("tokenizer_config.json")), Some(local_path.join("preprocessor_config.json")), Some(local_path.join("processor_config.json")), + None, ), Type::Api(api) => { let api_repo = api.repo(Repo::with_revision( tokenizer_name.to_string(), RepoType::Model, - revision.unwrap_or_else(|| "main").to_string(), + revision.clone().unwrap_or("main").to_string(), )); let tokenizer_filename = match api_repo.get("tokenizer.json").await { @@ -153,19 +156,26 @@ async fn get_tokenizer( let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok(); let processor_config_filename = api_repo.get("processor_config.json").await.ok(); + let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await { + Some(model_info) + } else { + tracing::warn!("Could not retrieve model info from the Hugging Face hub."); + None + }; ( tokenizer_filename, config_filename, tokenizer_config_filename, preprocessor_config_filename, processor_config_filename, + model_info, ) } Type::Cache(cache) => { let repo = cache.repo(Repo::with_revision( tokenizer_name.to_string(), RepoType::Model, - revision.clone().unwrap_or_else(|| "main").to_string(), + revision.clone().unwrap_or("main").to_string(), )); ( repo.get("tokenizer.json"), @@ -173,6 +183,7 @@ async fn get_tokenizer( repo.get("tokenizer_config.json"), repo.get("preprocessor_config.json"), repo.get("processor_config.json"), + None, ) } }; @@ -184,8 +195,43 @@ async fn get_tokenizer( } else { tokenizer_config_filename.and_then(HubTokenizerConfig::from_file) }; + let tokenizer_config = tokenizer_config.unwrap_or_else(|| { + tracing::warn!("Could not find tokenizer config locally and no API specified"); + HubTokenizerConfig::default() + }); - tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok()) + let tokenizer: Option = tokenizer_filename.and_then(|filename| { + use pyo3::prelude::*; + let convert = pyo3::Python::with_gil(|py| -> PyResult<()> { + let transformers = py.import_bound("transformers")?; + let auto = transformers.getattr("AutoTokenizer")?; + let from_pretrained = auto.getattr("from_pretrained")?; + let args = (tokenizer_name.to_string(),); + let kwargs = [( + "revision", + revision.clone().unwrap_or_else(|| "main"), + )] + .into_py_dict_bound(py); + let tokenizer = from_pretrained.call(args, Some(&kwargs))?; + let save = tokenizer.getattr("save_pretrained")?; + let args = ("out".to_string(),); + save.call1(args)?; + Ok(()) + }) + .inspect_err(|err| { + tracing::error!("Failed to import python tokenizer {err}"); + }); + let filename = if convert.is_ok() { + // If we have correctly loaded and resaved with transformers + // We might have modified the tokenizer.json according to transformers + "out/tokenizer.json".into() + } else { + filename + }; + Tokenizer::from_file(filename).ok() + }); + + tokenizer } #[tokio::main]