Fixing "deadlock" when python prompts for trust_remote_code by always (#2664)

specifiying a value.
This commit is contained in:
Nicolas Patry 2024-10-25 06:39:21 +02:00 committed by GitHub
parent eab07f746c
commit ed87b464b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 20 additions and 4 deletions

View File

@ -44,6 +44,8 @@ struct Args {
tokenizer_config_path: Option<String>,
#[clap(long, env)]
revision: Option<String>,
#[clap(long, env, value_enum)]
trust_remote_code: bool,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
@ -99,6 +101,7 @@ async fn main() -> Result<(), RouterError> {
tokenizer_name,
tokenizer_config_path,
revision,
trust_remote_code,
validation_workers,
api_key,
json_output,
@ -181,6 +184,7 @@ async fn main() -> Result<(), RouterError> {
tokenizer_name,
tokenizer_config_path,
revision,
trust_remote_code,
hostname,
port,
cors_allow_origin,

View File

@ -44,6 +44,8 @@ struct Args {
tokenizer_config_path: Option<String>,
#[clap(long, env)]
revision: Option<String>,
#[clap(long, env, value_enum)]
trust_remote_code: bool,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
@ -99,6 +101,7 @@ async fn main() -> Result<(), RouterError> {
tokenizer_name,
tokenizer_config_path,
revision,
trust_remote_code,
validation_workers,
api_key,
json_output,
@ -181,6 +184,7 @@ async fn main() -> Result<(), RouterError> {
tokenizer_name,
tokenizer_config_path,
revision,
trust_remote_code,
hostname,
port,
cors_allow_origin,

View File

@ -1509,6 +1509,10 @@ fn spawn_webserver(
router_args.push(revision.to_string())
}
if args.trust_remote_code {
router_args.push("--trust-remote-code".to_string());
}
if args.json_output {
router_args.push("--json-output".to_string());
}

View File

@ -1609,6 +1609,7 @@ pub async fn run(
tokenizer_name: String,
tokenizer_config_path: Option<String>,
revision: Option<String>,
trust_remote_code: bool,
hostname: String,
port: u16,
cors_allow_origin: Option<Vec<String>>,
@ -1768,10 +1769,13 @@ pub async fn run(
let auto = transformers.getattr("AutoTokenizer")?;
let from_pretrained = auto.getattr("from_pretrained")?;
let args = (tokenizer_name.to_string(),);
let kwargs = [(
"revision",
revision.clone().unwrap_or_else(|| "main".to_string()),
)]
let kwargs = [
(
"revision",
(revision.clone().unwrap_or_else(|| "main".to_string())).into_py(py),
),
("trust_remote_code", trust_remote_code.into_py(py)),
]
.into_py_dict_bound(py);
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
let save = tokenizer.getattr("save_pretrained")?;