fix(router): add auth token to get model info (#207)
This commit is contained in:
parent
6837b2eb77
commit
252f42c1e6
|
@ -414,6 +414,14 @@ fn main() -> ExitCode {
|
|||
argv.push(origin);
|
||||
}
|
||||
|
||||
// Copy current process env
|
||||
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||
|
||||
// Parse Inference API token
|
||||
if let Ok(api_token) = env::var("HF_API_TOKEN") {
|
||||
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
||||
};
|
||||
|
||||
let mut webserver = match Popen::create(
|
||||
&argv,
|
||||
PopenConfig {
|
||||
|
@ -421,6 +429,7 @@ fn main() -> ExitCode {
|
|||
stderr: Redirection::Pipe,
|
||||
// Needed for the shutdown procedure
|
||||
setpgid: true,
|
||||
env: Some(env),
|
||||
..Default::default()
|
||||
},
|
||||
) {
|
||||
|
|
|
@ -90,6 +90,9 @@ fn main() -> Result<(), std::io::Error> {
|
|||
)
|
||||
});
|
||||
|
||||
// Parse Huggingface hub token
|
||||
let authorization_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok();
|
||||
|
||||
// Tokenizer instance
|
||||
// This will only be used to validate payloads
|
||||
let local_path = Path::new(&tokenizer_name);
|
||||
|
@ -102,6 +105,7 @@ fn main() -> Result<(), std::io::Error> {
|
|||
// We need to download it outside of the Tokio runtime
|
||||
let params = FromPretrainedParameters {
|
||||
revision: revision.clone(),
|
||||
auth_token: authorization_token.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).ok()
|
||||
|
@ -129,7 +133,7 @@ fn main() -> Result<(), std::io::Error> {
|
|||
sha: None,
|
||||
pipeline_tag: None,
|
||||
},
|
||||
false => get_model_info(&tokenizer_name, &revision).await,
|
||||
false => get_model_info(&tokenizer_name, &revision, authorization_token).await,
|
||||
};
|
||||
|
||||
// if pipeline-tag == text-generation we default to return_full_text = true
|
||||
|
@ -233,14 +237,21 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
|
|||
}
|
||||
|
||||
/// get model info from the Huggingface Hub
|
||||
pub async fn get_model_info(model_id: &str, revision: &str) -> ModelInfo {
|
||||
let model_info = reqwest::get(format!(
|
||||
pub async fn get_model_info(model_id: &str, revision: &str, token: Option<String>) -> ModelInfo {
|
||||
let client = reqwest::Client::new();
|
||||
let mut builder = client.get(format!(
|
||||
"https://huggingface.co/api/models/{model_id}/revision/{revision}"
|
||||
))
|
||||
.await
|
||||
.expect("Could not connect to hf.co")
|
||||
.text()
|
||||
.await
|
||||
.expect("error when retrieving model info from hf.co");
|
||||
));
|
||||
if let Some(token) = token {
|
||||
builder = builder.bearer_auth(token);
|
||||
}
|
||||
|
||||
let model_info = builder
|
||||
.send()
|
||||
.await
|
||||
.expect("Could not connect to hf.co")
|
||||
.text()
|
||||
.await
|
||||
.expect("error when retrieving model info from hf.co");
|
||||
serde_json::from_str(&model_info).expect("unable to parse model info")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue