feat: support lora revisions and qkv_proj weights (#2482)

* feat: support lora revisions and qkv_proj weights

* fix: add qkv_proj weights to weight test
This commit is contained in:
drbh 2024-09-02 13:09:06 -04:00 committed by GitHub
parent 47d7e34458
commit 6cb42f49ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 100 additions and 11 deletions

View File

@ -1,6 +1,54 @@
import pytest
from unittest.mock import Mock
from text_generation_server.utils.adapter import get_attn_weights, get_mlp_weights
from text_generation_server.utils.adapter import (
get_attn_weights,
get_mlp_weights,
parse_lora_adapters,
AdapterInfo,
)
def test_parse_lora_adapters_empty():
assert parse_lora_adapters(None) == []
assert parse_lora_adapters("") == []
def test_parse_lora_adapters_single():
result = parse_lora_adapters("adapter1")
assert result == [AdapterInfo(id="adapter1", path=None, revision=None)]
def test_parse_lora_adapters_with_path():
result = parse_lora_adapters("adapter1=path/to/adapter1")
assert result == [
AdapterInfo(id="adapter1", path="path/to/adapter1", revision=None)
]
def test_parse_lora_adapters_with_path_and_revision():
result = parse_lora_adapters("adapter1=path/to/adapter1@main")
assert result == [
AdapterInfo(id="adapter1", path="path/to/adapter1", revision="main")
]
def test_parse_lora_adapters_multiple():
result = parse_lora_adapters(
"adapter1,adapter2=path/to/adapter2,adapter3=path/to/adapter3@dev"
)
assert result == [
AdapterInfo(id="adapter1", path=None, revision=None),
AdapterInfo(id="adapter2", path="path/to/adapter2", revision=None),
AdapterInfo(id="adapter3", path="path/to/adapter3", revision="dev"),
]
def test_parse_lora_adapters_invalid_format():
try:
parse_lora_adapters("adapter1,invalid=format=test,adapter3")
assert False, "Should have raised ValueError"
except ValueError as e:
assert str(e) == "Invalid LoRA adapter format: invalid=format=test"
def test_get_attn_weights():
@ -22,6 +70,10 @@ def test_get_attn_weights():
"model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value,
),
(2, "qkv_proj"): (
"model.layers.2.self_attn.qkv_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): (
"model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value,
@ -115,6 +167,10 @@ def test_get_attn_weights_llama_compatibility():
"model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value,
),
(2, "qkv_proj"): (
"model.layers.2.self_attn.qkv_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): (
"model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value,
@ -155,6 +211,10 @@ def test_get_attn_weights_gemma_compatibility():
"model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value,
),
(2, "qkv_proj"): (
"model.layers.2.self_attn.qkv_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): (
"model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value,

View File

@ -1259,6 +1259,7 @@ def get_model_with_lora_adapters(
"gate_proj",
"up_proj",
"down_proj",
"qkv_proj",
]
for layer_name in adapter_layers:
@ -1286,7 +1287,7 @@ def get_model_with_lora_adapters(
if len(unused_weight_names) > 0:
logger.warning(
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
f"{','.join([a.id for a in lora_adapters])} unused adapter weights: {unused_weight_names}"
)
if adapter_tokenizer is not None:

View File

@ -66,15 +66,15 @@ def load_attention(config, prefix: str, weights, layer_id):
prefixes = None
if config.model_type == "phi3":
prefix = f"{prefix}.qkv_proj"
base_layer = TensorParallelColumnLinear.load_qkv(
config,
prefix=prefix,
prefix=f"{prefix}.qkv_proj",
weights=weights,
bias=bias,
num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
)
prefixes = ["qkv_proj"]
elif config.model_type == "baichuan":
prefix = f"{prefix}.W_pack"
base_layer = TensorParallelColumnLinear.load_qkv(
@ -85,6 +85,7 @@ def load_attention(config, prefix: str, weights, layer_id):
num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
)
prefixes = [prefix]
else:
prefixes = ["q_proj", "k_proj", "v_proj"]
sizes = [

View File

@ -3,6 +3,7 @@
# License: Apache License Version 2.0, January 2004
import warnings
import re
from dataclasses import dataclass
from functools import lru_cache
from typing import TYPE_CHECKING, Set, Tuple, Optional, List
@ -27,6 +28,7 @@ BASE_MODEL_ADAPTER_ID = "__base_model__"
class AdapterInfo:
id: str
path: Optional[str]
revision: Optional[str] = None
@dataclass
@ -51,11 +53,16 @@ def parse_lora_adapters(lora_adapters: Optional[str]) -> List[AdapterInfo]:
adapter_list = []
for adapter in lora_adapters.split(","):
parts = adapter.strip().split("=")
if len(parts) == 1:
adapter_list.append(AdapterInfo(id=parts[0], path=None))
elif len(parts) == 2:
adapter_list.append(AdapterInfo(id=parts[0], path=parts[1]))
adapter = adapter.strip()
if adapter.count("=") > 1 or adapter.count("@") > 1:
raise ValueError(f"Invalid LoRA adapter format: {adapter}")
match = re.match(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$", adapter)
if match:
adapter_id, path, revision = match.groups()
adapter_list.append(
AdapterInfo(id=adapter_id, path=path, revision=revision)
)
else:
raise ValueError(f"Invalid LoRA adapter format: {adapter}")
return adapter_list
@ -73,6 +80,7 @@ def load_and_merge_adapters(
adapter_info = next(iter(adapter_parameters.adapter_info))
return load_module_map(
model_id,
adapter_info.revision,
adapter_info.id,
adapter_info.path,
weight_names,
@ -80,7 +88,13 @@ def load_and_merge_adapters(
)
adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index)
return _load_and_merge(model_id, adapter_params, weight_names, trust_remote_code)
return _load_and_merge(
model_id,
adapter_params.revision,
adapter_params,
weight_names,
trust_remote_code,
)
@dataclass
@ -95,6 +109,7 @@ class AdapterParametersContainer:
@lru_cache(maxsize=32)
def _load_and_merge(
model_id: str,
revision: str,
adapter_params: AdapterParametersContainer,
weight_names: Tuple[str],
trust_remote_code: bool = False,
@ -171,12 +186,12 @@ def check_architectures(
@lru_cache(maxsize=128)
def load_module_map(
model_id: str,
revision: str,
adapter_id: str,
adapter_path: Optional[str],
weight_names: Tuple[str],
trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
revision = "main"
adapter_config = LoraConfig.load(adapter_path or adapter_id, None)
@ -191,6 +206,12 @@ def load_module_map(
)
)
# throw an error if no adapter weights are found
if not adapter_filenames:
raise FileNotFoundError(
f"No adapter weights found for adapter '{adapter_id}' and revision '{revision}'."
)
try:
adapter_tokenizer = AutoTokenizer.from_pretrained(
adapter_config.config_path,
@ -221,6 +242,12 @@ def get_attn_weights(i, layer):
value = (f"model.layers.{i}.self_attn.{k}_proj", qkv)
weights[key] = value
# also add the qkv_proj weight for the adapter
weights[(i, "qkv_proj")] = (
f"model.layers.{i}.self_attn.qkv_proj",
qkv,
)
weights[(i, "o_proj")] = (
f"model.layers.{i}.self_attn.o_proj",
layer.self_attn.o_proj,