feat(router): explicit warning if revision is not set (#608)

This commit is contained in:
OlivierDehaene 2023-07-13 18:59:38 +02:00 committed by GitHub
parent b7327205a6
commit 982ce3227b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 17 deletions

View File

@ -760,16 +760,6 @@ fn spawn_shards(
status_sender: mpsc::Sender<ShardStatus>, status_sender: mpsc::Sender<ShardStatus>,
running: Arc<AtomicBool>, running: Arc<AtomicBool>,
) -> Result<(), LauncherError> { ) -> Result<(), LauncherError> {
if args.trust_remote_code {
tracing::warn!(
"`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
args.model_id
);
if args.revision.is_none() {
tracing::warn!("Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.");
}
}
// Start shard processes // Start shard processes
for rank in 0..num_shard { for rank in 0..num_shard {
let model_id = args.model_id.clone(); let model_id = args.model_id.clone();
@ -1025,6 +1015,12 @@ fn main() -> Result<(), LauncherError> {
"`validation_workers` must be > 0".to_string(), "`validation_workers` must be > 0".to_string(),
)); ));
} }
if args.trust_remote_code {
tracing::warn!(
"`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
args.model_id
);
}
let num_shard = find_num_shards(args.sharded, args.num_shard)?; let num_shard = find_num_shards(args.sharded, args.num_shard)?;
if num_shard > 1 { if num_shard > 1 {

View File

@ -49,8 +49,8 @@ struct Args {
master_shard_uds_path: String, master_shard_uds_path: String,
#[clap(default_value = "bigscience/bloom", long, env)] #[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String, tokenizer_name: String,
#[clap(default_value = "main", long, env)] #[clap(long, env)]
revision: String, revision: Option<String>,
#[clap(default_value = "2", long, env)] #[clap(default_value = "2", long, env)]
validation_workers: usize, validation_workers: usize,
#[clap(long, env)] #[clap(long, env)]
@ -147,7 +147,7 @@ fn main() -> Result<(), RouterError> {
// Download and instantiate tokenizer // Download and instantiate tokenizer
// We need to download it outside of the Tokio runtime // We need to download it outside of the Tokio runtime
let params = FromPretrainedParameters { let params = FromPretrainedParameters {
revision: revision.clone(), revision: revision.clone().unwrap_or("main".to_string()),
auth_token: authorization_token.clone(), auth_token: authorization_token.clone(),
..Default::default() ..Default::default()
}; };
@ -175,7 +175,7 @@ fn main() -> Result<(), RouterError> {
sha: None, sha: None,
pipeline_tag: None, pipeline_tag: None,
}, },
false => get_model_info(&tokenizer_name, &revision, authorization_token) false => get_model_info(&tokenizer_name, revision, authorization_token)
.await .await
.unwrap_or_else(|| { .unwrap_or_else(|| {
tracing::warn!("Could not retrieve model info from the Hugging Face hub."); tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
@ -316,9 +316,18 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
/// get model info from the Huggingface Hub /// get model info from the Huggingface Hub
pub async fn get_model_info( pub async fn get_model_info(
model_id: &str, model_id: &str,
revision: &str, revision: Option<String>,
token: Option<String>, token: Option<String>,
) -> Option<HubModelInfo> { ) -> Option<HubModelInfo> {
let revision = match revision {
None => {
tracing::warn!("`--revision` is not set");
tracing::warn!("We strongly advise to set it to a known supported commit.");
"main".to_string()
}
Some(revision) => revision,
};
let client = reqwest::Client::new(); let client = reqwest::Client::new();
// Poor man's urlencode // Poor man's urlencode
let revision = revision.replace('/', "%2F"); let revision = revision.replace('/', "%2F");
@ -331,9 +340,18 @@ pub async fn get_model_info(
let response = builder.send().await.ok()?; let response = builder.send().await.ok()?;
if response.status().is_success() { if response.status().is_success() {
return serde_json::from_str(&response.text().await.ok()?).ok(); let hub_model_info: HubModelInfo =
serde_json::from_str(&response.text().await.ok()?).ok()?;
if let Some(sha) = &hub_model_info.sha {
tracing::info!(
"Serving revision {sha} of model {}",
hub_model_info.model_id
);
}
Some(hub_model_info)
} else {
None
} }
None
} }
#[derive(Debug, Error)] #[derive(Debug, Error)]