fix(router): add auth token to get model info (#207)

This commit is contained in:
OlivierDehaene 2023-04-19 20:06:06 +02:00 committed by GitHub
parent 6837b2eb77
commit 252f42c1e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 9 deletions

View File

@ -414,6 +414,14 @@ fn main() -> ExitCode {
argv.push(origin); argv.push(origin);
} }
// Copy current process env
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
// Parse Inference API token
if let Ok(api_token) = env::var("HF_API_TOKEN") {
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
};
let mut webserver = match Popen::create( let mut webserver = match Popen::create(
&argv, &argv,
PopenConfig { PopenConfig {
@ -421,6 +429,7 @@ fn main() -> ExitCode {
stderr: Redirection::Pipe, stderr: Redirection::Pipe,
// Needed for the shutdown procedure // Needed for the shutdown procedure
setpgid: true, setpgid: true,
env: Some(env),
..Default::default() ..Default::default()
}, },
) { ) {

View File

@ -90,6 +90,9 @@ fn main() -> Result<(), std::io::Error> {
) )
}); });
// Parse Huggingface hub token
let authorization_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok();
// Tokenizer instance // Tokenizer instance
// This will only be used to validate payloads // This will only be used to validate payloads
let local_path = Path::new(&tokenizer_name); let local_path = Path::new(&tokenizer_name);
@ -102,6 +105,7 @@ fn main() -> Result<(), std::io::Error> {
// We need to download it outside of the Tokio runtime // We need to download it outside of the Tokio runtime
let params = FromPretrainedParameters { let params = FromPretrainedParameters {
revision: revision.clone(), revision: revision.clone(),
auth_token: authorization_token.clone(),
..Default::default() ..Default::default()
}; };
Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).ok() Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).ok()
@ -129,7 +133,7 @@ fn main() -> Result<(), std::io::Error> {
sha: None, sha: None,
pipeline_tag: None, pipeline_tag: None,
}, },
false => get_model_info(&tokenizer_name, &revision).await, false => get_model_info(&tokenizer_name, &revision, authorization_token).await,
}; };
// if pipeline-tag == text-generation we default to return_full_text = true // if pipeline-tag == text-generation we default to return_full_text = true
@ -233,14 +237,21 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
} }
/// get model info from the Huggingface Hub /// get model info from the Huggingface Hub
pub async fn get_model_info(model_id: &str, revision: &str) -> ModelInfo { pub async fn get_model_info(model_id: &str, revision: &str, token: Option<String>) -> ModelInfo {
let model_info = reqwest::get(format!( let client = reqwest::Client::new();
let mut builder = client.get(format!(
"https://huggingface.co/api/models/{model_id}/revision/{revision}" "https://huggingface.co/api/models/{model_id}/revision/{revision}"
)) ));
.await if let Some(token) = token {
.expect("Could not connect to hf.co") builder = builder.bearer_auth(token);
.text() }
.await
.expect("error when retrieving model info from hf.co"); let model_info = builder
.send()
.await
.expect("Could not connect to hf.co")
.text()
.await
.expect("error when retrieving model info from hf.co");
serde_json::from_str(&model_info).expect("unable to parse model info") serde_json::from_str(&model_info).expect("unable to parse model info")
} }