diff --git a/backends/v2/src/main.rs b/backends/v2/src/main.rs index bc00666c..ab4b7ce1 100644 --- a/backends/v2/src/main.rs +++ b/backends/v2/src/main.rs @@ -44,6 +44,8 @@ struct Args { tokenizer_config_path: Option, #[clap(long, env)] revision: Option, + #[clap(long, env, value_enum)] + trust_remote_code: bool, #[clap(default_value = "2", long, env)] validation_workers: usize, #[clap(long, env)] @@ -99,6 +101,7 @@ async fn main() -> Result<(), RouterError> { tokenizer_name, tokenizer_config_path, revision, + trust_remote_code, validation_workers, api_key, json_output, @@ -181,6 +184,7 @@ async fn main() -> Result<(), RouterError> { tokenizer_name, tokenizer_config_path, revision, + trust_remote_code, hostname, port, cors_allow_origin, diff --git a/backends/v3/src/main.rs b/backends/v3/src/main.rs index 769168c0..bc4bdb93 100644 --- a/backends/v3/src/main.rs +++ b/backends/v3/src/main.rs @@ -44,6 +44,8 @@ struct Args { tokenizer_config_path: Option, #[clap(long, env)] revision: Option, + #[clap(long, env, value_enum)] + trust_remote_code: bool, #[clap(default_value = "2", long, env)] validation_workers: usize, #[clap(long, env)] @@ -99,6 +101,7 @@ async fn main() -> Result<(), RouterError> { tokenizer_name, tokenizer_config_path, revision, + trust_remote_code, validation_workers, api_key, json_output, @@ -181,6 +184,7 @@ async fn main() -> Result<(), RouterError> { tokenizer_name, tokenizer_config_path, revision, + trust_remote_code, hostname, port, cors_allow_origin, diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 9ac6ea49..71bbcbd8 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1509,6 +1509,10 @@ fn spawn_webserver( router_args.push(revision.to_string()) } + if args.trust_remote_code { + router_args.push("--trust-remote-code".to_string()); + } + if args.json_output { router_args.push("--json-output".to_string()); } diff --git a/router/src/server.rs b/router/src/server.rs index 5abca058..eb1d2544 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1609,6 +1609,7 @@ pub async fn run( tokenizer_name: String, tokenizer_config_path: Option, revision: Option, + trust_remote_code: bool, hostname: String, port: u16, cors_allow_origin: Option>, @@ -1768,10 +1769,13 @@ pub async fn run( let auto = transformers.getattr("AutoTokenizer")?; let from_pretrained = auto.getattr("from_pretrained")?; let args = (tokenizer_name.to_string(),); - let kwargs = [( - "revision", - revision.clone().unwrap_or_else(|| "main".to_string()), - )] + let kwargs = [ + ( + "revision", + (revision.clone().unwrap_or_else(|| "main".to_string())).into_py(py), + ), + ("trust_remote_code", trust_remote_code.into_py(py)), + ] .into_py_dict_bound(py); let tokenizer = from_pretrained.call(args, Some(&kwargs))?; let save = tokenizer.getattr("save_pretrained")?;