From 982ce3227b780a8029b1c24546058c8f31f61a78 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Thu, 13 Jul 2023 18:59:38 +0200 Subject: [PATCH] feat(router): explicit warning if revision is not set (#608) --- launcher/src/main.rs | 16 ++++++---------- router/src/main.rs | 32 +++++++++++++++++++++++++------- 2 files changed, 31 insertions(+), 17 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 8b34dfe3..54fb1368 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -760,16 +760,6 @@ fn spawn_shards( status_sender: mpsc::Sender, running: Arc, ) -> Result<(), LauncherError> { - if args.trust_remote_code { - tracing::warn!( - "`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.", - args.model_id - ); - if args.revision.is_none() { - tracing::warn!("Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision."); - } - } - // Start shard processes for rank in 0..num_shard { let model_id = args.model_id.clone(); @@ -1025,6 +1015,12 @@ fn main() -> Result<(), LauncherError> { "`validation_workers` must be > 0".to_string(), )); } + if args.trust_remote_code { + tracing::warn!( + "`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.", + args.model_id + ); + } let num_shard = find_num_shards(args.sharded, args.num_shard)?; if num_shard > 1 { diff --git a/router/src/main.rs b/router/src/main.rs index 57ddd5ba..178c249c 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -49,8 +49,8 @@ struct Args { master_shard_uds_path: String, #[clap(default_value = "bigscience/bloom", long, env)] tokenizer_name: String, - #[clap(default_value = "main", long, env)] - revision: String, + #[clap(long, env)] + revision: Option, #[clap(default_value = "2", long, env)] validation_workers: usize, #[clap(long, env)] @@ -147,7 +147,7 @@ fn main() -> Result<(), RouterError> { // Download and instantiate tokenizer // We need to download it outside of the Tokio runtime let params = FromPretrainedParameters { - revision: revision.clone(), + revision: revision.clone().unwrap_or("main".to_string()), auth_token: authorization_token.clone(), ..Default::default() }; @@ -175,7 +175,7 @@ fn main() -> Result<(), RouterError> { sha: None, pipeline_tag: None, }, - false => get_model_info(&tokenizer_name, &revision, authorization_token) + false => get_model_info(&tokenizer_name, revision, authorization_token) .await .unwrap_or_else(|| { tracing::warn!("Could not retrieve model info from the Hugging Face hub."); @@ -316,9 +316,18 @@ fn init_logging(otlp_endpoint: Option, json_output: bool) { /// get model info from the Huggingface Hub pub async fn get_model_info( model_id: &str, - revision: &str, + revision: Option, token: Option, ) -> Option { + let revision = match revision { + None => { + tracing::warn!("`--revision` is not set"); + tracing::warn!("We strongly advise to set it to a known supported commit."); + "main".to_string() + } + Some(revision) => revision, + }; + let client = reqwest::Client::new(); // Poor man's urlencode let revision = revision.replace('/', "%2F"); @@ -331,9 +340,18 @@ pub async fn get_model_info( let response = builder.send().await.ok()?; if response.status().is_success() { - return serde_json::from_str(&response.text().await.ok()?).ok(); + let hub_model_info: HubModelInfo = + serde_json::from_str(&response.text().await.ok()?).ok()?; + if let Some(sha) = &hub_model_info.sha { + tracing::info!( + "Serving revision {sha} of model {}", + hub_model_info.model_id + ); + } + Some(hub_model_info) + } else { + None } - None } #[derive(Debug, Error)]