diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 8dc6a798..fcac736d 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -414,6 +414,14 @@ fn main() -> ExitCode { argv.push(origin); } + // Copy current process env + let mut env: Vec<(OsString, OsString)> = env::vars_os().collect(); + + // Parse Inference API token + if let Ok(api_token) = env::var("HF_API_TOKEN") { + env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) + }; + let mut webserver = match Popen::create( &argv, PopenConfig { @@ -421,6 +429,7 @@ fn main() -> ExitCode { stderr: Redirection::Pipe, // Needed for the shutdown procedure setpgid: true, + env: Some(env), ..Default::default() }, ) { diff --git a/router/src/main.rs b/router/src/main.rs index 5fda57be..31783bbe 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -90,6 +90,9 @@ fn main() -> Result<(), std::io::Error> { ) }); + // Parse Huggingface hub token + let authorization_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok(); + // Tokenizer instance // This will only be used to validate payloads let local_path = Path::new(&tokenizer_name); @@ -102,6 +105,7 @@ fn main() -> Result<(), std::io::Error> { // We need to download it outside of the Tokio runtime let params = FromPretrainedParameters { revision: revision.clone(), + auth_token: authorization_token.clone(), ..Default::default() }; Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).ok() @@ -129,7 +133,7 @@ fn main() -> Result<(), std::io::Error> { sha: None, pipeline_tag: None, }, - false => get_model_info(&tokenizer_name, &revision).await, + false => get_model_info(&tokenizer_name, &revision, authorization_token).await, }; // if pipeline-tag == text-generation we default to return_full_text = true @@ -233,14 +237,21 @@ fn init_logging(otlp_endpoint: Option, json_output: bool) { } /// get model info from the Huggingface Hub -pub async fn get_model_info(model_id: &str, revision: &str) -> ModelInfo { - let model_info = reqwest::get(format!( +pub async fn get_model_info(model_id: &str, revision: &str, token: Option) -> ModelInfo { + let client = reqwest::Client::new(); + let mut builder = client.get(format!( "https://huggingface.co/api/models/{model_id}/revision/{revision}" - )) - .await - .expect("Could not connect to hf.co") - .text() - .await - .expect("error when retrieving model info from hf.co"); + )); + if let Some(token) = token { + builder = builder.bearer_auth(token); + } + + let model_info = builder + .send() + .await + .expect("Could not connect to hf.co") + .text() + .await + .expect("error when retrieving model info from hf.co"); serde_json::from_str(&model_info).expect("unable to parse model info") }