add trust_remote_code in tokenizer to fix baichuan issue (#2725)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
b1f9044d6c
commit
97f7a22f0b
|
@ -27,6 +27,7 @@ pub enum Tokenizer {
|
|||
Python {
|
||||
tokenizer_name: String,
|
||||
revision: Option<String>,
|
||||
trust_remote_code: bool,
|
||||
},
|
||||
Rust(tokenizers::Tokenizer),
|
||||
}
|
||||
|
@ -38,15 +39,20 @@ impl<'a> PyTokenizer<'a> {
|
|||
py: Python<'a>,
|
||||
tokenizer_name: String,
|
||||
revision: Option<String>,
|
||||
trust_remote_code: bool,
|
||||
) -> PyResult<PyTokenizer<'a>> {
|
||||
let transformers = py.import_bound("transformers")?;
|
||||
let auto = transformers.getattr("AutoTokenizer")?;
|
||||
let from_pretrained = auto.getattr("from_pretrained")?;
|
||||
let args = (tokenizer_name,);
|
||||
let kwargs = if let Some(rev) = &revision {
|
||||
[("revision", rev.to_string())].into_py_dict_bound(py)
|
||||
[
|
||||
("revision", rev.to_string().into_py(py)),
|
||||
("trust_remote_code", trust_remote_code.into_py(py)),
|
||||
]
|
||||
.into_py_dict_bound(py)
|
||||
} else {
|
||||
pyo3::types::PyDict::new_bound(py)
|
||||
[("trust_remote_code", trust_remote_code.into_py(py))].into_py_dict_bound(py)
|
||||
};
|
||||
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
|
||||
tracing::info!("Loaded a python tokenizer");
|
||||
|
|
|
@ -1829,6 +1829,7 @@ pub async fn run(
|
|||
Tokenizer::Python {
|
||||
tokenizer_name: tokenizer_name.clone(),
|
||||
revision: revision.clone(),
|
||||
trust_remote_code,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -439,9 +439,11 @@ fn tokenizer_worker(
|
|||
Tokenizer::Python {
|
||||
tokenizer_name,
|
||||
revision,
|
||||
trust_remote_code,
|
||||
} => {
|
||||
pyo3::Python::with_gil(|py| -> pyo3::PyResult<()> {
|
||||
let tokenizer = PyTokenizer::from_py(py, tokenizer_name, revision)?;
|
||||
let tokenizer =
|
||||
PyTokenizer::from_py(py, tokenizer_name, revision, trust_remote_code)?;
|
||||
// Loop over requests
|
||||
while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
|
||||
receiver.blocking_recv()
|
||||
|
|
Loading…
Reference in New Issue