feat(router): explicit warning if revision is not set (#608)
This commit is contained in:
parent
b7327205a6
commit
982ce3227b
|
@ -760,16 +760,6 @@ fn spawn_shards(
|
||||||
status_sender: mpsc::Sender<ShardStatus>,
|
status_sender: mpsc::Sender<ShardStatus>,
|
||||||
running: Arc<AtomicBool>,
|
running: Arc<AtomicBool>,
|
||||||
) -> Result<(), LauncherError> {
|
) -> Result<(), LauncherError> {
|
||||||
if args.trust_remote_code {
|
|
||||||
tracing::warn!(
|
|
||||||
"`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
|
|
||||||
args.model_id
|
|
||||||
);
|
|
||||||
if args.revision.is_none() {
|
|
||||||
tracing::warn!("Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start shard processes
|
// Start shard processes
|
||||||
for rank in 0..num_shard {
|
for rank in 0..num_shard {
|
||||||
let model_id = args.model_id.clone();
|
let model_id = args.model_id.clone();
|
||||||
|
@ -1025,6 +1015,12 @@ fn main() -> Result<(), LauncherError> {
|
||||||
"`validation_workers` must be > 0".to_string(),
|
"`validation_workers` must be > 0".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
if args.trust_remote_code {
|
||||||
|
tracing::warn!(
|
||||||
|
"`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
|
||||||
|
args.model_id
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
let num_shard = find_num_shards(args.sharded, args.num_shard)?;
|
let num_shard = find_num_shards(args.sharded, args.num_shard)?;
|
||||||
if num_shard > 1 {
|
if num_shard > 1 {
|
||||||
|
|
|
@ -49,8 +49,8 @@ struct Args {
|
||||||
master_shard_uds_path: String,
|
master_shard_uds_path: String,
|
||||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||||
tokenizer_name: String,
|
tokenizer_name: String,
|
||||||
#[clap(default_value = "main", long, env)]
|
#[clap(long, env)]
|
||||||
revision: String,
|
revision: Option<String>,
|
||||||
#[clap(default_value = "2", long, env)]
|
#[clap(default_value = "2", long, env)]
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
|
@ -147,7 +147,7 @@ fn main() -> Result<(), RouterError> {
|
||||||
// Download and instantiate tokenizer
|
// Download and instantiate tokenizer
|
||||||
// We need to download it outside of the Tokio runtime
|
// We need to download it outside of the Tokio runtime
|
||||||
let params = FromPretrainedParameters {
|
let params = FromPretrainedParameters {
|
||||||
revision: revision.clone(),
|
revision: revision.clone().unwrap_or("main".to_string()),
|
||||||
auth_token: authorization_token.clone(),
|
auth_token: authorization_token.clone(),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
@ -175,7 +175,7 @@ fn main() -> Result<(), RouterError> {
|
||||||
sha: None,
|
sha: None,
|
||||||
pipeline_tag: None,
|
pipeline_tag: None,
|
||||||
},
|
},
|
||||||
false => get_model_info(&tokenizer_name, &revision, authorization_token)
|
false => get_model_info(&tokenizer_name, revision, authorization_token)
|
||||||
.await
|
.await
|
||||||
.unwrap_or_else(|| {
|
.unwrap_or_else(|| {
|
||||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||||
|
@ -316,9 +316,18 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
|
||||||
/// get model info from the Huggingface Hub
|
/// get model info from the Huggingface Hub
|
||||||
pub async fn get_model_info(
|
pub async fn get_model_info(
|
||||||
model_id: &str,
|
model_id: &str,
|
||||||
revision: &str,
|
revision: Option<String>,
|
||||||
token: Option<String>,
|
token: Option<String>,
|
||||||
) -> Option<HubModelInfo> {
|
) -> Option<HubModelInfo> {
|
||||||
|
let revision = match revision {
|
||||||
|
None => {
|
||||||
|
tracing::warn!("`--revision` is not set");
|
||||||
|
tracing::warn!("We strongly advise to set it to a known supported commit.");
|
||||||
|
"main".to_string()
|
||||||
|
}
|
||||||
|
Some(revision) => revision,
|
||||||
|
};
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
// Poor man's urlencode
|
// Poor man's urlencode
|
||||||
let revision = revision.replace('/', "%2F");
|
let revision = revision.replace('/', "%2F");
|
||||||
|
@ -331,10 +340,19 @@ pub async fn get_model_info(
|
||||||
let response = builder.send().await.ok()?;
|
let response = builder.send().await.ok()?;
|
||||||
|
|
||||||
if response.status().is_success() {
|
if response.status().is_success() {
|
||||||
return serde_json::from_str(&response.text().await.ok()?).ok();
|
let hub_model_info: HubModelInfo =
|
||||||
|
serde_json::from_str(&response.text().await.ok()?).ok()?;
|
||||||
|
if let Some(sha) = &hub_model_info.sha {
|
||||||
|
tracing::info!(
|
||||||
|
"Serving revision {sha} of model {}",
|
||||||
|
hub_model_info.model_id
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
Some(hub_model_info)
|
||||||
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
enum RouterError {
|
enum RouterError {
|
||||||
|
|
Loading…
Reference in New Issue