feat: support lora revisions and qkv_proj weights (#2482)

* feat: support lora revisions and qkv_proj weights

* fix: add qkv_proj weights to weight test
This commit is contained in:
drbh 2024-09-02 13:09:06 -04:00 committed by GitHub
parent 47d7e34458
commit 6cb42f49ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 100 additions and 11 deletions

View File

@ -1,6 +1,54 @@
import pytest import pytest
from unittest.mock import Mock from unittest.mock import Mock
from text_generation_server.utils.adapter import get_attn_weights, get_mlp_weights from text_generation_server.utils.adapter import (
get_attn_weights,
get_mlp_weights,
parse_lora_adapters,
AdapterInfo,
)
def test_parse_lora_adapters_empty():
assert parse_lora_adapters(None) == []
assert parse_lora_adapters("") == []
def test_parse_lora_adapters_single():
result = parse_lora_adapters("adapter1")
assert result == [AdapterInfo(id="adapter1", path=None, revision=None)]
def test_parse_lora_adapters_with_path():
result = parse_lora_adapters("adapter1=path/to/adapter1")
assert result == [
AdapterInfo(id="adapter1", path="path/to/adapter1", revision=None)
]
def test_parse_lora_adapters_with_path_and_revision():
result = parse_lora_adapters("adapter1=path/to/adapter1@main")
assert result == [
AdapterInfo(id="adapter1", path="path/to/adapter1", revision="main")
]
def test_parse_lora_adapters_multiple():
result = parse_lora_adapters(
"adapter1,adapter2=path/to/adapter2,adapter3=path/to/adapter3@dev"
)
assert result == [
AdapterInfo(id="adapter1", path=None, revision=None),
AdapterInfo(id="adapter2", path="path/to/adapter2", revision=None),
AdapterInfo(id="adapter3", path="path/to/adapter3", revision="dev"),
]
def test_parse_lora_adapters_invalid_format():
try:
parse_lora_adapters("adapter1,invalid=format=test,adapter3")
assert False, "Should have raised ValueError"
except ValueError as e:
assert str(e) == "Invalid LoRA adapter format: invalid=format=test"
def test_get_attn_weights(): def test_get_attn_weights():
@ -22,6 +70,10 @@ def test_get_attn_weights():
"model.layers.2.self_attn.k_proj", "model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value, mock_layer.self_attn.query_key_value,
), ),
(2, "qkv_proj"): (
"model.layers.2.self_attn.qkv_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): ( (2, "v_proj"): (
"model.layers.2.self_attn.v_proj", "model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value, mock_layer.self_attn.query_key_value,
@ -115,6 +167,10 @@ def test_get_attn_weights_llama_compatibility():
"model.layers.2.self_attn.k_proj", "model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value, mock_layer.self_attn.query_key_value,
), ),
(2, "qkv_proj"): (
"model.layers.2.self_attn.qkv_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): ( (2, "v_proj"): (
"model.layers.2.self_attn.v_proj", "model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value, mock_layer.self_attn.query_key_value,
@ -155,6 +211,10 @@ def test_get_attn_weights_gemma_compatibility():
"model.layers.2.self_attn.k_proj", "model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value, mock_layer.self_attn.query_key_value,
), ),
(2, "qkv_proj"): (
"model.layers.2.self_attn.qkv_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): ( (2, "v_proj"): (
"model.layers.2.self_attn.v_proj", "model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value, mock_layer.self_attn.query_key_value,

View File

@ -1259,6 +1259,7 @@ def get_model_with_lora_adapters(
"gate_proj", "gate_proj",
"up_proj", "up_proj",
"down_proj", "down_proj",
"qkv_proj",
] ]
for layer_name in adapter_layers: for layer_name in adapter_layers:
@ -1286,7 +1287,7 @@ def get_model_with_lora_adapters(
if len(unused_weight_names) > 0: if len(unused_weight_names) > 0:
logger.warning( logger.warning(
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}" f"{','.join([a.id for a in lora_adapters])} unused adapter weights: {unused_weight_names}"
) )
if adapter_tokenizer is not None: if adapter_tokenizer is not None:

View File

@ -66,15 +66,15 @@ def load_attention(config, prefix: str, weights, layer_id):
prefixes = None prefixes = None
if config.model_type == "phi3": if config.model_type == "phi3":
prefix = f"{prefix}.qkv_proj"
base_layer = TensorParallelColumnLinear.load_qkv( base_layer = TensorParallelColumnLinear.load_qkv(
config, config,
prefix=prefix, prefix=f"{prefix}.qkv_proj",
weights=weights, weights=weights,
bias=bias, bias=bias,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads, num_key_value_heads=config.num_key_value_heads,
) )
prefixes = ["qkv_proj"]
elif config.model_type == "baichuan": elif config.model_type == "baichuan":
prefix = f"{prefix}.W_pack" prefix = f"{prefix}.W_pack"
base_layer = TensorParallelColumnLinear.load_qkv( base_layer = TensorParallelColumnLinear.load_qkv(
@ -85,6 +85,7 @@ def load_attention(config, prefix: str, weights, layer_id):
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads, num_key_value_heads=config.num_key_value_heads,
) )
prefixes = [prefix]
else: else:
prefixes = ["q_proj", "k_proj", "v_proj"] prefixes = ["q_proj", "k_proj", "v_proj"]
sizes = [ sizes = [

View File

@ -3,6 +3,7 @@
# License: Apache License Version 2.0, January 2004 # License: Apache License Version 2.0, January 2004
import warnings import warnings
import re
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from typing import TYPE_CHECKING, Set, Tuple, Optional, List from typing import TYPE_CHECKING, Set, Tuple, Optional, List
@ -27,6 +28,7 @@ BASE_MODEL_ADAPTER_ID = "__base_model__"
class AdapterInfo: class AdapterInfo:
id: str id: str
path: Optional[str] path: Optional[str]
revision: Optional[str] = None
@dataclass @dataclass
@ -51,11 +53,16 @@ def parse_lora_adapters(lora_adapters: Optional[str]) -> List[AdapterInfo]:
adapter_list = [] adapter_list = []
for adapter in lora_adapters.split(","): for adapter in lora_adapters.split(","):
parts = adapter.strip().split("=") adapter = adapter.strip()
if len(parts) == 1: if adapter.count("=") > 1 or adapter.count("@") > 1:
adapter_list.append(AdapterInfo(id=parts[0], path=None)) raise ValueError(f"Invalid LoRA adapter format: {adapter}")
elif len(parts) == 2: match = re.match(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$", adapter)
adapter_list.append(AdapterInfo(id=parts[0], path=parts[1]))
if match:
adapter_id, path, revision = match.groups()
adapter_list.append(
AdapterInfo(id=adapter_id, path=path, revision=revision)
)
else: else:
raise ValueError(f"Invalid LoRA adapter format: {adapter}") raise ValueError(f"Invalid LoRA adapter format: {adapter}")
return adapter_list return adapter_list
@ -73,6 +80,7 @@ def load_and_merge_adapters(
adapter_info = next(iter(adapter_parameters.adapter_info)) adapter_info = next(iter(adapter_parameters.adapter_info))
return load_module_map( return load_module_map(
model_id, model_id,
adapter_info.revision,
adapter_info.id, adapter_info.id,
adapter_info.path, adapter_info.path,
weight_names, weight_names,
@ -80,7 +88,13 @@ def load_and_merge_adapters(
) )
adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index) adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index)
return _load_and_merge(model_id, adapter_params, weight_names, trust_remote_code) return _load_and_merge(
model_id,
adapter_params.revision,
adapter_params,
weight_names,
trust_remote_code,
)
@dataclass @dataclass
@ -95,6 +109,7 @@ class AdapterParametersContainer:
@lru_cache(maxsize=32) @lru_cache(maxsize=32)
def _load_and_merge( def _load_and_merge(
model_id: str, model_id: str,
revision: str,
adapter_params: AdapterParametersContainer, adapter_params: AdapterParametersContainer,
weight_names: Tuple[str], weight_names: Tuple[str],
trust_remote_code: bool = False, trust_remote_code: bool = False,
@ -171,12 +186,12 @@ def check_architectures(
@lru_cache(maxsize=128) @lru_cache(maxsize=128)
def load_module_map( def load_module_map(
model_id: str, model_id: str,
revision: str,
adapter_id: str, adapter_id: str,
adapter_path: Optional[str], adapter_path: Optional[str],
weight_names: Tuple[str], weight_names: Tuple[str],
trust_remote_code: bool = False, trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
revision = "main"
adapter_config = LoraConfig.load(adapter_path or adapter_id, None) adapter_config = LoraConfig.load(adapter_path or adapter_id, None)
@ -191,6 +206,12 @@ def load_module_map(
) )
) )
# throw an error if no adapter weights are found
if not adapter_filenames:
raise FileNotFoundError(
f"No adapter weights found for adapter '{adapter_id}' and revision '{revision}'."
)
try: try:
adapter_tokenizer = AutoTokenizer.from_pretrained( adapter_tokenizer = AutoTokenizer.from_pretrained(
adapter_config.config_path, adapter_config.config_path,
@ -221,6 +242,12 @@ def get_attn_weights(i, layer):
value = (f"model.layers.{i}.self_attn.{k}_proj", qkv) value = (f"model.layers.{i}.self_attn.{k}_proj", qkv)
weights[key] = value weights[key] = value
# also add the qkv_proj weight for the adapter
weights[(i, "qkv_proj")] = (
f"model.layers.{i}.self_attn.qkv_proj",
qkv,
)
weights[(i, "o_proj")] = ( weights[(i, "o_proj")] = (
f"model.layers.{i}.self_attn.o_proj", f"model.layers.{i}.self_attn.o_proj",
layer.self_attn.o_proj, layer.self_attn.o_proj,