CI (2592): Allow LoRA adapter revision in server launcher (#2602)
allow revision for lora adapters from launcher Co-authored-by: Sida <sida@kulamind.com> Co-authored-by: teamclouday <teamclouday@gmail.com>
This commit is contained in:
parent
0204946d26
commit
2335459556
|
@ -4247,6 +4247,7 @@ dependencies = [
|
||||||
"nix 0.28.0",
|
"nix 0.28.0",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"pyo3",
|
"pyo3",
|
||||||
|
"regex",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
|
|
@ -36,6 +36,12 @@ To use LoRA in TGI, when starting the server, you can specify the list of LoRA m
|
||||||
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
|
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
|
||||||
```
|
```
|
||||||
|
|
||||||
|
To specify model revision, use `adapter_id@revision`, as follows:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
LORA_ADAPTERS=predibase/customer_support@main,predibase/dbpedia@rev2
|
||||||
|
```
|
||||||
|
|
||||||
To use a locally stored lora adapter, use `adapter-name=/path/to/adapter`, as seen below. When you want to use this adapter, set `"parameters": {"adapter_id": "adapter-name"}"`
|
To use a locally stored lora adapter, use `adapter-name=/path/to/adapter`, as seen below. When you want to use this adapter, set `"parameters": {"adapter_id": "adapter-name"}"`
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
|
@ -18,6 +18,7 @@ serde_json = "1.0.107"
|
||||||
thiserror = "1.0.59"
|
thiserror = "1.0.59"
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||||
|
regex = "1.11.0"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
float_eq = "1.0.1"
|
float_eq = "1.0.1"
|
||||||
|
|
|
@ -5,6 +5,7 @@ use hf_hub::{
|
||||||
};
|
};
|
||||||
use nix::sys::signal::{self, Signal};
|
use nix::sys::signal::{self, Signal};
|
||||||
use nix::unistd::Pid;
|
use nix::unistd::Pid;
|
||||||
|
use regex::Regex;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::ffi::OsString;
|
use std::ffi::OsString;
|
||||||
|
@ -1808,14 +1809,37 @@ fn main() -> Result<(), LauncherError> {
|
||||||
if adapter.contains('=') {
|
if adapter.contains('=') {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let adapter = adapter.trim();
|
||||||
|
|
||||||
|
// check if adapter has more than 1 '@'
|
||||||
|
if adapter.matches('@').count() > 1 {
|
||||||
|
return Err(LauncherError::ArgumentValidation(format!(
|
||||||
|
"Invalid LoRA adapter format: {}",
|
||||||
|
adapter
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// capture adapter_id, path, revision in format of adapter_id=path@revision
|
||||||
|
let re = Regex::new(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$").unwrap();
|
||||||
|
if let Some(caps) = re.captures(adapter) {
|
||||||
|
let adapter_id = caps.get(1).map_or("", |m| m.as_str());
|
||||||
|
let revision = caps.get(3).map(|m| m.as_str());
|
||||||
|
|
||||||
download_convert_model(
|
download_convert_model(
|
||||||
adapter,
|
adapter_id,
|
||||||
None,
|
revision,
|
||||||
args.trust_remote_code,
|
args.trust_remote_code,
|
||||||
args.huggingface_hub_cache.as_deref(),
|
args.huggingface_hub_cache.as_deref(),
|
||||||
args.weights_cache_override.as_deref(),
|
args.weights_cache_override.as_deref(),
|
||||||
running.clone(),
|
running.clone(),
|
||||||
)?;
|
)?;
|
||||||
|
} else {
|
||||||
|
return Err(LauncherError::ArgumentValidation(format!(
|
||||||
|
"Invalid LoRA adapter format: {}",
|
||||||
|
adapter
|
||||||
|
)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue