CI (2592): Allow LoRA adapter revision in server launcher (#2602)
allow revision for lora adapters from launcher Co-authored-by: Sida <sida@kulamind.com> Co-authored-by: teamclouday <teamclouday@gmail.com>
This commit is contained in:
parent
0204946d26
commit
2335459556
|
@ -4247,6 +4247,7 @@ dependencies = [
|
|||
"nix 0.28.0",
|
||||
"once_cell",
|
||||
"pyo3",
|
||||
"regex",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
|
|
@ -36,6 +36,12 @@ To use LoRA in TGI, when starting the server, you can specify the list of LoRA m
|
|||
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
|
||||
```
|
||||
|
||||
To specify model revision, use `adapter_id@revision`, as follows:
|
||||
|
||||
```bash
|
||||
LORA_ADAPTERS=predibase/customer_support@main,predibase/dbpedia@rev2
|
||||
```
|
||||
|
||||
To use a locally stored lora adapter, use `adapter-name=/path/to/adapter`, as seen below. When you want to use this adapter, set `"parameters": {"adapter_id": "adapter-name"}"`
|
||||
|
||||
```bash
|
||||
|
|
|
@ -18,6 +18,7 @@ serde_json = "1.0.107"
|
|||
thiserror = "1.0.59"
|
||||
tracing = "0.1.37"
|
||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||
regex = "1.11.0"
|
||||
|
||||
[dev-dependencies]
|
||||
float_eq = "1.0.1"
|
||||
|
|
|
@ -5,6 +5,7 @@ use hf_hub::{
|
|||
};
|
||||
use nix::sys::signal::{self, Signal};
|
||||
use nix::unistd::Pid;
|
||||
use regex::Regex;
|
||||
use serde::Deserialize;
|
||||
use std::env;
|
||||
use std::ffi::OsString;
|
||||
|
@ -1808,14 +1809,37 @@ fn main() -> Result<(), LauncherError> {
|
|||
if adapter.contains('=') {
|
||||
continue;
|
||||
}
|
||||
download_convert_model(
|
||||
adapter,
|
||||
None,
|
||||
args.trust_remote_code,
|
||||
args.huggingface_hub_cache.as_deref(),
|
||||
args.weights_cache_override.as_deref(),
|
||||
running.clone(),
|
||||
)?;
|
||||
|
||||
let adapter = adapter.trim();
|
||||
|
||||
// check if adapter has more than 1 '@'
|
||||
if adapter.matches('@').count() > 1 {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"Invalid LoRA adapter format: {}",
|
||||
adapter
|
||||
)));
|
||||
}
|
||||
|
||||
// capture adapter_id, path, revision in format of adapter_id=path@revision
|
||||
let re = Regex::new(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$").unwrap();
|
||||
if let Some(caps) = re.captures(adapter) {
|
||||
let adapter_id = caps.get(1).map_or("", |m| m.as_str());
|
||||
let revision = caps.get(3).map(|m| m.as_str());
|
||||
|
||||
download_convert_model(
|
||||
adapter_id,
|
||||
revision,
|
||||
args.trust_remote_code,
|
||||
args.huggingface_hub_cache.as_deref(),
|
||||
args.weights_cache_override.as_deref(),
|
||||
running.clone(),
|
||||
)?;
|
||||
} else {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"Invalid LoRA adapter format: {}",
|
||||
adapter
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue