fix(router): add auth token to get model info (#207)

This commit is contained in:
OlivierDehaene 2023-04-19 20:06:06 +02:00 committed by GitHub
parent 6837b2eb77
commit 252f42c1e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 9 deletions

View File

@ -414,6 +414,14 @@ fn main() -> ExitCode {
argv.push(origin);
}
// Copy current process env
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
// Parse Inference API token
if let Ok(api_token) = env::var("HF_API_TOKEN") {
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
};
let mut webserver = match Popen::create(
&argv,
PopenConfig {
@ -421,6 +429,7 @@ fn main() -> ExitCode {
stderr: Redirection::Pipe,
// Needed for the shutdown procedure
setpgid: true,
env: Some(env),
..Default::default()
},
) {

View File

@ -90,6 +90,9 @@ fn main() -> Result<(), std::io::Error> {
)
});
// Parse Huggingface hub token
let authorization_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok();
// Tokenizer instance
// This will only be used to validate payloads
let local_path = Path::new(&tokenizer_name);
@ -102,6 +105,7 @@ fn main() -> Result<(), std::io::Error> {
// We need to download it outside of the Tokio runtime
let params = FromPretrainedParameters {
revision: revision.clone(),
auth_token: authorization_token.clone(),
..Default::default()
};
Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).ok()
@ -129,7 +133,7 @@ fn main() -> Result<(), std::io::Error> {
sha: None,
pipeline_tag: None,
},
false => get_model_info(&tokenizer_name, &revision).await,
false => get_model_info(&tokenizer_name, &revision, authorization_token).await,
};
// if pipeline-tag == text-generation we default to return_full_text = true
@ -233,14 +237,21 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
}
/// get model info from the Huggingface Hub
pub async fn get_model_info(model_id: &str, revision: &str) -> ModelInfo {
let model_info = reqwest::get(format!(
pub async fn get_model_info(model_id: &str, revision: &str, token: Option<String>) -> ModelInfo {
let client = reqwest::Client::new();
let mut builder = client.get(format!(
"https://huggingface.co/api/models/{model_id}/revision/{revision}"
))
.await
.expect("Could not connect to hf.co")
.text()
.await
.expect("error when retrieving model info from hf.co");
));
if let Some(token) = token {
builder = builder.bearer_auth(token);
}
let model_info = builder
.send()
.await
.expect("Could not connect to hf.co")
.text()
.await
.expect("error when retrieving model info from hf.co");
serde_json::from_str(&model_info).expect("unable to parse model info")
}