CI (2592): Allow LoRA adapter revision in server launcher (#2602)

allow revision for lora adapters from launcher

Co-authored-by: Sida <sida@kulamind.com>
Co-authored-by: teamclouday <teamclouday@gmail.com>
This commit is contained in:
drbh 2024-10-02 10:51:04 -04:00 committed by GitHub
parent 0204946d26
commit 2335459556
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 40 additions and 8 deletions

1
Cargo.lock generated
View File

@ -4247,6 +4247,7 @@ dependencies = [
"nix 0.28.0",
"once_cell",
"pyo3",
"regex",
"reqwest",
"serde",
"serde_json",

View File

@ -36,6 +36,12 @@ To use LoRA in TGI, when starting the server, you can specify the list of LoRA m
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
```
To specify model revision, use `adapter_id@revision`, as follows:
```bash
LORA_ADAPTERS=predibase/customer_support@main,predibase/dbpedia@rev2
```
To use a locally stored lora adapter, use `adapter-name=/path/to/adapter`, as seen below. When you want to use this adapter, set `"parameters": {"adapter_id": "adapter-name"}"`
```bash

View File

@ -18,6 +18,7 @@ serde_json = "1.0.107"
thiserror = "1.0.59"
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
regex = "1.11.0"
[dev-dependencies]
float_eq = "1.0.1"

View File

@ -5,6 +5,7 @@ use hf_hub::{
};
use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
use regex::Regex;
use serde::Deserialize;
use std::env;
use std::ffi::OsString;
@ -1808,14 +1809,37 @@ fn main() -> Result<(), LauncherError> {
if adapter.contains('=') {
continue;
}
let adapter = adapter.trim();
// check if adapter has more than 1 '@'
if adapter.matches('@').count() > 1 {
return Err(LauncherError::ArgumentValidation(format!(
"Invalid LoRA adapter format: {}",
adapter
)));
}
// capture adapter_id, path, revision in format of adapter_id=path@revision
let re = Regex::new(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$").unwrap();
if let Some(caps) = re.captures(adapter) {
let adapter_id = caps.get(1).map_or("", |m| m.as_str());
let revision = caps.get(3).map(|m| m.as_str());
download_convert_model(
adapter,
None,
adapter_id,
revision,
args.trust_remote_code,
args.huggingface_hub_cache.as_deref(),
args.weights_cache_override.as_deref(),
running.clone(),
)?;
} else {
return Err(LauncherError::ArgumentValidation(format!(
"Invalid LoRA adapter format: {}",
adapter
)));
}
}
}