diff --git a/Cargo.lock b/Cargo.lock index fa6afc29..27499cd4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4247,6 +4247,7 @@ dependencies = [ "nix 0.28.0", "once_cell", "pyo3", + "regex", "reqwest", "serde", "serde_json", diff --git a/docs/source/conceptual/lora.md b/docs/source/conceptual/lora.md index 0b7e3616..d1f4ce78 100644 --- a/docs/source/conceptual/lora.md +++ b/docs/source/conceptual/lora.md @@ -36,6 +36,12 @@ To use LoRA in TGI, when starting the server, you can specify the list of LoRA m LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia ``` +To specify model revision, use `adapter_id@revision`, as follows: + +```bash +LORA_ADAPTERS=predibase/customer_support@main,predibase/dbpedia@rev2 +``` + To use a locally stored lora adapter, use `adapter-name=/path/to/adapter`, as seen below. When you want to use this adapter, set `"parameters": {"adapter_id": "adapter-name"}"` ```bash diff --git a/launcher/Cargo.toml b/launcher/Cargo.toml index 033a9a04..fdc3c02c 100644 --- a/launcher/Cargo.toml +++ b/launcher/Cargo.toml @@ -18,6 +18,7 @@ serde_json = "1.0.107" thiserror = "1.0.59" tracing = "0.1.37" tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } +regex = "1.11.0" [dev-dependencies] float_eq = "1.0.1" diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 474a72d3..aba497d6 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -5,6 +5,7 @@ use hf_hub::{ }; use nix::sys::signal::{self, Signal}; use nix::unistd::Pid; +use regex::Regex; use serde::Deserialize; use std::env; use std::ffi::OsString; @@ -1808,14 +1809,37 @@ fn main() -> Result<(), LauncherError> { if adapter.contains('=') { continue; } - download_convert_model( - adapter, - None, - args.trust_remote_code, - args.huggingface_hub_cache.as_deref(), - args.weights_cache_override.as_deref(), - running.clone(), - )?; + + let adapter = adapter.trim(); + + // check if adapter has more than 1 '@' + if adapter.matches('@').count() > 1 { + return Err(LauncherError::ArgumentValidation(format!( + "Invalid LoRA adapter format: {}", + adapter + ))); + } + + // capture adapter_id, path, revision in format of adapter_id=path@revision + let re = Regex::new(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$").unwrap(); + if let Some(caps) = re.captures(adapter) { + let adapter_id = caps.get(1).map_or("", |m| m.as_str()); + let revision = caps.get(3).map(|m| m.as_str()); + + download_convert_model( + adapter_id, + revision, + args.trust_remote_code, + args.huggingface_hub_cache.as_deref(), + args.weights_cache_override.as_deref(), + running.clone(), + )?; + } else { + return Err(LauncherError::ArgumentValidation(format!( + "Invalid LoRA adapter format: {}", + adapter + ))); + } } }