diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 41211d8..9f8d215 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -542,7 +542,7 @@ fn download_convert_model( // Copy current process env let mut env: Vec<(OsString, OsString)> = env::vars_os().collect(); - // If huggingface_hub_cache is set, pass it to the shard + // If huggingface_hub_cache is set, pass it to the download process // Useful when running inside a docker container if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache { env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); @@ -560,6 +560,15 @@ fn download_convert_model( env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) }; + // If args.weights_cache_override is some, pass it to the download process + // Useful when running inside a HuggingFace Inference Endpoint + if let Some(weights_cache_override) = &args.weights_cache_override { + env.push(( + "WEIGHTS_CACHE_OVERRIDE".into(), + weights_cache_override.into(), + )); + }; + // Start process tracing::info!("Starting download process."); let mut download_process = match Popen::create(