From 6cb42f49ae47a117e8f1bdfcdb5cbe42332dc360 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 2 Sep 2024 13:09:06 -0400 Subject: [PATCH] feat: support lora revisions and qkv_proj weights (#2482) * feat: support lora revisions and qkv_proj weights * fix: add qkv_proj weights to weight test --- server/tests/utils/test_adapter.py | 62 ++++++++++++++++++- .../text_generation_server/models/__init__.py | 3 +- .../custom_modeling/flash_llama_modeling.py | 5 +- .../text_generation_server/utils/adapter.py | 41 +++++++++--- 4 files changed, 100 insertions(+), 11 deletions(-) diff --git a/server/tests/utils/test_adapter.py b/server/tests/utils/test_adapter.py index cc1b076d..a27c1055 100644 --- a/server/tests/utils/test_adapter.py +++ b/server/tests/utils/test_adapter.py @@ -1,6 +1,54 @@ import pytest from unittest.mock import Mock -from text_generation_server.utils.adapter import get_attn_weights, get_mlp_weights +from text_generation_server.utils.adapter import ( + get_attn_weights, + get_mlp_weights, + parse_lora_adapters, + AdapterInfo, +) + + +def test_parse_lora_adapters_empty(): + assert parse_lora_adapters(None) == [] + assert parse_lora_adapters("") == [] + + +def test_parse_lora_adapters_single(): + result = parse_lora_adapters("adapter1") + assert result == [AdapterInfo(id="adapter1", path=None, revision=None)] + + +def test_parse_lora_adapters_with_path(): + result = parse_lora_adapters("adapter1=path/to/adapter1") + assert result == [ + AdapterInfo(id="adapter1", path="path/to/adapter1", revision=None) + ] + + +def test_parse_lora_adapters_with_path_and_revision(): + result = parse_lora_adapters("adapter1=path/to/adapter1@main") + assert result == [ + AdapterInfo(id="adapter1", path="path/to/adapter1", revision="main") + ] + + +def test_parse_lora_adapters_multiple(): + result = parse_lora_adapters( + "adapter1,adapter2=path/to/adapter2,adapter3=path/to/adapter3@dev" + ) + assert result == [ + AdapterInfo(id="adapter1", path=None, revision=None), + AdapterInfo(id="adapter2", path="path/to/adapter2", revision=None), + AdapterInfo(id="adapter3", path="path/to/adapter3", revision="dev"), + ] + + +def test_parse_lora_adapters_invalid_format(): + try: + parse_lora_adapters("adapter1,invalid=format=test,adapter3") + assert False, "Should have raised ValueError" + except ValueError as e: + assert str(e) == "Invalid LoRA adapter format: invalid=format=test" def test_get_attn_weights(): @@ -22,6 +70,10 @@ def test_get_attn_weights(): "model.layers.2.self_attn.k_proj", mock_layer.self_attn.query_key_value, ), + (2, "qkv_proj"): ( + "model.layers.2.self_attn.qkv_proj", + mock_layer.self_attn.query_key_value, + ), (2, "v_proj"): ( "model.layers.2.self_attn.v_proj", mock_layer.self_attn.query_key_value, @@ -115,6 +167,10 @@ def test_get_attn_weights_llama_compatibility(): "model.layers.2.self_attn.k_proj", mock_layer.self_attn.query_key_value, ), + (2, "qkv_proj"): ( + "model.layers.2.self_attn.qkv_proj", + mock_layer.self_attn.query_key_value, + ), (2, "v_proj"): ( "model.layers.2.self_attn.v_proj", mock_layer.self_attn.query_key_value, @@ -155,6 +211,10 @@ def test_get_attn_weights_gemma_compatibility(): "model.layers.2.self_attn.k_proj", mock_layer.self_attn.query_key_value, ), + (2, "qkv_proj"): ( + "model.layers.2.self_attn.qkv_proj", + mock_layer.self_attn.query_key_value, + ), (2, "v_proj"): ( "model.layers.2.self_attn.v_proj", mock_layer.self_attn.query_key_value, diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 52f332c1..fc530b38 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1259,6 +1259,7 @@ def get_model_with_lora_adapters( "gate_proj", "up_proj", "down_proj", + "qkv_proj", ] for layer_name in adapter_layers: @@ -1286,7 +1287,7 @@ def get_model_with_lora_adapters( if len(unused_weight_names) > 0: logger.warning( - f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}" + f"{','.join([a.id for a in lora_adapters])} unused adapter weights: {unused_weight_names}" ) if adapter_tokenizer is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 5b228f9f..ae981c9a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -66,15 +66,15 @@ def load_attention(config, prefix: str, weights, layer_id): prefixes = None if config.model_type == "phi3": - prefix = f"{prefix}.qkv_proj" base_layer = TensorParallelColumnLinear.load_qkv( config, - prefix=prefix, + prefix=f"{prefix}.qkv_proj", weights=weights, bias=bias, num_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, ) + prefixes = ["qkv_proj"] elif config.model_type == "baichuan": prefix = f"{prefix}.W_pack" base_layer = TensorParallelColumnLinear.load_qkv( @@ -85,6 +85,7 @@ def load_attention(config, prefix: str, weights, layer_id): num_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, ) + prefixes = [prefix] else: prefixes = ["q_proj", "k_proj", "v_proj"] sizes = [ diff --git a/server/text_generation_server/utils/adapter.py b/server/text_generation_server/utils/adapter.py index 1db5f77b..b7fc89df 100644 --- a/server/text_generation_server/utils/adapter.py +++ b/server/text_generation_server/utils/adapter.py @@ -3,6 +3,7 @@ # License: Apache License Version 2.0, January 2004 import warnings +import re from dataclasses import dataclass from functools import lru_cache from typing import TYPE_CHECKING, Set, Tuple, Optional, List @@ -27,6 +28,7 @@ BASE_MODEL_ADAPTER_ID = "__base_model__" class AdapterInfo: id: str path: Optional[str] + revision: Optional[str] = None @dataclass @@ -51,11 +53,16 @@ def parse_lora_adapters(lora_adapters: Optional[str]) -> List[AdapterInfo]: adapter_list = [] for adapter in lora_adapters.split(","): - parts = adapter.strip().split("=") - if len(parts) == 1: - adapter_list.append(AdapterInfo(id=parts[0], path=None)) - elif len(parts) == 2: - adapter_list.append(AdapterInfo(id=parts[0], path=parts[1])) + adapter = adapter.strip() + if adapter.count("=") > 1 or adapter.count("@") > 1: + raise ValueError(f"Invalid LoRA adapter format: {adapter}") + match = re.match(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$", adapter) + + if match: + adapter_id, path, revision = match.groups() + adapter_list.append( + AdapterInfo(id=adapter_id, path=path, revision=revision) + ) else: raise ValueError(f"Invalid LoRA adapter format: {adapter}") return adapter_list @@ -73,6 +80,7 @@ def load_and_merge_adapters( adapter_info = next(iter(adapter_parameters.adapter_info)) return load_module_map( model_id, + adapter_info.revision, adapter_info.id, adapter_info.path, weight_names, @@ -80,7 +88,13 @@ def load_and_merge_adapters( ) adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index) - return _load_and_merge(model_id, adapter_params, weight_names, trust_remote_code) + return _load_and_merge( + model_id, + adapter_params.revision, + adapter_params, + weight_names, + trust_remote_code, + ) @dataclass @@ -95,6 +109,7 @@ class AdapterParametersContainer: @lru_cache(maxsize=32) def _load_and_merge( model_id: str, + revision: str, adapter_params: AdapterParametersContainer, weight_names: Tuple[str], trust_remote_code: bool = False, @@ -171,12 +186,12 @@ def check_architectures( @lru_cache(maxsize=128) def load_module_map( model_id: str, + revision: str, adapter_id: str, adapter_path: Optional[str], weight_names: Tuple[str], trust_remote_code: bool = False, ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: - revision = "main" adapter_config = LoraConfig.load(adapter_path or adapter_id, None) @@ -191,6 +206,12 @@ def load_module_map( ) ) + # throw an error if no adapter weights are found + if not adapter_filenames: + raise FileNotFoundError( + f"No adapter weights found for adapter '{adapter_id}' and revision '{revision}'." + ) + try: adapter_tokenizer = AutoTokenizer.from_pretrained( adapter_config.config_path, @@ -221,6 +242,12 @@ def get_attn_weights(i, layer): value = (f"model.layers.{i}.self_attn.{k}_proj", qkv) weights[key] = value + # also add the qkv_proj weight for the adapter + weights[(i, "qkv_proj")] = ( + f"model.layers.{i}.self_attn.qkv_proj", + qkv, + ) + weights[(i, "o_proj")] = ( f"model.layers.{i}.self_attn.o_proj", layer.self_attn.o_proj,