feat(router): explicit warning if revision is not set (#608)
This commit is contained in:
parent
b7327205a6
commit
982ce3227b
|
@ -760,16 +760,6 @@ fn spawn_shards(
|
|||
status_sender: mpsc::Sender<ShardStatus>,
|
||||
running: Arc<AtomicBool>,
|
||||
) -> Result<(), LauncherError> {
|
||||
if args.trust_remote_code {
|
||||
tracing::warn!(
|
||||
"`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
|
||||
args.model_id
|
||||
);
|
||||
if args.revision.is_none() {
|
||||
tracing::warn!("Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.");
|
||||
}
|
||||
}
|
||||
|
||||
// Start shard processes
|
||||
for rank in 0..num_shard {
|
||||
let model_id = args.model_id.clone();
|
||||
|
@ -1025,6 +1015,12 @@ fn main() -> Result<(), LauncherError> {
|
|||
"`validation_workers` must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
if args.trust_remote_code {
|
||||
tracing::warn!(
|
||||
"`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
|
||||
args.model_id
|
||||
);
|
||||
}
|
||||
|
||||
let num_shard = find_num_shards(args.sharded, args.num_shard)?;
|
||||
if num_shard > 1 {
|
||||
|
|
|
@ -49,8 +49,8 @@ struct Args {
|
|||
master_shard_uds_path: String,
|
||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||
tokenizer_name: String,
|
||||
#[clap(default_value = "main", long, env)]
|
||||
revision: String,
|
||||
#[clap(long, env)]
|
||||
revision: Option<String>,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
validation_workers: usize,
|
||||
#[clap(long, env)]
|
||||
|
@ -147,7 +147,7 @@ fn main() -> Result<(), RouterError> {
|
|||
// Download and instantiate tokenizer
|
||||
// We need to download it outside of the Tokio runtime
|
||||
let params = FromPretrainedParameters {
|
||||
revision: revision.clone(),
|
||||
revision: revision.clone().unwrap_or("main".to_string()),
|
||||
auth_token: authorization_token.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
|
@ -175,7 +175,7 @@ fn main() -> Result<(), RouterError> {
|
|||
sha: None,
|
||||
pipeline_tag: None,
|
||||
},
|
||||
false => get_model_info(&tokenizer_name, &revision, authorization_token)
|
||||
false => get_model_info(&tokenizer_name, revision, authorization_token)
|
||||
.await
|
||||
.unwrap_or_else(|| {
|
||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||
|
@ -316,9 +316,18 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
|
|||
/// get model info from the Huggingface Hub
|
||||
pub async fn get_model_info(
|
||||
model_id: &str,
|
||||
revision: &str,
|
||||
revision: Option<String>,
|
||||
token: Option<String>,
|
||||
) -> Option<HubModelInfo> {
|
||||
let revision = match revision {
|
||||
None => {
|
||||
tracing::warn!("`--revision` is not set");
|
||||
tracing::warn!("We strongly advise to set it to a known supported commit.");
|
||||
"main".to_string()
|
||||
}
|
||||
Some(revision) => revision,
|
||||
};
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
// Poor man's urlencode
|
||||
let revision = revision.replace('/', "%2F");
|
||||
|
@ -331,10 +340,19 @@ pub async fn get_model_info(
|
|||
let response = builder.send().await.ok()?;
|
||||
|
||||
if response.status().is_success() {
|
||||
return serde_json::from_str(&response.text().await.ok()?).ok();
|
||||
let hub_model_info: HubModelInfo =
|
||||
serde_json::from_str(&response.text().await.ok()?).ok()?;
|
||||
if let Some(sha) = &hub_model_info.sha {
|
||||
tracing::info!(
|
||||
"Serving revision {sha} of model {}",
|
||||
hub_model_info.model_id
|
||||
);
|
||||
}
|
||||
Some(hub_model_info)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum RouterError {
|
||||
|
|
Loading…
Reference in New Issue