feat: download lora adapter weights from launcher (#2140)
This commit is contained in:
parent
25f57e2e98
commit
0d97a93c1e
|
@ -898,13 +898,20 @@ enum LauncherError {
|
||||||
WebserverCannotStart,
|
WebserverCannotStart,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
|
fn download_convert_model(
|
||||||
|
model_id: &str,
|
||||||
|
revision: Option<&str>,
|
||||||
|
trust_remote_code: bool,
|
||||||
|
huggingface_hub_cache: Option<&str>,
|
||||||
|
weights_cache_override: Option<&str>,
|
||||||
|
running: Arc<AtomicBool>,
|
||||||
|
) -> Result<(), LauncherError> {
|
||||||
// Enter download tracing span
|
// Enter download tracing span
|
||||||
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
|
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
|
||||||
|
|
||||||
let mut download_args = vec![
|
let mut download_args = vec![
|
||||||
"download-weights".to_string(),
|
"download-weights".to_string(),
|
||||||
args.model_id.to_string(),
|
model_id.to_string(),
|
||||||
"--extension".to_string(),
|
"--extension".to_string(),
|
||||||
".safetensors".to_string(),
|
".safetensors".to_string(),
|
||||||
"--logger-level".to_string(),
|
"--logger-level".to_string(),
|
||||||
|
@ -913,13 +920,13 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
||||||
];
|
];
|
||||||
|
|
||||||
// Model optional revision
|
// Model optional revision
|
||||||
if let Some(revision) = &args.revision {
|
if let Some(revision) = &revision {
|
||||||
download_args.push("--revision".to_string());
|
download_args.push("--revision".to_string());
|
||||||
download_args.push(revision.to_string())
|
download_args.push(revision.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Trust remote code for automatic peft fusion
|
// Trust remote code for automatic peft fusion
|
||||||
if args.trust_remote_code {
|
if trust_remote_code {
|
||||||
download_args.push("--trust-remote-code".to_string());
|
download_args.push("--trust-remote-code".to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -934,7 +941,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
||||||
|
|
||||||
// If huggingface_hub_cache is set, pass it to the download process
|
// If huggingface_hub_cache is set, pass it to the download process
|
||||||
// Useful when running inside a docker container
|
// Useful when running inside a docker container
|
||||||
if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache {
|
if let Some(ref huggingface_hub_cache) = huggingface_hub_cache {
|
||||||
envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -952,7 +959,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
||||||
|
|
||||||
// If args.weights_cache_override is some, pass it to the download process
|
// If args.weights_cache_override is some, pass it to the download process
|
||||||
// Useful when running inside a HuggingFace Inference Endpoint
|
// Useful when running inside a HuggingFace Inference Endpoint
|
||||||
if let Some(weights_cache_override) = &args.weights_cache_override {
|
if let Some(weights_cache_override) = &weights_cache_override {
|
||||||
envs.push((
|
envs.push((
|
||||||
"WEIGHTS_CACHE_OVERRIDE".into(),
|
"WEIGHTS_CACHE_OVERRIDE".into(),
|
||||||
weights_cache_override.into(),
|
weights_cache_override.into(),
|
||||||
|
@ -960,7 +967,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
||||||
};
|
};
|
||||||
|
|
||||||
// Start process
|
// Start process
|
||||||
tracing::info!("Starting download process.");
|
tracing::info!("Starting check and download process for {model_id}");
|
||||||
let mut download_process = match Command::new("text-generation-server")
|
let mut download_process = match Command::new("text-generation-server")
|
||||||
.args(download_args)
|
.args(download_args)
|
||||||
.env_clear()
|
.env_clear()
|
||||||
|
@ -1002,7 +1009,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
||||||
loop {
|
loop {
|
||||||
if let Some(status) = download_process.try_wait().unwrap() {
|
if let Some(status) = download_process.try_wait().unwrap() {
|
||||||
if status.success() {
|
if status.success() {
|
||||||
tracing::info!("Successfully downloaded weights.");
|
tracing::info!("Successfully downloaded weights for {model_id}");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1557,7 +1564,28 @@ fn main() -> Result<(), LauncherError> {
|
||||||
.expect("Error setting Ctrl-C handler");
|
.expect("Error setting Ctrl-C handler");
|
||||||
|
|
||||||
// Download and convert model weights
|
// Download and convert model weights
|
||||||
download_convert_model(&args, running.clone())?;
|
download_convert_model(
|
||||||
|
&args.model_id,
|
||||||
|
args.revision.as_deref(),
|
||||||
|
args.trust_remote_code,
|
||||||
|
args.huggingface_hub_cache.as_deref(),
|
||||||
|
args.weights_cache_override.as_deref(),
|
||||||
|
running.clone(),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// Download and convert lora adapters if any
|
||||||
|
if let Some(lora_adapters) = &args.lora_adapters {
|
||||||
|
for adapter in lora_adapters.split(',') {
|
||||||
|
download_convert_model(
|
||||||
|
adapter,
|
||||||
|
None,
|
||||||
|
args.trust_remote_code,
|
||||||
|
args.huggingface_hub_cache.as_deref(),
|
||||||
|
args.weights_cache_override.as_deref(),
|
||||||
|
running.clone(),
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if !running.load(Ordering::SeqCst) {
|
if !running.load(Ordering::SeqCst) {
|
||||||
// Launcher was asked to stop
|
// Launcher was asked to stop
|
||||||
|
|
Loading…
Reference in New Issue