add trust_remote_code in tokenizer to fix baichuan issue (#2725)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
b1f9044d6c
commit
97f7a22f0b
|
@ -27,6 +27,7 @@ pub enum Tokenizer {
|
||||||
Python {
|
Python {
|
||||||
tokenizer_name: String,
|
tokenizer_name: String,
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
|
trust_remote_code: bool,
|
||||||
},
|
},
|
||||||
Rust(tokenizers::Tokenizer),
|
Rust(tokenizers::Tokenizer),
|
||||||
}
|
}
|
||||||
|
@ -38,15 +39,20 @@ impl<'a> PyTokenizer<'a> {
|
||||||
py: Python<'a>,
|
py: Python<'a>,
|
||||||
tokenizer_name: String,
|
tokenizer_name: String,
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
|
trust_remote_code: bool,
|
||||||
) -> PyResult<PyTokenizer<'a>> {
|
) -> PyResult<PyTokenizer<'a>> {
|
||||||
let transformers = py.import_bound("transformers")?;
|
let transformers = py.import_bound("transformers")?;
|
||||||
let auto = transformers.getattr("AutoTokenizer")?;
|
let auto = transformers.getattr("AutoTokenizer")?;
|
||||||
let from_pretrained = auto.getattr("from_pretrained")?;
|
let from_pretrained = auto.getattr("from_pretrained")?;
|
||||||
let args = (tokenizer_name,);
|
let args = (tokenizer_name,);
|
||||||
let kwargs = if let Some(rev) = &revision {
|
let kwargs = if let Some(rev) = &revision {
|
||||||
[("revision", rev.to_string())].into_py_dict_bound(py)
|
[
|
||||||
|
("revision", rev.to_string().into_py(py)),
|
||||||
|
("trust_remote_code", trust_remote_code.into_py(py)),
|
||||||
|
]
|
||||||
|
.into_py_dict_bound(py)
|
||||||
} else {
|
} else {
|
||||||
pyo3::types::PyDict::new_bound(py)
|
[("trust_remote_code", trust_remote_code.into_py(py))].into_py_dict_bound(py)
|
||||||
};
|
};
|
||||||
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
|
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
|
||||||
tracing::info!("Loaded a python tokenizer");
|
tracing::info!("Loaded a python tokenizer");
|
||||||
|
|
|
@ -1829,6 +1829,7 @@ pub async fn run(
|
||||||
Tokenizer::Python {
|
Tokenizer::Python {
|
||||||
tokenizer_name: tokenizer_name.clone(),
|
tokenizer_name: tokenizer_name.clone(),
|
||||||
revision: revision.clone(),
|
revision: revision.clone(),
|
||||||
|
trust_remote_code,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -439,9 +439,11 @@ fn tokenizer_worker(
|
||||||
Tokenizer::Python {
|
Tokenizer::Python {
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
revision,
|
revision,
|
||||||
|
trust_remote_code,
|
||||||
} => {
|
} => {
|
||||||
pyo3::Python::with_gil(|py| -> pyo3::PyResult<()> {
|
pyo3::Python::with_gil(|py| -> pyo3::PyResult<()> {
|
||||||
let tokenizer = PyTokenizer::from_py(py, tokenizer_name, revision)?;
|
let tokenizer =
|
||||||
|
PyTokenizer::from_py(py, tokenizer_name, revision, trust_remote_code)?;
|
||||||
// Loop over requests
|
// Loop over requests
|
||||||
while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
|
while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
|
||||||
receiver.blocking_recv()
|
receiver.blocking_recv()
|
||||||
|
|
Loading…
Reference in New Issue