fix: add merge-lora arg for model id (#2788)

This commit is contained in:
drbh 2024-12-01 23:52:02 -05:00 committed by GitHub
parent a35d1e6fe5
commit 2c74c55637
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 7 additions and 0 deletions

View File

@ -1193,6 +1193,7 @@ fn download_convert_model(
huggingface_hub_cache: Option<&str>,
weights_cache_override: Option<&str>,
running: Arc<AtomicBool>,
merge_lora: bool,
) -> Result<(), LauncherError> {
// Enter download tracing span
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
@ -1207,6 +1208,10 @@ fn download_convert_model(
"--json-output".to_string(),
];
if merge_lora {
download_args.push("--merge-lora".to_string());
}
// Model optional revision
if let Some(revision) = &revision {
download_args.push("--revision".to_string());
@ -1842,6 +1847,7 @@ fn main() -> Result<(), LauncherError> {
args.huggingface_hub_cache.as_deref(),
args.weights_cache_override.as_deref(),
running.clone(),
true, // if its only a lora model - we should merge the lora adapters
)?;
// Download and convert lora adapters if any
@ -1875,6 +1881,7 @@ fn main() -> Result<(), LauncherError> {
args.huggingface_hub_cache.as_deref(),
args.weights_cache_override.as_deref(),
running.clone(),
false, // avoid merging lora adapters if using multi-lora
)?;
} else {
return Err(LauncherError::ArgumentValidation(format!(