From 0d97a93c1e14d497e911f42db2da0b9eb032fe75 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 1 Jul 2024 06:58:49 -0400 Subject: [PATCH] feat: download lora adapter weights from launcher (#2140) --- launcher/src/main.rs | 46 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 9 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 816fa5f3..d2ca38e5 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -898,13 +898,20 @@ enum LauncherError { WebserverCannotStart, } -fn download_convert_model(args: &Args, running: Arc) -> Result<(), LauncherError> { +fn download_convert_model( + model_id: &str, + revision: Option<&str>, + trust_remote_code: bool, + huggingface_hub_cache: Option<&str>, + weights_cache_override: Option<&str>, + running: Arc, +) -> Result<(), LauncherError> { // Enter download tracing span let _span = tracing::span!(tracing::Level::INFO, "download").entered(); let mut download_args = vec![ "download-weights".to_string(), - args.model_id.to_string(), + model_id.to_string(), "--extension".to_string(), ".safetensors".to_string(), "--logger-level".to_string(), @@ -913,13 +920,13 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L ]; // Model optional revision - if let Some(revision) = &args.revision { + if let Some(revision) = &revision { download_args.push("--revision".to_string()); download_args.push(revision.to_string()) } // Trust remote code for automatic peft fusion - if args.trust_remote_code { + if trust_remote_code { download_args.push("--trust-remote-code".to_string()); } @@ -934,7 +941,7 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L // If huggingface_hub_cache is set, pass it to the download process // Useful when running inside a docker container - if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache { + if let Some(ref huggingface_hub_cache) = huggingface_hub_cache { envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); }; @@ -952,7 +959,7 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L // If args.weights_cache_override is some, pass it to the download process // Useful when running inside a HuggingFace Inference Endpoint - if let Some(weights_cache_override) = &args.weights_cache_override { + if let Some(weights_cache_override) = &weights_cache_override { envs.push(( "WEIGHTS_CACHE_OVERRIDE".into(), weights_cache_override.into(), @@ -960,7 +967,7 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L }; // Start process - tracing::info!("Starting download process."); + tracing::info!("Starting check and download process for {model_id}"); let mut download_process = match Command::new("text-generation-server") .args(download_args) .env_clear() @@ -1002,7 +1009,7 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L loop { if let Some(status) = download_process.try_wait().unwrap() { if status.success() { - tracing::info!("Successfully downloaded weights."); + tracing::info!("Successfully downloaded weights for {model_id}"); break; } @@ -1557,7 +1564,28 @@ fn main() -> Result<(), LauncherError> { .expect("Error setting Ctrl-C handler"); // Download and convert model weights - download_convert_model(&args, running.clone())?; + download_convert_model( + &args.model_id, + args.revision.as_deref(), + args.trust_remote_code, + args.huggingface_hub_cache.as_deref(), + args.weights_cache_override.as_deref(), + running.clone(), + )?; + + // Download and convert lora adapters if any + if let Some(lora_adapters) = &args.lora_adapters { + for adapter in lora_adapters.split(',') { + download_convert_model( + adapter, + None, + args.trust_remote_code, + args.huggingface_hub_cache.as_deref(), + args.weights_cache_override.as_deref(), + running.clone(), + )?; + } + } if !running.load(Ordering::SeqCst) { // Launcher was asked to stop