2024-06-25 12:46:27 -06:00
|
|
|
# Origin: https://github.com/predibase/lorax
|
|
|
|
# Path: lorax/server/lorax_server/utils/adapter.py
|
|
|
|
# License: Apache License Version 2.0, January 2004
|
|
|
|
|
|
|
|
import warnings
|
2024-09-02 11:09:06 -06:00
|
|
|
import re
|
2024-06-25 12:46:27 -06:00
|
|
|
from dataclasses import dataclass
|
|
|
|
from functools import lru_cache
|
2024-07-24 13:32:14 -06:00
|
|
|
from typing import TYPE_CHECKING, Set, Tuple, Optional, List
|
2024-06-25 12:46:27 -06:00
|
|
|
|
|
|
|
from safetensors.torch import load_file
|
|
|
|
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
|
|
|
|
|
|
|
from text_generation_server.utils.merges.strategies import merge_adapters
|
|
|
|
|
|
|
|
from text_generation_server.utils import hub
|
|
|
|
from text_generation_server.adapters.lora import LoraConfig
|
|
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
|
|
|
|
|
|
|
|
|
|
|
|
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
|
|
|
|
|
|
|
|
2024-07-24 13:32:14 -06:00
|
|
|
@dataclass
|
|
|
|
class AdapterInfo:
|
|
|
|
id: str
|
|
|
|
path: Optional[str]
|
2024-09-02 11:09:06 -06:00
|
|
|
revision: Optional[str] = None
|
2024-07-24 13:32:14 -06:00
|
|
|
|
|
|
|
|
2024-06-25 12:46:27 -06:00
|
|
|
@dataclass
|
|
|
|
class AdapterParameters:
|
2024-07-24 13:32:14 -06:00
|
|
|
adapter_info: Tuple[AdapterInfo]
|
2024-06-25 12:46:27 -06:00
|
|
|
weights: Tuple[float]
|
|
|
|
merge_strategy: NotImplemented
|
|
|
|
density: float
|
|
|
|
majority_sign_method: NotImplemented
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class AdapterSource:
|
|
|
|
adapter_id: str
|
|
|
|
model_id: str
|
|
|
|
revision: str
|
|
|
|
|
|
|
|
|
2024-07-24 13:32:14 -06:00
|
|
|
def parse_lora_adapters(lora_adapters: Optional[str]) -> List[AdapterInfo]:
|
|
|
|
if not lora_adapters:
|
|
|
|
return []
|
|
|
|
|
|
|
|
adapter_list = []
|
|
|
|
for adapter in lora_adapters.split(","):
|
2024-09-02 11:09:06 -06:00
|
|
|
adapter = adapter.strip()
|
|
|
|
if adapter.count("=") > 1 or adapter.count("@") > 1:
|
|
|
|
raise ValueError(f"Invalid LoRA adapter format: {adapter}")
|
|
|
|
match = re.match(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$", adapter)
|
|
|
|
|
|
|
|
if match:
|
|
|
|
adapter_id, path, revision = match.groups()
|
|
|
|
adapter_list.append(
|
|
|
|
AdapterInfo(id=adapter_id, path=path, revision=revision)
|
|
|
|
)
|
2024-07-24 13:32:14 -06:00
|
|
|
else:
|
|
|
|
raise ValueError(f"Invalid LoRA adapter format: {adapter}")
|
|
|
|
return adapter_list
|
|
|
|
|
|
|
|
|
2024-06-25 12:46:27 -06:00
|
|
|
def load_and_merge_adapters(
|
|
|
|
model_id: str,
|
|
|
|
adapter_parameters: AdapterParameters,
|
|
|
|
adapter_index: int,
|
|
|
|
weight_names: Tuple[str],
|
|
|
|
trust_remote_code: bool = False,
|
|
|
|
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
2024-07-24 13:32:14 -06:00
|
|
|
if len(adapter_parameters.adapter_info) == 1:
|
2024-09-12 09:04:52 -06:00
|
|
|
adapter = next(iter(adapter_parameters.adapter_info))
|
2024-06-25 12:46:27 -06:00
|
|
|
return load_module_map(
|
|
|
|
model_id,
|
2024-09-12 09:04:52 -06:00
|
|
|
adapter.revision,
|
|
|
|
adapter.id,
|
|
|
|
adapter.path,
|
2024-06-25 12:46:27 -06:00
|
|
|
weight_names,
|
|
|
|
trust_remote_code,
|
|
|
|
)
|
|
|
|
|
2024-07-24 13:32:14 -06:00
|
|
|
adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index)
|
2024-09-02 11:09:06 -06:00
|
|
|
return _load_and_merge(
|
|
|
|
model_id,
|
|
|
|
adapter_params,
|
|
|
|
weight_names,
|
|
|
|
trust_remote_code,
|
|
|
|
)
|
2024-06-25 12:46:27 -06:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class AdapterParametersContainer:
|
|
|
|
adapter_parameters: AdapterParameters
|
|
|
|
adapter_index: int
|
|
|
|
|
|
|
|
def __hash__(self) -> int:
|
|
|
|
return self.adapter_index
|
|
|
|
|
|
|
|
|
|
|
|
@lru_cache(maxsize=32)
|
|
|
|
def _load_and_merge(
|
|
|
|
model_id: str,
|
|
|
|
adapter_params: AdapterParametersContainer,
|
|
|
|
weight_names: Tuple[str],
|
|
|
|
trust_remote_code: bool = False,
|
|
|
|
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
|
|
|
params = adapter_params.adapter_parameters
|
|
|
|
|
|
|
|
adapters_to_merge = []
|
|
|
|
merged_weight_names = set()
|
|
|
|
tokenizer = None
|
2024-07-24 13:32:14 -06:00
|
|
|
for adapter in params.adapter_info:
|
|
|
|
if adapter.id == BASE_MODEL_ADAPTER_ID:
|
2024-06-25 12:46:27 -06:00
|
|
|
raise ValueError("Base model adapter cannot be merged.")
|
|
|
|
|
|
|
|
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
|
|
|
|
load_module_map(
|
|
|
|
model_id,
|
2024-09-12 09:04:52 -06:00
|
|
|
adapter.revision,
|
2024-07-24 13:32:14 -06:00
|
|
|
adapter.id,
|
|
|
|
adapter.path,
|
2024-06-25 12:46:27 -06:00
|
|
|
weight_names,
|
|
|
|
trust_remote_code,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
adapters_to_merge.append((module_map, adapter_config))
|
|
|
|
merged_weight_names = merged_weight_names.union(adapter_weight_names)
|
|
|
|
if tokenizer is None:
|
|
|
|
tokenizer = adapter_tokenizer
|
|
|
|
|
|
|
|
if len(adapters_to_merge) == 0:
|
|
|
|
raise ValueError("No adapters to merge.")
|
|
|
|
|
|
|
|
module_map, adapter_config = merge_adapters(adapters_to_merge, params)
|
|
|
|
return module_map, adapter_config, merged_weight_names, tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
def check_architectures(
|
|
|
|
model_id: str,
|
|
|
|
adapter_id: str,
|
|
|
|
adapter_config: "AdapterConfig",
|
|
|
|
trust_remote_code: bool = False,
|
|
|
|
):
|
|
|
|
try:
|
|
|
|
if not adapter_config.base_model_name_or_path:
|
|
|
|
# Avoid execution latency caused by the network connection retrying for AutoConfig.from_pretrained(None)
|
|
|
|
return
|
|
|
|
|
|
|
|
expected_config = AutoConfig.from_pretrained(
|
|
|
|
model_id, trust_remote_code=trust_remote_code
|
|
|
|
)
|
|
|
|
model_config = AutoConfig.from_pretrained(
|
|
|
|
adapter_config.base_model_name_or_path, trust_remote_code=trust_remote_code
|
|
|
|
)
|
|
|
|
except Exception as e:
|
|
|
|
warnings.warn(
|
|
|
|
f"Unable to check architecture compatibility for adapter '{adapter_id}' "
|
|
|
|
f"against model '{model_id}'. Assuming they are compatible. Error: {e}"
|
|
|
|
)
|
|
|
|
return
|
|
|
|
|
|
|
|
if model_config.architectures == expected_config.architectures:
|
|
|
|
warnings.warn(
|
|
|
|
f"Adapter '{adapter_id}' was not trained on base model '{model_id}'. "
|
|
|
|
f"If you encounter issues, use --model-id '{adapter_config.base_model_name_or_path}' instead."
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
# TODO(travis): revisit this when we support clasification heads which will not use CausalLM
|
|
|
|
raise ValueError(
|
|
|
|
f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. "
|
|
|
|
f"Architectures differ: {model_config.architectures} != {expected_config.architectures}. "
|
|
|
|
f"Use --model-id '{adapter_config.base_model_name_or_path}' instead."
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@lru_cache(maxsize=128)
|
|
|
|
def load_module_map(
|
|
|
|
model_id: str,
|
2024-09-02 11:09:06 -06:00
|
|
|
revision: str,
|
2024-06-25 12:46:27 -06:00
|
|
|
adapter_id: str,
|
2024-07-24 13:32:14 -06:00
|
|
|
adapter_path: Optional[str],
|
2024-06-25 12:46:27 -06:00
|
|
|
weight_names: Tuple[str],
|
|
|
|
trust_remote_code: bool = False,
|
|
|
|
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
2024-07-24 13:32:14 -06:00
|
|
|
adapter_config = LoraConfig.load(adapter_path or adapter_id, None)
|
|
|
|
|
|
|
|
if not adapter_path and adapter_config.base_model_name_or_path != model_id:
|
2024-06-25 12:46:27 -06:00
|
|
|
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
|
|
|
|
|
2024-07-24 13:32:14 -06:00
|
|
|
adapter_filenames = (
|
2024-09-24 03:19:24 -06:00
|
|
|
hub._weight_files_from_dir(adapter_path, extension=".safetensors")
|
2024-07-24 13:32:14 -06:00
|
|
|
if adapter_path
|
2024-09-24 03:19:24 -06:00
|
|
|
else hub._cached_weight_files(
|
2024-07-24 13:32:14 -06:00
|
|
|
adapter_id, revision=revision, extension=".safetensors"
|
|
|
|
)
|
2024-06-25 12:46:27 -06:00
|
|
|
)
|
|
|
|
|
2024-09-02 11:09:06 -06:00
|
|
|
# throw an error if no adapter weights are found
|
|
|
|
if not adapter_filenames:
|
|
|
|
raise FileNotFoundError(
|
|
|
|
f"No adapter weights found for adapter '{adapter_id}' and revision '{revision}'."
|
|
|
|
)
|
|
|
|
|
2024-06-25 12:46:27 -06:00
|
|
|
try:
|
|
|
|
adapter_tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
adapter_config.config_path,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
|
|
|
except Exception:
|
|
|
|
# Adapter does not have a tokenizer, so fallback to base model tokenizer
|
|
|
|
adapter_tokenizer = None
|
|
|
|
|
|
|
|
# load adapter weights from all shards (should have relatively small memory footprint)
|
|
|
|
adapter_weights = {}
|
|
|
|
for filename in adapter_filenames:
|
|
|
|
adapter_weights.update(load_file(filename))
|
|
|
|
|
|
|
|
# map the model weights to the relevant adapter weights (LoRA A and B matrices)
|
|
|
|
module_map, adapter_weight_names = adapter_config.map_weights_for_model(
|
|
|
|
adapter_weights, weight_names
|
|
|
|
)
|
|
|
|
return module_map, adapter_config, adapter_weight_names, adapter_tokenizer
|
2024-07-24 13:32:14 -06:00
|
|
|
|
|
|
|
|
|
|
|
def get_attn_weights(i, layer):
|
|
|
|
qkv = layer.self_attn.query_key_value
|
|
|
|
weights = {}
|
|
|
|
|
|
|
|
for k in ["q", "k", "v"]:
|
|
|
|
key = (i, f"{k}_proj")
|
|
|
|
value = (f"model.layers.{i}.self_attn.{k}_proj", qkv)
|
|
|
|
weights[key] = value
|
|
|
|
|
2024-09-02 11:09:06 -06:00
|
|
|
# also add the qkv_proj weight for the adapter
|
|
|
|
weights[(i, "qkv_proj")] = (
|
|
|
|
f"model.layers.{i}.self_attn.qkv_proj",
|
|
|
|
qkv,
|
|
|
|
)
|
|
|
|
|
2024-07-24 13:32:14 -06:00
|
|
|
weights[(i, "o_proj")] = (
|
|
|
|
f"model.layers.{i}.self_attn.o_proj",
|
|
|
|
layer.self_attn.o_proj,
|
|
|
|
)
|
|
|
|
|
|
|
|
return weights
|
|
|
|
|
|
|
|
|
|
|
|
def get_mlp_weights(i, layer):
|
|
|
|
weights = {}
|
|
|
|
if hasattr(layer, "mlp"):
|
|
|
|
mlp = layer.mlp
|
|
|
|
if hasattr(mlp, "gate_up_proj"):
|
|
|
|
# handle combined gate_up_proj (e.g., for some LLaMA variants)
|
|
|
|
weights.update(
|
|
|
|
{
|
|
|
|
(i, "gate_proj"): (
|
|
|
|
f"model.layers.{i}.mlp.gate_proj",
|
|
|
|
mlp.gate_up_proj,
|
|
|
|
),
|
|
|
|
(i, "up_proj"): (f"model.layers.{i}.mlp.up_proj", mlp.gate_up_proj),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
# handle separate gate_proj, up_proj, and down_proj (e.g., for Gemma)
|
|
|
|
if hasattr(mlp, "gate_proj"):
|
|
|
|
weights[(i, "gate_proj")] = (
|
|
|
|
f"model.layers.{i}.mlp.gate_proj",
|
|
|
|
mlp.gate_proj,
|
|
|
|
)
|
|
|
|
if hasattr(mlp, "up_proj"):
|
|
|
|
weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj)
|
|
|
|
|
|
|
|
if hasattr(mlp, "down_proj"):
|
|
|
|
weights[(i, "down_proj")] = (
|
|
|
|
f"model.layers.{i}.mlp.down_proj",
|
|
|
|
mlp.down_proj,
|
|
|
|
)
|
|
|
|
|
|
|
|
return weights
|
|
|
|
|
|
|
|
|
|
|
|
# build_layer_weight_lookup creates a mapping of model layers to their corresponding
|
|
|
|
# weight tensors and paths. It builds a dictionary that maps layer identifiers to tuples
|
|
|
|
# containing the weight tensor path and the actual layer object. This mapping is needed
|
|
|
|
# for the lora adapter to know which weights to update when applying the adapter.
|
|
|
|
def build_layer_weight_lookup(model):
|
|
|
|
if hasattr(model, "language_model"):
|
|
|
|
m = model.language_model.model
|
|
|
|
elif hasattr(model, "text_model"):
|
|
|
|
m = model.text_model.model
|
|
|
|
else:
|
|
|
|
m = model.model
|
|
|
|
|
|
|
|
layer_weights = {}
|
|
|
|
|
|
|
|
for i, layer in enumerate(m.layers):
|
|
|
|
attn_weights = get_attn_weights(i, layer)
|
|
|
|
mlp_weights = get_mlp_weights(i, layer)
|
|
|
|
|
|
|
|
layer_weights.update(attn_weights)
|
|
|
|
layer_weights.update(mlp_weights)
|
|
|
|
|
|
|
|
lm_head = None
|
|
|
|
if hasattr(m, "lm_head"):
|
|
|
|
lm_head = m.lm_head
|
|
|
|
elif hasattr(model, "lm_head"):
|
|
|
|
lm_head = model.lm_head
|
|
|
|
|
|
|
|
if lm_head:
|
|
|
|
layer_weights[(0, "lm_head")] = ("lm_head", lm_head)
|
|
|
|
|
|
|
|
return layer_weights
|