feat: support lora revisions and qkv_proj weights (#2482)
* feat: support lora revisions and qkv_proj weights * fix: add qkv_proj weights to weight test
This commit is contained in:
parent
47d7e34458
commit
6cb42f49ae
|
@ -1,6 +1,54 @@
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
from text_generation_server.utils.adapter import get_attn_weights, get_mlp_weights
|
from text_generation_server.utils.adapter import (
|
||||||
|
get_attn_weights,
|
||||||
|
get_mlp_weights,
|
||||||
|
parse_lora_adapters,
|
||||||
|
AdapterInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_lora_adapters_empty():
|
||||||
|
assert parse_lora_adapters(None) == []
|
||||||
|
assert parse_lora_adapters("") == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_lora_adapters_single():
|
||||||
|
result = parse_lora_adapters("adapter1")
|
||||||
|
assert result == [AdapterInfo(id="adapter1", path=None, revision=None)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_lora_adapters_with_path():
|
||||||
|
result = parse_lora_adapters("adapter1=path/to/adapter1")
|
||||||
|
assert result == [
|
||||||
|
AdapterInfo(id="adapter1", path="path/to/adapter1", revision=None)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_lora_adapters_with_path_and_revision():
|
||||||
|
result = parse_lora_adapters("adapter1=path/to/adapter1@main")
|
||||||
|
assert result == [
|
||||||
|
AdapterInfo(id="adapter1", path="path/to/adapter1", revision="main")
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_lora_adapters_multiple():
|
||||||
|
result = parse_lora_adapters(
|
||||||
|
"adapter1,adapter2=path/to/adapter2,adapter3=path/to/adapter3@dev"
|
||||||
|
)
|
||||||
|
assert result == [
|
||||||
|
AdapterInfo(id="adapter1", path=None, revision=None),
|
||||||
|
AdapterInfo(id="adapter2", path="path/to/adapter2", revision=None),
|
||||||
|
AdapterInfo(id="adapter3", path="path/to/adapter3", revision="dev"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_lora_adapters_invalid_format():
|
||||||
|
try:
|
||||||
|
parse_lora_adapters("adapter1,invalid=format=test,adapter3")
|
||||||
|
assert False, "Should have raised ValueError"
|
||||||
|
except ValueError as e:
|
||||||
|
assert str(e) == "Invalid LoRA adapter format: invalid=format=test"
|
||||||
|
|
||||||
|
|
||||||
def test_get_attn_weights():
|
def test_get_attn_weights():
|
||||||
|
@ -22,6 +70,10 @@ def test_get_attn_weights():
|
||||||
"model.layers.2.self_attn.k_proj",
|
"model.layers.2.self_attn.k_proj",
|
||||||
mock_layer.self_attn.query_key_value,
|
mock_layer.self_attn.query_key_value,
|
||||||
),
|
),
|
||||||
|
(2, "qkv_proj"): (
|
||||||
|
"model.layers.2.self_attn.qkv_proj",
|
||||||
|
mock_layer.self_attn.query_key_value,
|
||||||
|
),
|
||||||
(2, "v_proj"): (
|
(2, "v_proj"): (
|
||||||
"model.layers.2.self_attn.v_proj",
|
"model.layers.2.self_attn.v_proj",
|
||||||
mock_layer.self_attn.query_key_value,
|
mock_layer.self_attn.query_key_value,
|
||||||
|
@ -115,6 +167,10 @@ def test_get_attn_weights_llama_compatibility():
|
||||||
"model.layers.2.self_attn.k_proj",
|
"model.layers.2.self_attn.k_proj",
|
||||||
mock_layer.self_attn.query_key_value,
|
mock_layer.self_attn.query_key_value,
|
||||||
),
|
),
|
||||||
|
(2, "qkv_proj"): (
|
||||||
|
"model.layers.2.self_attn.qkv_proj",
|
||||||
|
mock_layer.self_attn.query_key_value,
|
||||||
|
),
|
||||||
(2, "v_proj"): (
|
(2, "v_proj"): (
|
||||||
"model.layers.2.self_attn.v_proj",
|
"model.layers.2.self_attn.v_proj",
|
||||||
mock_layer.self_attn.query_key_value,
|
mock_layer.self_attn.query_key_value,
|
||||||
|
@ -155,6 +211,10 @@ def test_get_attn_weights_gemma_compatibility():
|
||||||
"model.layers.2.self_attn.k_proj",
|
"model.layers.2.self_attn.k_proj",
|
||||||
mock_layer.self_attn.query_key_value,
|
mock_layer.self_attn.query_key_value,
|
||||||
),
|
),
|
||||||
|
(2, "qkv_proj"): (
|
||||||
|
"model.layers.2.self_attn.qkv_proj",
|
||||||
|
mock_layer.self_attn.query_key_value,
|
||||||
|
),
|
||||||
(2, "v_proj"): (
|
(2, "v_proj"): (
|
||||||
"model.layers.2.self_attn.v_proj",
|
"model.layers.2.self_attn.v_proj",
|
||||||
mock_layer.self_attn.query_key_value,
|
mock_layer.self_attn.query_key_value,
|
||||||
|
|
|
@ -1259,6 +1259,7 @@ def get_model_with_lora_adapters(
|
||||||
"gate_proj",
|
"gate_proj",
|
||||||
"up_proj",
|
"up_proj",
|
||||||
"down_proj",
|
"down_proj",
|
||||||
|
"qkv_proj",
|
||||||
]
|
]
|
||||||
|
|
||||||
for layer_name in adapter_layers:
|
for layer_name in adapter_layers:
|
||||||
|
@ -1286,7 +1287,7 @@ def get_model_with_lora_adapters(
|
||||||
|
|
||||||
if len(unused_weight_names) > 0:
|
if len(unused_weight_names) > 0:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
|
f"{','.join([a.id for a in lora_adapters])} unused adapter weights: {unused_weight_names}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if adapter_tokenizer is not None:
|
if adapter_tokenizer is not None:
|
||||||
|
|
|
@ -66,15 +66,15 @@ def load_attention(config, prefix: str, weights, layer_id):
|
||||||
prefixes = None
|
prefixes = None
|
||||||
|
|
||||||
if config.model_type == "phi3":
|
if config.model_type == "phi3":
|
||||||
prefix = f"{prefix}.qkv_proj"
|
|
||||||
base_layer = TensorParallelColumnLinear.load_qkv(
|
base_layer = TensorParallelColumnLinear.load_qkv(
|
||||||
config,
|
config,
|
||||||
prefix=prefix,
|
prefix=f"{prefix}.qkv_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
num_key_value_heads=config.num_key_value_heads,
|
num_key_value_heads=config.num_key_value_heads,
|
||||||
)
|
)
|
||||||
|
prefixes = ["qkv_proj"]
|
||||||
elif config.model_type == "baichuan":
|
elif config.model_type == "baichuan":
|
||||||
prefix = f"{prefix}.W_pack"
|
prefix = f"{prefix}.W_pack"
|
||||||
base_layer = TensorParallelColumnLinear.load_qkv(
|
base_layer = TensorParallelColumnLinear.load_qkv(
|
||||||
|
@ -85,6 +85,7 @@ def load_attention(config, prefix: str, weights, layer_id):
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
num_key_value_heads=config.num_key_value_heads,
|
num_key_value_heads=config.num_key_value_heads,
|
||||||
)
|
)
|
||||||
|
prefixes = [prefix]
|
||||||
else:
|
else:
|
||||||
prefixes = ["q_proj", "k_proj", "v_proj"]
|
prefixes = ["q_proj", "k_proj", "v_proj"]
|
||||||
sizes = [
|
sizes = [
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
# License: Apache License Version 2.0, January 2004
|
# License: Apache License Version 2.0, January 2004
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import TYPE_CHECKING, Set, Tuple, Optional, List
|
from typing import TYPE_CHECKING, Set, Tuple, Optional, List
|
||||||
|
@ -27,6 +28,7 @@ BASE_MODEL_ADAPTER_ID = "__base_model__"
|
||||||
class AdapterInfo:
|
class AdapterInfo:
|
||||||
id: str
|
id: str
|
||||||
path: Optional[str]
|
path: Optional[str]
|
||||||
|
revision: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -51,11 +53,16 @@ def parse_lora_adapters(lora_adapters: Optional[str]) -> List[AdapterInfo]:
|
||||||
|
|
||||||
adapter_list = []
|
adapter_list = []
|
||||||
for adapter in lora_adapters.split(","):
|
for adapter in lora_adapters.split(","):
|
||||||
parts = adapter.strip().split("=")
|
adapter = adapter.strip()
|
||||||
if len(parts) == 1:
|
if adapter.count("=") > 1 or adapter.count("@") > 1:
|
||||||
adapter_list.append(AdapterInfo(id=parts[0], path=None))
|
raise ValueError(f"Invalid LoRA adapter format: {adapter}")
|
||||||
elif len(parts) == 2:
|
match = re.match(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$", adapter)
|
||||||
adapter_list.append(AdapterInfo(id=parts[0], path=parts[1]))
|
|
||||||
|
if match:
|
||||||
|
adapter_id, path, revision = match.groups()
|
||||||
|
adapter_list.append(
|
||||||
|
AdapterInfo(id=adapter_id, path=path, revision=revision)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid LoRA adapter format: {adapter}")
|
raise ValueError(f"Invalid LoRA adapter format: {adapter}")
|
||||||
return adapter_list
|
return adapter_list
|
||||||
|
@ -73,6 +80,7 @@ def load_and_merge_adapters(
|
||||||
adapter_info = next(iter(adapter_parameters.adapter_info))
|
adapter_info = next(iter(adapter_parameters.adapter_info))
|
||||||
return load_module_map(
|
return load_module_map(
|
||||||
model_id,
|
model_id,
|
||||||
|
adapter_info.revision,
|
||||||
adapter_info.id,
|
adapter_info.id,
|
||||||
adapter_info.path,
|
adapter_info.path,
|
||||||
weight_names,
|
weight_names,
|
||||||
|
@ -80,7 +88,13 @@ def load_and_merge_adapters(
|
||||||
)
|
)
|
||||||
|
|
||||||
adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index)
|
adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index)
|
||||||
return _load_and_merge(model_id, adapter_params, weight_names, trust_remote_code)
|
return _load_and_merge(
|
||||||
|
model_id,
|
||||||
|
adapter_params.revision,
|
||||||
|
adapter_params,
|
||||||
|
weight_names,
|
||||||
|
trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -95,6 +109,7 @@ class AdapterParametersContainer:
|
||||||
@lru_cache(maxsize=32)
|
@lru_cache(maxsize=32)
|
||||||
def _load_and_merge(
|
def _load_and_merge(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
revision: str,
|
||||||
adapter_params: AdapterParametersContainer,
|
adapter_params: AdapterParametersContainer,
|
||||||
weight_names: Tuple[str],
|
weight_names: Tuple[str],
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
|
@ -171,12 +186,12 @@ def check_architectures(
|
||||||
@lru_cache(maxsize=128)
|
@lru_cache(maxsize=128)
|
||||||
def load_module_map(
|
def load_module_map(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
revision: str,
|
||||||
adapter_id: str,
|
adapter_id: str,
|
||||||
adapter_path: Optional[str],
|
adapter_path: Optional[str],
|
||||||
weight_names: Tuple[str],
|
weight_names: Tuple[str],
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||||
revision = "main"
|
|
||||||
|
|
||||||
adapter_config = LoraConfig.load(adapter_path or adapter_id, None)
|
adapter_config = LoraConfig.load(adapter_path or adapter_id, None)
|
||||||
|
|
||||||
|
@ -191,6 +206,12 @@ def load_module_map(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# throw an error if no adapter weights are found
|
||||||
|
if not adapter_filenames:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"No adapter weights found for adapter '{adapter_id}' and revision '{revision}'."
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
adapter_tokenizer = AutoTokenizer.from_pretrained(
|
adapter_tokenizer = AutoTokenizer.from_pretrained(
|
||||||
adapter_config.config_path,
|
adapter_config.config_path,
|
||||||
|
@ -221,6 +242,12 @@ def get_attn_weights(i, layer):
|
||||||
value = (f"model.layers.{i}.self_attn.{k}_proj", qkv)
|
value = (f"model.layers.{i}.self_attn.{k}_proj", qkv)
|
||||||
weights[key] = value
|
weights[key] = value
|
||||||
|
|
||||||
|
# also add the qkv_proj weight for the adapter
|
||||||
|
weights[(i, "qkv_proj")] = (
|
||||||
|
f"model.layers.{i}.self_attn.qkv_proj",
|
||||||
|
qkv,
|
||||||
|
)
|
||||||
|
|
||||||
weights[(i, "o_proj")] = (
|
weights[(i, "o_proj")] = (
|
||||||
f"model.layers.{i}.self_attn.o_proj",
|
f"model.layers.{i}.self_attn.o_proj",
|
||||||
layer.self_attn.o_proj,
|
layer.self_attn.o_proj,
|
||||||
|
|
Loading…
Reference in New Issue