add trust_remote_code in tokenizer to fix baichuan issue (#2725)

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2024-11-07 21:43:38 +08:00 committed by GitHub
parent b1f9044d6c
commit 97f7a22f0b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 12 additions and 3 deletions

View File

@ -27,6 +27,7 @@ pub enum Tokenizer {
Python { Python {
tokenizer_name: String, tokenizer_name: String,
revision: Option<String>, revision: Option<String>,
trust_remote_code: bool,
}, },
Rust(tokenizers::Tokenizer), Rust(tokenizers::Tokenizer),
} }
@ -38,15 +39,20 @@ impl<'a> PyTokenizer<'a> {
py: Python<'a>, py: Python<'a>,
tokenizer_name: String, tokenizer_name: String,
revision: Option<String>, revision: Option<String>,
trust_remote_code: bool,
) -> PyResult<PyTokenizer<'a>> { ) -> PyResult<PyTokenizer<'a>> {
let transformers = py.import_bound("transformers")?; let transformers = py.import_bound("transformers")?;
let auto = transformers.getattr("AutoTokenizer")?; let auto = transformers.getattr("AutoTokenizer")?;
let from_pretrained = auto.getattr("from_pretrained")?; let from_pretrained = auto.getattr("from_pretrained")?;
let args = (tokenizer_name,); let args = (tokenizer_name,);
let kwargs = if let Some(rev) = &revision { let kwargs = if let Some(rev) = &revision {
[("revision", rev.to_string())].into_py_dict_bound(py) [
("revision", rev.to_string().into_py(py)),
("trust_remote_code", trust_remote_code.into_py(py)),
]
.into_py_dict_bound(py)
} else { } else {
pyo3::types::PyDict::new_bound(py) [("trust_remote_code", trust_remote_code.into_py(py))].into_py_dict_bound(py)
}; };
let tokenizer = from_pretrained.call(args, Some(&kwargs))?; let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
tracing::info!("Loaded a python tokenizer"); tracing::info!("Loaded a python tokenizer");

View File

@ -1829,6 +1829,7 @@ pub async fn run(
Tokenizer::Python { Tokenizer::Python {
tokenizer_name: tokenizer_name.clone(), tokenizer_name: tokenizer_name.clone(),
revision: revision.clone(), revision: revision.clone(),
trust_remote_code,
} }
} }
}; };

View File

@ -439,9 +439,11 @@ fn tokenizer_worker(
Tokenizer::Python { Tokenizer::Python {
tokenizer_name, tokenizer_name,
revision, revision,
trust_remote_code,
} => { } => {
pyo3::Python::with_gil(|py| -> pyo3::PyResult<()> { pyo3::Python::with_gil(|py| -> pyo3::PyResult<()> {
let tokenizer = PyTokenizer::from_py(py, tokenizer_name, revision)?; let tokenizer =
PyTokenizer::from_py(py, tokenizer_name, revision, trust_remote_code)?;
// Loop over requests // Loop over requests
while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) = while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
receiver.blocking_recv() receiver.blocking_recv()