feat: support lora revisions and qkv_proj weights (#2482)
* feat: support lora revisions and qkv_proj weights * fix: add qkv_proj weights to weight test
This commit is contained in:
parent
47d7e34458
commit
6cb42f49ae
|
@ -1,6 +1,54 @@
|
|||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from text_generation_server.utils.adapter import get_attn_weights, get_mlp_weights
|
||||
from text_generation_server.utils.adapter import (
|
||||
get_attn_weights,
|
||||
get_mlp_weights,
|
||||
parse_lora_adapters,
|
||||
AdapterInfo,
|
||||
)
|
||||
|
||||
|
||||
def test_parse_lora_adapters_empty():
|
||||
assert parse_lora_adapters(None) == []
|
||||
assert parse_lora_adapters("") == []
|
||||
|
||||
|
||||
def test_parse_lora_adapters_single():
|
||||
result = parse_lora_adapters("adapter1")
|
||||
assert result == [AdapterInfo(id="adapter1", path=None, revision=None)]
|
||||
|
||||
|
||||
def test_parse_lora_adapters_with_path():
|
||||
result = parse_lora_adapters("adapter1=path/to/adapter1")
|
||||
assert result == [
|
||||
AdapterInfo(id="adapter1", path="path/to/adapter1", revision=None)
|
||||
]
|
||||
|
||||
|
||||
def test_parse_lora_adapters_with_path_and_revision():
|
||||
result = parse_lora_adapters("adapter1=path/to/adapter1@main")
|
||||
assert result == [
|
||||
AdapterInfo(id="adapter1", path="path/to/adapter1", revision="main")
|
||||
]
|
||||
|
||||
|
||||
def test_parse_lora_adapters_multiple():
|
||||
result = parse_lora_adapters(
|
||||
"adapter1,adapter2=path/to/adapter2,adapter3=path/to/adapter3@dev"
|
||||
)
|
||||
assert result == [
|
||||
AdapterInfo(id="adapter1", path=None, revision=None),
|
||||
AdapterInfo(id="adapter2", path="path/to/adapter2", revision=None),
|
||||
AdapterInfo(id="adapter3", path="path/to/adapter3", revision="dev"),
|
||||
]
|
||||
|
||||
|
||||
def test_parse_lora_adapters_invalid_format():
|
||||
try:
|
||||
parse_lora_adapters("adapter1,invalid=format=test,adapter3")
|
||||
assert False, "Should have raised ValueError"
|
||||
except ValueError as e:
|
||||
assert str(e) == "Invalid LoRA adapter format: invalid=format=test"
|
||||
|
||||
|
||||
def test_get_attn_weights():
|
||||
|
@ -22,6 +70,10 @@ def test_get_attn_weights():
|
|||
"model.layers.2.self_attn.k_proj",
|
||||
mock_layer.self_attn.query_key_value,
|
||||
),
|
||||
(2, "qkv_proj"): (
|
||||
"model.layers.2.self_attn.qkv_proj",
|
||||
mock_layer.self_attn.query_key_value,
|
||||
),
|
||||
(2, "v_proj"): (
|
||||
"model.layers.2.self_attn.v_proj",
|
||||
mock_layer.self_attn.query_key_value,
|
||||
|
@ -115,6 +167,10 @@ def test_get_attn_weights_llama_compatibility():
|
|||
"model.layers.2.self_attn.k_proj",
|
||||
mock_layer.self_attn.query_key_value,
|
||||
),
|
||||
(2, "qkv_proj"): (
|
||||
"model.layers.2.self_attn.qkv_proj",
|
||||
mock_layer.self_attn.query_key_value,
|
||||
),
|
||||
(2, "v_proj"): (
|
||||
"model.layers.2.self_attn.v_proj",
|
||||
mock_layer.self_attn.query_key_value,
|
||||
|
@ -155,6 +211,10 @@ def test_get_attn_weights_gemma_compatibility():
|
|||
"model.layers.2.self_attn.k_proj",
|
||||
mock_layer.self_attn.query_key_value,
|
||||
),
|
||||
(2, "qkv_proj"): (
|
||||
"model.layers.2.self_attn.qkv_proj",
|
||||
mock_layer.self_attn.query_key_value,
|
||||
),
|
||||
(2, "v_proj"): (
|
||||
"model.layers.2.self_attn.v_proj",
|
||||
mock_layer.self_attn.query_key_value,
|
||||
|
|
|
@ -1259,6 +1259,7 @@ def get_model_with_lora_adapters(
|
|||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
"qkv_proj",
|
||||
]
|
||||
|
||||
for layer_name in adapter_layers:
|
||||
|
@ -1286,7 +1287,7 @@ def get_model_with_lora_adapters(
|
|||
|
||||
if len(unused_weight_names) > 0:
|
||||
logger.warning(
|
||||
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
|
||||
f"{','.join([a.id for a in lora_adapters])} unused adapter weights: {unused_weight_names}"
|
||||
)
|
||||
|
||||
if adapter_tokenizer is not None:
|
||||
|
|
|
@ -66,15 +66,15 @@ def load_attention(config, prefix: str, weights, layer_id):
|
|||
prefixes = None
|
||||
|
||||
if config.model_type == "phi3":
|
||||
prefix = f"{prefix}.qkv_proj"
|
||||
base_layer = TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=prefix,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
)
|
||||
prefixes = ["qkv_proj"]
|
||||
elif config.model_type == "baichuan":
|
||||
prefix = f"{prefix}.W_pack"
|
||||
base_layer = TensorParallelColumnLinear.load_qkv(
|
||||
|
@ -85,6 +85,7 @@ def load_attention(config, prefix: str, weights, layer_id):
|
|||
num_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
)
|
||||
prefixes = [prefix]
|
||||
else:
|
||||
prefixes = ["q_proj", "k_proj", "v_proj"]
|
||||
sizes = [
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
# License: Apache License Version 2.0, January 2004
|
||||
|
||||
import warnings
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Set, Tuple, Optional, List
|
||||
|
@ -27,6 +28,7 @@ BASE_MODEL_ADAPTER_ID = "__base_model__"
|
|||
class AdapterInfo:
|
||||
id: str
|
||||
path: Optional[str]
|
||||
revision: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -51,11 +53,16 @@ def parse_lora_adapters(lora_adapters: Optional[str]) -> List[AdapterInfo]:
|
|||
|
||||
adapter_list = []
|
||||
for adapter in lora_adapters.split(","):
|
||||
parts = adapter.strip().split("=")
|
||||
if len(parts) == 1:
|
||||
adapter_list.append(AdapterInfo(id=parts[0], path=None))
|
||||
elif len(parts) == 2:
|
||||
adapter_list.append(AdapterInfo(id=parts[0], path=parts[1]))
|
||||
adapter = adapter.strip()
|
||||
if adapter.count("=") > 1 or adapter.count("@") > 1:
|
||||
raise ValueError(f"Invalid LoRA adapter format: {adapter}")
|
||||
match = re.match(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$", adapter)
|
||||
|
||||
if match:
|
||||
adapter_id, path, revision = match.groups()
|
||||
adapter_list.append(
|
||||
AdapterInfo(id=adapter_id, path=path, revision=revision)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid LoRA adapter format: {adapter}")
|
||||
return adapter_list
|
||||
|
@ -73,6 +80,7 @@ def load_and_merge_adapters(
|
|||
adapter_info = next(iter(adapter_parameters.adapter_info))
|
||||
return load_module_map(
|
||||
model_id,
|
||||
adapter_info.revision,
|
||||
adapter_info.id,
|
||||
adapter_info.path,
|
||||
weight_names,
|
||||
|
@ -80,7 +88,13 @@ def load_and_merge_adapters(
|
|||
)
|
||||
|
||||
adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index)
|
||||
return _load_and_merge(model_id, adapter_params, weight_names, trust_remote_code)
|
||||
return _load_and_merge(
|
||||
model_id,
|
||||
adapter_params.revision,
|
||||
adapter_params,
|
||||
weight_names,
|
||||
trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -95,6 +109,7 @@ class AdapterParametersContainer:
|
|||
@lru_cache(maxsize=32)
|
||||
def _load_and_merge(
|
||||
model_id: str,
|
||||
revision: str,
|
||||
adapter_params: AdapterParametersContainer,
|
||||
weight_names: Tuple[str],
|
||||
trust_remote_code: bool = False,
|
||||
|
@ -171,12 +186,12 @@ def check_architectures(
|
|||
@lru_cache(maxsize=128)
|
||||
def load_module_map(
|
||||
model_id: str,
|
||||
revision: str,
|
||||
adapter_id: str,
|
||||
adapter_path: Optional[str],
|
||||
weight_names: Tuple[str],
|
||||
trust_remote_code: bool = False,
|
||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||
revision = "main"
|
||||
|
||||
adapter_config = LoraConfig.load(adapter_path or adapter_id, None)
|
||||
|
||||
|
@ -191,6 +206,12 @@ def load_module_map(
|
|||
)
|
||||
)
|
||||
|
||||
# throw an error if no adapter weights are found
|
||||
if not adapter_filenames:
|
||||
raise FileNotFoundError(
|
||||
f"No adapter weights found for adapter '{adapter_id}' and revision '{revision}'."
|
||||
)
|
||||
|
||||
try:
|
||||
adapter_tokenizer = AutoTokenizer.from_pretrained(
|
||||
adapter_config.config_path,
|
||||
|
@ -221,6 +242,12 @@ def get_attn_weights(i, layer):
|
|||
value = (f"model.layers.{i}.self_attn.{k}_proj", qkv)
|
||||
weights[key] = value
|
||||
|
||||
# also add the qkv_proj weight for the adapter
|
||||
weights[(i, "qkv_proj")] = (
|
||||
f"model.layers.{i}.self_attn.qkv_proj",
|
||||
qkv,
|
||||
)
|
||||
|
||||
weights[(i, "o_proj")] = (
|
||||
f"model.layers.{i}.self_attn.o_proj",
|
||||
layer.self_attn.o_proj,
|
||||
|
|
Loading…
Reference in New Issue