fix: refactor adapter weight loading and mapping (#2193)
* fix: refactor adapter weight loading and mapping * feat: enable lora load from directory * fix: adjust launcher for local lora adapters * feat: improve weight loading and add tests * fix: improve logging and rebase syntax issue * fix: impove adapter merge comments and remove unused conditional * fix: improve get_model_with_lora_adapters naming * fix: comment typo
This commit is contained in:
parent
93d2b9fe9c
commit
5d85a958c9
|
@ -1607,6 +1607,10 @@ fn main() -> Result<(), LauncherError> {
|
|||
// Download and convert lora adapters if any
|
||||
if let Some(lora_adapters) = &args.lora_adapters {
|
||||
for adapter in lora_adapters.split(',') {
|
||||
// skip download if a path is provided
|
||||
if adapter.contains('=') {
|
||||
continue;
|
||||
}
|
||||
download_convert_model(
|
||||
adapter,
|
||||
None,
|
||||
|
|
|
@ -0,0 +1,187 @@
|
|||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from text_generation_server.utils.adapter import get_attn_weights, get_mlp_weights
|
||||
|
||||
|
||||
def test_get_attn_weights():
|
||||
# create a mock layer
|
||||
mock_layer = Mock()
|
||||
mock_layer.self_attn.query_key_value = Mock()
|
||||
mock_layer.self_attn.o_proj = Mock()
|
||||
|
||||
# call the function
|
||||
result = get_attn_weights(2, mock_layer)
|
||||
|
||||
# assert the result
|
||||
expected = {
|
||||
(2, "q_proj"): (
|
||||
"model.layers.2.self_attn.q_proj",
|
||||
mock_layer.self_attn.query_key_value,
|
||||
),
|
||||
(2, "k_proj"): (
|
||||
"model.layers.2.self_attn.k_proj",
|
||||
mock_layer.self_attn.query_key_value,
|
||||
),
|
||||
(2, "v_proj"): (
|
||||
"model.layers.2.self_attn.v_proj",
|
||||
mock_layer.self_attn.query_key_value,
|
||||
),
|
||||
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
|
||||
}
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_get_mlp_weights_with_gate_up_proj():
|
||||
# create a mock layer with gate_up_proj
|
||||
mock_layer = Mock()
|
||||
mock_layer.mlp.gate_up_proj = Mock()
|
||||
mock_layer.mlp.down_proj = Mock()
|
||||
|
||||
# call the function
|
||||
result = get_mlp_weights(3, mock_layer)
|
||||
|
||||
# assert the result
|
||||
expected = {
|
||||
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
|
||||
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
|
||||
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
|
||||
}
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_get_mlp_weights_without_gate_up_proj():
|
||||
# create a mock layer without gate_up_proj
|
||||
mock_layer = Mock()
|
||||
mock_layer.mlp = Mock(spec=[])
|
||||
|
||||
# call the function
|
||||
result = get_mlp_weights(1, mock_layer)
|
||||
|
||||
# assert the result
|
||||
assert result == {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("layer_index", [0, 1, 5])
|
||||
def test_get_attn_weights_different_layers(layer_index):
|
||||
mock_layer = Mock()
|
||||
mock_layer.self_attn.query_key_value = Mock()
|
||||
mock_layer.self_attn.o_proj = Mock()
|
||||
|
||||
result = get_attn_weights(layer_index, mock_layer)
|
||||
|
||||
for k in ["q", "k", "v"]:
|
||||
assert (layer_index, f"{k}_proj") in result
|
||||
assert (
|
||||
result[(layer_index, f"{k}_proj")][0]
|
||||
== f"model.layers.{layer_index}.self_attn.{k}_proj"
|
||||
)
|
||||
|
||||
assert (layer_index, "o_proj") in result
|
||||
assert (
|
||||
result[(layer_index, "o_proj")][0]
|
||||
== f"model.layers.{layer_index}.self_attn.o_proj"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("layer_index", [0, 1, 5])
|
||||
def test_get_mlp_weights_different_layers(layer_index):
|
||||
mock_layer = Mock()
|
||||
mock_layer.mlp.gate_up_proj = Mock()
|
||||
mock_layer.mlp.down_proj = Mock()
|
||||
|
||||
result = get_mlp_weights(layer_index, mock_layer)
|
||||
|
||||
for k in ["gate", "up", "down"]:
|
||||
assert (layer_index, f"{k}_proj") in result
|
||||
assert (
|
||||
result[(layer_index, f"{k}_proj")][0]
|
||||
== f"model.layers.{layer_index}.mlp.{k}_proj"
|
||||
)
|
||||
|
||||
|
||||
def test_get_attn_weights_llama_compatibility():
|
||||
mock_layer = Mock()
|
||||
mock_layer.self_attn.query_key_value = Mock()
|
||||
mock_layer.self_attn.o_proj = Mock()
|
||||
|
||||
result = get_attn_weights(2, mock_layer)
|
||||
|
||||
expected = {
|
||||
(2, "q_proj"): (
|
||||
"model.layers.2.self_attn.q_proj",
|
||||
mock_layer.self_attn.query_key_value,
|
||||
),
|
||||
(2, "k_proj"): (
|
||||
"model.layers.2.self_attn.k_proj",
|
||||
mock_layer.self_attn.query_key_value,
|
||||
),
|
||||
(2, "v_proj"): (
|
||||
"model.layers.2.self_attn.v_proj",
|
||||
mock_layer.self_attn.query_key_value,
|
||||
),
|
||||
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
|
||||
}
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_get_mlp_weights_llama_compatibility():
|
||||
mock_layer = Mock()
|
||||
mock_layer.mlp.gate_up_proj = Mock()
|
||||
mock_layer.mlp.down_proj = Mock()
|
||||
|
||||
result = get_mlp_weights(3, mock_layer)
|
||||
|
||||
expected = {
|
||||
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
|
||||
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
|
||||
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
|
||||
}
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_get_attn_weights_gemma_compatibility():
|
||||
mock_layer = Mock()
|
||||
mock_layer.self_attn.query_key_value = Mock()
|
||||
mock_layer.self_attn.o_proj = Mock()
|
||||
|
||||
result = get_attn_weights(2, mock_layer)
|
||||
|
||||
expected = {
|
||||
(2, "q_proj"): (
|
||||
"model.layers.2.self_attn.q_proj",
|
||||
mock_layer.self_attn.query_key_value,
|
||||
),
|
||||
(2, "k_proj"): (
|
||||
"model.layers.2.self_attn.k_proj",
|
||||
mock_layer.self_attn.query_key_value,
|
||||
),
|
||||
(2, "v_proj"): (
|
||||
"model.layers.2.self_attn.v_proj",
|
||||
mock_layer.self_attn.query_key_value,
|
||||
),
|
||||
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
|
||||
}
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_get_mlp_weights_gemma_compatibility():
|
||||
mock_layer = Mock()
|
||||
mock_layer.mlp.gate_proj = Mock()
|
||||
mock_layer.mlp.up_proj = Mock()
|
||||
mock_layer.mlp.down_proj = Mock()
|
||||
|
||||
# ensure that the mock_layer.mlp.gate_up_proj attribute does not exist.
|
||||
# This is necessary because the use of `Mock` automatically creates any
|
||||
# attributes that are accessed, even if they don't exist in the actual
|
||||
# implementation. If `gate_up_proj` were created, `get_mlp_weights` might
|
||||
# follow the wrong execution path and return an incorrect result.
|
||||
del mock_layer.mlp.gate_up_proj
|
||||
|
||||
result = get_mlp_weights(3, mock_layer)
|
||||
|
||||
expected = {
|
||||
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj),
|
||||
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj),
|
||||
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
|
||||
}
|
||||
assert result == expected
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -31,14 +31,3 @@ class AdapterConfig(ABC):
|
|||
weight_names: Tuple[str],
|
||||
) -> Tuple[ModuleMap, Set[str]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_batched_adapter_weights(
|
||||
self,
|
||||
model: "Model",
|
||||
module_map: ModuleMap,
|
||||
layer_type: str,
|
||||
unused_weight_names: Set[str],
|
||||
dynamic: bool,
|
||||
) -> Optional[AdapterWeights]:
|
||||
pass
|
||||
|
|
|
@ -102,22 +102,6 @@ class LoraConfig(AdapterConfig):
|
|||
adapter_weight_names.add(lora_b_name)
|
||||
return module_map, adapter_weight_names
|
||||
|
||||
def load_batched_adapter_weights(
|
||||
self,
|
||||
model: "Model",
|
||||
module_map: Dict[str, Dict],
|
||||
layer_type: str,
|
||||
unused_weight_names: Set[str],
|
||||
dynamic: bool,
|
||||
) -> Optional[AdapterWeights]:
|
||||
return LoraWeights.load(
|
||||
self,
|
||||
model,
|
||||
module_map,
|
||||
layer_type,
|
||||
unused_weight_names,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
|
||||
hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
|
||||
|
@ -192,22 +176,38 @@ class LoraWeights(AdapterWeights):
|
|||
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
|
||||
return [BatchLoraWeights]
|
||||
|
||||
# prepare pre-loaded lora weights for use in the model.
|
||||
#
|
||||
# this method processes and organizes lora weights for a specific layer type across all layers:
|
||||
# - uses `config` (LoraConfig) to apply lora-specific settings like scaling factor.
|
||||
# - retrieves weights from `module_map` based on the `layer_type`.
|
||||
# - processes `nlayers` number of layers.
|
||||
# - converts weights to the specified `dtype`.
|
||||
# - shards weights across `world_size` number of processes using the `process_group`.
|
||||
# - maps weights to specific layers using `target_to_layer`.
|
||||
# - tracks `unused_weight_names` to identify any unused weights.
|
||||
#
|
||||
# the method handles weight transposition, scaling, and padding to ensure compatibility
|
||||
# with SGMV or BGMV operations.
|
||||
@classmethod
|
||||
def load(
|
||||
def prepare_weights(
|
||||
cls,
|
||||
config: LoraConfig,
|
||||
model: "Model",
|
||||
module_map: Dict[str, Dict],
|
||||
layer_type: str,
|
||||
unused_weight_names: Set[str],
|
||||
nlayers: int,
|
||||
dtype: torch.dtype,
|
||||
world_size: int,
|
||||
process_group: ProcessGroup,
|
||||
target_to_layer: Dict[str, Tuple[str, torch.Tensor]],
|
||||
) -> Optional[AdapterWeights]:
|
||||
nlayers = model.get_num_layers_for_type(layer_type)
|
||||
lora_a_list = [None] * nlayers
|
||||
lora_b_list = [None] * nlayers
|
||||
|
||||
for layer_id in range(nlayers):
|
||||
key = (layer_id, layer_type)
|
||||
weight_name, layer = model.target_to_layer[key]
|
||||
weight_name, layer = target_to_layer[key]
|
||||
base_weight = layer.base_layer.linear.weight
|
||||
base_device = base_weight.device
|
||||
|
||||
|
@ -216,10 +216,10 @@ class LoraWeights(AdapterWeights):
|
|||
return None
|
||||
|
||||
lora_a, lora_a_name = module_map[weight_name]["lora_A"]
|
||||
lora_a = lora_a.to(base_device, model.dtype)
|
||||
lora_a = lora_a.to(base_device, dtype)
|
||||
|
||||
lora_b, lora_b_name = module_map[weight_name]["lora_B"]
|
||||
lora_b = lora_b.to(base_device, model.dtype)
|
||||
lora_b = lora_b.to(base_device, dtype)
|
||||
|
||||
scale = get_scaling_factor(
|
||||
config.lora_alpha,
|
||||
|
@ -236,12 +236,8 @@ class LoraWeights(AdapterWeights):
|
|||
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
|
||||
|
||||
# pad lora ranks to be compatible with sgmv
|
||||
lora_a_list = [
|
||||
pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list
|
||||
]
|
||||
lora_b_list = [
|
||||
pad_rank(w, dim=0, world_size=model.world_size) for w in lora_b_list
|
||||
]
|
||||
lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list]
|
||||
lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list]
|
||||
|
||||
if lora_a_list:
|
||||
# update rank if it was padded
|
||||
|
@ -252,8 +248,8 @@ class LoraWeights(AdapterWeights):
|
|||
*shard_lora_weights(
|
||||
weights_a=lora_a_list,
|
||||
weights_b=lora_b_list,
|
||||
split_dim=0 if model.is_row_parallel(layer_type) else 1,
|
||||
process_group=model.process_group,
|
||||
split_dim=0 if layer_type in {"o_proj", "down_proj", "lm_head"} else 1,
|
||||
process_group=process_group,
|
||||
),
|
||||
config,
|
||||
)
|
||||
|
@ -293,10 +289,6 @@ class BatchLoraWeights(BatchAdapterWeights):
|
|||
for rank_data in self.rank_data.values()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def key(cls) -> str:
|
||||
return "lora"
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
self,
|
||||
|
|
|
@ -42,10 +42,6 @@ class BatchAdapterWeights(ABC):
|
|||
def has_adapter(self, adapter_index: int) -> bool:
|
||||
pass
|
||||
|
||||
@abstractclassmethod
|
||||
def key(cls) -> str:
|
||||
pass
|
||||
|
||||
@abstractclassmethod
|
||||
def load(
|
||||
cls,
|
||||
|
@ -71,13 +67,6 @@ class LayerAdapterWeights:
|
|||
return
|
||||
del self.adapter_weights[adapter_idx]
|
||||
|
||||
@property
|
||||
def max_speculative_tokens(self) -> int:
|
||||
return max(
|
||||
adapter_weights.speculative_tokens
|
||||
for adapter_weights in self.adapter_weights.values()
|
||||
)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return len(self.adapter_weights) == 0
|
||||
|
||||
|
@ -101,7 +90,7 @@ class LayerAdapterWeights:
|
|||
adapter_weights, meta, prefill, prefill_head_indices
|
||||
)
|
||||
if batched_weights is not None:
|
||||
batch_data[batch_type.key()] = batched_weights
|
||||
batch_data = batched_weights
|
||||
return batch_data
|
||||
|
||||
|
||||
|
@ -133,8 +122,7 @@ class AdapterBatchData:
|
|||
def ranks(self) -> Set[int]:
|
||||
# TODO(travis): refactor to be less coupled to lora implementation
|
||||
ranks = set()
|
||||
for layer_data in self.data.values():
|
||||
lora_data = layer_data.get("lora")
|
||||
for lora_data in self.data.values():
|
||||
if lora_data is None:
|
||||
continue
|
||||
|
||||
|
|
|
@ -4,9 +4,10 @@ import typer
|
|||
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
from typing import Optional, List, Dict
|
||||
from enum import Enum
|
||||
from huggingface_hub import hf_hub_download
|
||||
from text_generation_server.utils.adapter import parse_lora_adapters
|
||||
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
|
@ -80,28 +81,19 @@ def serve(
|
|||
if otlp_endpoint is not None:
|
||||
setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
|
||||
|
||||
lora_adapter_ids = os.getenv("LORA_ADAPTERS", None)
|
||||
|
||||
# split on comma and strip whitespace
|
||||
lora_adapter_ids = (
|
||||
[x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else []
|
||||
)
|
||||
|
||||
if len(lora_adapter_ids) > 0:
|
||||
log_master(
|
||||
logger.warning,
|
||||
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected.",
|
||||
)
|
||||
lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS"))
|
||||
|
||||
# TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
|
||||
# and warn the user
|
||||
if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None:
|
||||
log_master(
|
||||
logger.warning,
|
||||
f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs.",
|
||||
)
|
||||
global CUDA_GRAPHS
|
||||
CUDA_GRAPHS = None
|
||||
if lora_adapters:
|
||||
logger.warning("LoRA adapters enabled (experimental feature).")
|
||||
|
||||
if "CUDA_GRAPHS" in os.environ:
|
||||
logger.warning(
|
||||
"LoRA adapters incompatible with CUDA Graphs. Disabling CUDA Graphs."
|
||||
)
|
||||
global CUDA_GRAPHS
|
||||
CUDA_GRAPHS = None
|
||||
|
||||
# Downgrade enum into str for easier management later on
|
||||
quantize = None if quantize is None else quantize.value
|
||||
|
@ -117,7 +109,7 @@ def serve(
|
|||
)
|
||||
server.serve(
|
||||
model_id,
|
||||
lora_adapter_ids,
|
||||
lora_adapters,
|
||||
revision,
|
||||
sharded,
|
||||
quantize,
|
||||
|
|
|
@ -43,10 +43,7 @@ class LoraLinear(nn.Module):
|
|||
) -> torch.Tensor:
|
||||
if adapter_data is None:
|
||||
return result
|
||||
data = adapter_data.data.get(layer_type)
|
||||
data: Optional["BatchLoraWeights"] = (
|
||||
data.get("lora") if data is not None else None
|
||||
)
|
||||
data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type)
|
||||
|
||||
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
|
||||
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
|
||||
|
|
|
@ -6,7 +6,7 @@ from loguru import logger
|
|||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.models.auto import modeling_auto
|
||||
from huggingface_hub import hf_hub_download, HfApi
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Dict
|
||||
from pathlib import Path
|
||||
|
||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||
|
@ -33,6 +33,16 @@ from text_generation_server.models.custom_modeling.t5_modeling import (
|
|||
T5ForConditionalGeneration,
|
||||
)
|
||||
|
||||
|
||||
from text_generation_server.utils.adapter import (
|
||||
AdapterParameters,
|
||||
build_layer_weight_lookup,
|
||||
load_and_merge_adapters,
|
||||
AdapterInfo,
|
||||
)
|
||||
from text_generation_server.adapters.lora import LoraWeights
|
||||
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
|
@ -50,7 +60,7 @@ __all__ = [
|
|||
"Model",
|
||||
"CausalLM",
|
||||
"Seq2SeqLM",
|
||||
"get_model",
|
||||
"get_model_with_lora_adapters",
|
||||
]
|
||||
|
||||
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||
|
@ -1115,3 +1125,116 @@ def get_model(
|
|||
)
|
||||
|
||||
raise ValueError(f"Unsupported model type {model_type}")
|
||||
|
||||
|
||||
# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters
|
||||
# this provides a post model loading hook to load adapters into the model after the model has been loaded
|
||||
def get_model_with_lora_adapters(
|
||||
model_id: str,
|
||||
lora_adapters: Optional[List[AdapterInfo]],
|
||||
revision: Optional[str],
|
||||
sharded: bool,
|
||||
quantize: Optional[str],
|
||||
speculate: Optional[int],
|
||||
dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
max_input_tokens: int,
|
||||
adapter_to_index: Dict[str, int],
|
||||
):
|
||||
lora_adapter_ids = [adapter.id for adapter in lora_adapters]
|
||||
model = get_model(
|
||||
model_id,
|
||||
lora_adapter_ids,
|
||||
revision,
|
||||
sharded,
|
||||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
trust_remote_code,
|
||||
max_input_tokens,
|
||||
)
|
||||
|
||||
if len(lora_adapters) > 0:
|
||||
target_to_layer = build_layer_weight_lookup(model.model)
|
||||
|
||||
for index, adapter in enumerate(lora_adapters):
|
||||
# The AdapterParameters object allows for merging multiple adapters into a single adapter.
|
||||
# At the moment, we only support loading a single adapter into the model, but we keep the
|
||||
# AdapterParameters object for easier extension in the future.
|
||||
adapter_parameters = AdapterParameters(
|
||||
adapter_info=[adapter],
|
||||
# when merging multiple adapters we can weight them differently
|
||||
# if this is not set, all adapters will be weighted equally
|
||||
# see: text_generation_server.utils.merges.strategies for impl
|
||||
weights=None,
|
||||
merge_strategy=0,
|
||||
density=1.0,
|
||||
majority_sign_method=0,
|
||||
)
|
||||
|
||||
adapter_index = index + 1
|
||||
adapter_to_index[adapter.id] = adapter_index
|
||||
|
||||
logger.info(
|
||||
f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}"
|
||||
)
|
||||
weight_names = tuple([v[0] for v in target_to_layer.values()])
|
||||
(
|
||||
module_map,
|
||||
adapter_config,
|
||||
adapter_weight_names,
|
||||
adapter_tokenizer,
|
||||
) = load_and_merge_adapters(
|
||||
model.model_id,
|
||||
adapter_parameters,
|
||||
adapter_index,
|
||||
weight_names,
|
||||
False,
|
||||
)
|
||||
|
||||
unused_weight_names = adapter_weight_names.copy()
|
||||
|
||||
adapter_layers = [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
|
||||
for layer_name in adapter_layers:
|
||||
nlayers = (
|
||||
1 if layer_name == "lm_head" else len(model.model.model.layers)
|
||||
)
|
||||
adapter_weights = LoraWeights.prepare_weights(
|
||||
config=adapter_config,
|
||||
module_map=module_map,
|
||||
layer_type=layer_name,
|
||||
unused_weight_names=unused_weight_names,
|
||||
nlayers=nlayers,
|
||||
dtype=model.dtype,
|
||||
world_size=model.world_size,
|
||||
process_group=model.process_group,
|
||||
target_to_layer=target_to_layer,
|
||||
)
|
||||
|
||||
if adapter_weights is None:
|
||||
continue
|
||||
|
||||
model.layer_to_adapter_weights[layer_name].add_adapter(
|
||||
adapter_index, adapter_weights
|
||||
)
|
||||
|
||||
if len(unused_weight_names) > 0:
|
||||
logger.warning(
|
||||
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
|
||||
)
|
||||
|
||||
if adapter_tokenizer is not None:
|
||||
model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)
|
||||
|
||||
model.loaded_adapters.add(adapter_index)
|
||||
|
||||
return model
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import math
|
||||
import os
|
||||
import time
|
||||
import itertools
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
@ -1700,72 +1699,3 @@ class FlashCausalLM(Model):
|
|||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, batch, (forward_ns, decode_ns)
|
||||
|
||||
@property
|
||||
def supports_adapter_loading(self) -> bool:
|
||||
return True
|
||||
|
||||
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
|
||||
layer_weights = {}
|
||||
|
||||
prefix = "model.layers"
|
||||
|
||||
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
|
||||
# that have a language_model inside of the larger model.
|
||||
if hasattr(self.model, "language_model"):
|
||||
_model = self.model.language_model
|
||||
elif hasattr(self.model, "text_model"):
|
||||
_model = self.model.text_model
|
||||
else:
|
||||
_model = self.model
|
||||
|
||||
for i, layer in enumerate(_model.model.layers):
|
||||
layer_weights[(i, "q_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.q_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, "k_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.k_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, "v_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.v_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, "o_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.o_proj",
|
||||
layer.self_attn.o_proj,
|
||||
)
|
||||
|
||||
# TODO: this is a hack to avoid the gate_proj for
|
||||
# FlashStarcoder2 that doesnt have these layers
|
||||
if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"):
|
||||
layer_weights[(i, "gate_proj")] = (
|
||||
f"{prefix}.{i}.mlp.gate_proj",
|
||||
layer.mlp.gate_up_proj,
|
||||
)
|
||||
layer_weights[(i, "up_proj")] = (
|
||||
f"{prefix}.{i}.mlp.up_proj",
|
||||
layer.mlp.gate_up_proj,
|
||||
)
|
||||
layer_weights[(i, "down_proj")] = (
|
||||
f"{prefix}.{i}.mlp.down_proj",
|
||||
layer.mlp.down_proj,
|
||||
)
|
||||
|
||||
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
|
||||
return layer_weights
|
||||
|
||||
@property
|
||||
def adapter_layers(self) -> List[str]:
|
||||
return ADAPTER_LAYERS
|
||||
|
||||
@property
|
||||
def default_traced_adapter_layers(self) -> List[str]:
|
||||
return ["q_proj", "v_proj"]
|
||||
|
||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
||||
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
|
||||
|
||||
def is_row_parallel(self, layer_type: str) -> bool:
|
||||
return layer_type in ROW_PARALLEL
|
||||
|
|
|
@ -1,85 +0,0 @@
|
|||
import torch
|
||||
from typing import Optional, Tuple, Dict, List
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
|
||||
|
||||
ADAPTER_LAYERS = [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
|
||||
|
||||
|
||||
class FlashMistral(FlashCausalLM):
|
||||
@property
|
||||
def supports_adapter_loading(self) -> bool:
|
||||
return True
|
||||
|
||||
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
|
||||
layer_weights = {}
|
||||
|
||||
prefix = "model.layers"
|
||||
|
||||
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
|
||||
# that have a language_model inside of the larger model.
|
||||
if hasattr(self.model, "text_model"):
|
||||
_model = self.model.text_model
|
||||
else:
|
||||
_model = self.model
|
||||
|
||||
for i, layer in enumerate(_model.model.layers):
|
||||
layer_weights[(i, "q_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.q_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, "k_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.k_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, "v_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.v_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, "o_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.o_proj",
|
||||
layer.self_attn.o_proj,
|
||||
)
|
||||
|
||||
# TODO: this is a hack to avoid the gate_proj for
|
||||
# FlashStarcoder2 that doesnt have these layers
|
||||
if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"):
|
||||
layer_weights[(i, "gate_proj")] = (
|
||||
f"{prefix}.{i}.mlp.gate_proj",
|
||||
layer.mlp.gate_up_proj,
|
||||
)
|
||||
layer_weights[(i, "up_proj")] = (
|
||||
f"{prefix}.{i}.mlp.up_proj",
|
||||
layer.mlp.gate_up_proj,
|
||||
)
|
||||
layer_weights[(i, "down_proj")] = (
|
||||
f"{prefix}.{i}.mlp.down_proj",
|
||||
layer.mlp.down_proj,
|
||||
)
|
||||
|
||||
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
|
||||
return layer_weights
|
||||
|
||||
@property
|
||||
def adapter_layers(self) -> List[str]:
|
||||
return ADAPTER_LAYERS
|
||||
|
||||
@property
|
||||
def default_traced_adapter_layers(self) -> List[str]:
|
||||
return ["q_proj", "v_proj"]
|
||||
|
||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
||||
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
|
||||
|
||||
def is_row_parallel(self, layer_type: str) -> bool:
|
||||
return layer_type in ROW_PARALLEL
|
|
@ -4,20 +4,12 @@ import torch
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict
|
||||
from collections import defaultdict
|
||||
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from text_generation_server.models.types import Batch, Generation
|
||||
from text_generation_server.utils.speculate import get_speculate
|
||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||
from text_generation_server.adapters.weights import LayerAdapterWeights
|
||||
from text_generation_server.utils.adapter import (
|
||||
load_and_merge_adapters,
|
||||
AdapterParameters,
|
||||
AdapterSource,
|
||||
)
|
||||
from text_generation_server.utils.log import log_master
|
||||
from loguru import logger
|
||||
|
||||
|
||||
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
||||
|
||||
|
@ -61,7 +53,6 @@ class Model(ABC):
|
|||
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
|
||||
LayerAdapterWeights
|
||||
)
|
||||
self.target_to_layer = None
|
||||
self.loaded_adapters = set()
|
||||
self.static_adapter_id = adapter_id
|
||||
|
||||
|
@ -142,140 +133,3 @@ class Model(ABC):
|
|||
raise RuntimeError(
|
||||
f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_adapter_loading(self) -> bool:
|
||||
return False
|
||||
|
||||
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def adapter_layers(self) -> List[str]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def default_traced_adapter_layers(self) -> List[str]:
|
||||
return []
|
||||
|
||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
||||
return 0
|
||||
|
||||
def is_row_parallel(self, layer_type: str) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def max_speculative_tokens(self) -> int:
|
||||
return max(
|
||||
[
|
||||
weights.max_speculative_tokens
|
||||
for weights in self.layer_to_adapter_weights.values()
|
||||
],
|
||||
default=0,
|
||||
)
|
||||
|
||||
def load_adapter(
|
||||
self,
|
||||
adapter_parameters: AdapterParameters,
|
||||
adapter_source: AdapterSource,
|
||||
adapter_index: int,
|
||||
api_token: str,
|
||||
dynamic: bool = True,
|
||||
):
|
||||
"""Loads adapter weights from disk / host memory on the GPU.
|
||||
|
||||
adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded
|
||||
into model. Otherwise, the adapter weights are applied during the forward
|
||||
pass and stored separately from the base model parameters.
|
||||
"""
|
||||
if self.target_to_layer is None:
|
||||
self.target_to_layer = self.adapter_target_to_layer()
|
||||
if adapter_index in self.loaded_adapters:
|
||||
# Adapter already loaded
|
||||
return
|
||||
|
||||
if not self.supports_adapter_loading:
|
||||
raise ValueError("This model does not support adapter loading.")
|
||||
|
||||
if dynamic and not self.dynamic_adapter_loading_enabled:
|
||||
raise ValueError(
|
||||
f"This model was initialized with the adapter {self.static_adapter_id} "
|
||||
f"and therefore does not support dynamic adapter loading. "
|
||||
f"Please initialize a new model instance from the base model in "
|
||||
f"order to use the dynamic adapter loading feature."
|
||||
)
|
||||
|
||||
log_master(
|
||||
logger.info,
|
||||
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}",
|
||||
)
|
||||
weight_names = tuple([v[0] for v in self.target_to_layer.values()])
|
||||
(
|
||||
module_map,
|
||||
adapter_config,
|
||||
adapter_weight_names,
|
||||
adapter_tokenizer,
|
||||
) = load_and_merge_adapters(
|
||||
self.model_id,
|
||||
adapter_parameters,
|
||||
adapter_source,
|
||||
adapter_index,
|
||||
weight_names,
|
||||
api_token,
|
||||
False,
|
||||
)
|
||||
|
||||
unused_weight_names = adapter_weight_names.copy()
|
||||
for layer_name in self.adapter_layers:
|
||||
adapter_weights = adapter_config.load_batched_adapter_weights(
|
||||
self,
|
||||
module_map,
|
||||
layer_name,
|
||||
unused_weight_names,
|
||||
dynamic,
|
||||
)
|
||||
|
||||
if adapter_weights is None:
|
||||
continue
|
||||
|
||||
layer_weights = self.layer_to_adapter_weights[layer_name]
|
||||
layer_weights.add_adapter(adapter_index, adapter_weights)
|
||||
|
||||
if len(unused_weight_names) > 0:
|
||||
log_master(
|
||||
logger.warning,
|
||||
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}",
|
||||
)
|
||||
|
||||
if adapter_tokenizer is not None:
|
||||
self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)
|
||||
|
||||
self.loaded_adapters.add(adapter_index)
|
||||
|
||||
def offload_adapter(
|
||||
self,
|
||||
adapter_parameters: AdapterParameters,
|
||||
adapter_source: AdapterSource,
|
||||
adapter_index: int,
|
||||
):
|
||||
"""Offloads the adapter weights from GPU to CPU or disk."""
|
||||
if adapter_index not in self.loaded_adapters:
|
||||
# Adapter already offloaded
|
||||
return
|
||||
|
||||
if not self.supports_adapter_loading:
|
||||
raise ValueError("This model does not support adapter loading.")
|
||||
|
||||
if not self.dynamic_adapter_loading_enabled:
|
||||
raise ValueError(
|
||||
f"This model was initialized with the adapter {self.static_adapter_id} "
|
||||
f"and therefore does not support dynamic adapter loading. "
|
||||
f"Please initialize a new model instance from the base model in "
|
||||
f"order to use the dynamic adapter loading feature."
|
||||
)
|
||||
|
||||
for layer_name in self.adapter_layers:
|
||||
if layer_name in self.layer_to_adapter_weights:
|
||||
self.layer_to_adapter_weights[layer_name].remove_adapter(adapter_index)
|
||||
|
||||
self.loaded_adapters.remove(adapter_index)
|
||||
|
|
|
@ -9,11 +9,12 @@ from loguru import logger
|
|||
|
||||
from grpc_reflection.v1alpha import reflection
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
from text_generation_server.cache import Cache
|
||||
from text_generation_server.interceptor import ExceptionInterceptor
|
||||
from text_generation_server.models import Model, get_model
|
||||
from text_generation_server.models import Model, get_model_with_lora_adapters
|
||||
from text_generation_server.utils.adapter import AdapterInfo
|
||||
|
||||
try:
|
||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
||||
|
@ -30,9 +31,6 @@ except (ImportError, NotImplementedError):
|
|||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||
from text_generation_server.models.globals import set_model_id, set_adapter_to_index
|
||||
from text_generation_server.utils.adapter import (
|
||||
AdapterParameters,
|
||||
)
|
||||
|
||||
|
||||
class SignalHandler:
|
||||
|
@ -195,7 +193,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||
|
||||
def serve(
|
||||
model_id: str,
|
||||
lora_adapter_ids: Optional[List[str]],
|
||||
lora_adapters: Optional[List[AdapterInfo]],
|
||||
revision: Optional[str],
|
||||
sharded: bool,
|
||||
quantize: Optional[str],
|
||||
|
@ -207,7 +205,7 @@ def serve(
|
|||
):
|
||||
async def serve_inner(
|
||||
model_id: str,
|
||||
lora_adapter_ids: Optional[List[str]],
|
||||
lora_adapters: Optional[List[AdapterInfo]],
|
||||
revision: Optional[str],
|
||||
sharded: bool = False,
|
||||
quantize: Optional[str] = None,
|
||||
|
@ -228,9 +226,9 @@ def serve(
|
|||
server_urls = [local_url]
|
||||
|
||||
try:
|
||||
model = get_model(
|
||||
model = get_model_with_lora_adapters(
|
||||
model_id,
|
||||
lora_adapter_ids,
|
||||
lora_adapters,
|
||||
revision,
|
||||
sharded,
|
||||
quantize,
|
||||
|
@ -238,29 +236,9 @@ def serve(
|
|||
dtype,
|
||||
trust_remote_code,
|
||||
max_input_tokens,
|
||||
adapter_to_index,
|
||||
)
|
||||
|
||||
if len(lora_adapter_ids) > 0:
|
||||
for index, adapter_id in enumerate(lora_adapter_ids):
|
||||
# TODO: improve non merged adapter loading and long term
|
||||
# improve adapter loading as a whole
|
||||
adapter_parameters = AdapterParameters(
|
||||
adapter_ids=[adapter_id],
|
||||
weights=None, # will be set to 1
|
||||
merge_strategy=0,
|
||||
density=1.0,
|
||||
majority_sign_method=0,
|
||||
)
|
||||
adapter_index = index + 1
|
||||
adapter_to_index[adapter_id] = adapter_index
|
||||
model.load_adapter(
|
||||
adapter_parameters,
|
||||
None, # adapter_source
|
||||
adapter_index,
|
||||
None, # api_token
|
||||
False, # dynamic
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error when initializing model")
|
||||
raise
|
||||
|
@ -297,7 +275,7 @@ def serve(
|
|||
asyncio.run(
|
||||
serve_inner(
|
||||
model_id,
|
||||
lora_adapter_ids,
|
||||
lora_adapters,
|
||||
revision,
|
||||
sharded,
|
||||
quantize,
|
||||
|
|
|
@ -5,12 +5,11 @@
|
|||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Set, Tuple, Optional, List
|
||||
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils.merges.strategies import merge_adapters
|
||||
|
||||
from text_generation_server.utils import hub
|
||||
|
@ -24,9 +23,15 @@ if TYPE_CHECKING:
|
|||
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterInfo:
|
||||
id: str
|
||||
path: Optional[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterParameters:
|
||||
adapter_ids: Tuple[str]
|
||||
adapter_info: Tuple[AdapterInfo]
|
||||
weights: Tuple[float]
|
||||
merge_strategy: NotImplemented
|
||||
density: float
|
||||
|
@ -40,37 +45,47 @@ class AdapterSource:
|
|||
revision: str
|
||||
|
||||
|
||||
def parse_lora_adapters(lora_adapters: Optional[str]) -> List[AdapterInfo]:
|
||||
if not lora_adapters:
|
||||
return []
|
||||
|
||||
adapter_list = []
|
||||
for adapter in lora_adapters.split(","):
|
||||
parts = adapter.strip().split("=")
|
||||
if len(parts) == 1:
|
||||
adapter_list.append(AdapterInfo(id=parts[0], path=None))
|
||||
elif len(parts) == 2:
|
||||
adapter_list.append(AdapterInfo(id=parts[0], path=parts[1]))
|
||||
else:
|
||||
raise ValueError(f"Invalid LoRA adapter format: {adapter}")
|
||||
return adapter_list
|
||||
|
||||
|
||||
def load_and_merge_adapters(
|
||||
model_id: str,
|
||||
adapter_parameters: AdapterParameters,
|
||||
adapter_source: str,
|
||||
adapter_index: int,
|
||||
weight_names: Tuple[str],
|
||||
api_token: str,
|
||||
trust_remote_code: bool = False,
|
||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||
if len(adapter_parameters.adapter_ids) == 1:
|
||||
|
||||
if len(adapter_parameters.adapter_info) == 1:
|
||||
adapter_info = next(iter(adapter_parameters.adapter_info))
|
||||
return load_module_map(
|
||||
model_id,
|
||||
adapter_parameters.adapter_ids[0],
|
||||
adapter_source,
|
||||
adapter_info.id,
|
||||
adapter_info.path,
|
||||
weight_names,
|
||||
api_token,
|
||||
trust_remote_code,
|
||||
)
|
||||
|
||||
adapter_params = AdapterParametersContainer(
|
||||
adapter_parameters, adapter_source, adapter_index
|
||||
)
|
||||
return _load_and_merge(
|
||||
model_id, adapter_params, weight_names, api_token, trust_remote_code
|
||||
)
|
||||
adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index)
|
||||
return _load_and_merge(model_id, adapter_params, weight_names, trust_remote_code)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterParametersContainer:
|
||||
adapter_parameters: AdapterParameters
|
||||
adapter_source: str
|
||||
adapter_index: int
|
||||
|
||||
def __hash__(self) -> int:
|
||||
|
@ -82,7 +97,6 @@ def _load_and_merge(
|
|||
model_id: str,
|
||||
adapter_params: AdapterParametersContainer,
|
||||
weight_names: Tuple[str],
|
||||
api_token: str,
|
||||
trust_remote_code: bool = False,
|
||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||
params = adapter_params.adapter_parameters
|
||||
|
@ -90,17 +104,16 @@ def _load_and_merge(
|
|||
adapters_to_merge = []
|
||||
merged_weight_names = set()
|
||||
tokenizer = None
|
||||
for adapter_id in params.adapter_ids:
|
||||
if adapter_id == BASE_MODEL_ADAPTER_ID:
|
||||
for adapter in params.adapter_info:
|
||||
if adapter.id == BASE_MODEL_ADAPTER_ID:
|
||||
raise ValueError("Base model adapter cannot be merged.")
|
||||
|
||||
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
|
||||
load_module_map(
|
||||
model_id,
|
||||
adapter_id,
|
||||
adapter_params.adapter_source,
|
||||
adapter.id,
|
||||
adapter.path,
|
||||
weight_names,
|
||||
api_token,
|
||||
trust_remote_code,
|
||||
)
|
||||
)
|
||||
|
@ -159,25 +172,28 @@ def check_architectures(
|
|||
def load_module_map(
|
||||
model_id: str,
|
||||
adapter_id: str,
|
||||
adapter_source: str,
|
||||
adapter_path: Optional[str],
|
||||
weight_names: Tuple[str],
|
||||
api_token: str,
|
||||
trust_remote_code: bool = False,
|
||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||
revision = "main"
|
||||
|
||||
adapter_config = LoraConfig.load(adapter_id, api_token)
|
||||
if adapter_config.base_model_name_or_path != model_id:
|
||||
adapter_config = LoraConfig.load(adapter_path or adapter_id, None)
|
||||
|
||||
if not adapter_path and adapter_config.base_model_name_or_path != model_id:
|
||||
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
|
||||
|
||||
adapter_filenames = hub._cached_adapter_weight_files(
|
||||
adapter_id, revision=revision, extension=".safetensors"
|
||||
adapter_filenames = (
|
||||
hub._adapter_weight_files_from_dir(adapter_path, extension=".safetensors")
|
||||
if adapter_path
|
||||
else hub._cached_adapter_weight_files(
|
||||
adapter_id, revision=revision, extension=".safetensors"
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
adapter_tokenizer = AutoTokenizer.from_pretrained(
|
||||
adapter_config.config_path,
|
||||
token=api_token,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
except Exception:
|
||||
|
@ -194,3 +210,87 @@ def load_module_map(
|
|||
adapter_weights, weight_names
|
||||
)
|
||||
return module_map, adapter_config, adapter_weight_names, adapter_tokenizer
|
||||
|
||||
|
||||
def get_attn_weights(i, layer):
|
||||
qkv = layer.self_attn.query_key_value
|
||||
weights = {}
|
||||
|
||||
for k in ["q", "k", "v"]:
|
||||
key = (i, f"{k}_proj")
|
||||
value = (f"model.layers.{i}.self_attn.{k}_proj", qkv)
|
||||
weights[key] = value
|
||||
|
||||
weights[(i, "o_proj")] = (
|
||||
f"model.layers.{i}.self_attn.o_proj",
|
||||
layer.self_attn.o_proj,
|
||||
)
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
def get_mlp_weights(i, layer):
|
||||
weights = {}
|
||||
if hasattr(layer, "mlp"):
|
||||
mlp = layer.mlp
|
||||
if hasattr(mlp, "gate_up_proj"):
|
||||
# handle combined gate_up_proj (e.g., for some LLaMA variants)
|
||||
weights.update(
|
||||
{
|
||||
(i, "gate_proj"): (
|
||||
f"model.layers.{i}.mlp.gate_proj",
|
||||
mlp.gate_up_proj,
|
||||
),
|
||||
(i, "up_proj"): (f"model.layers.{i}.mlp.up_proj", mlp.gate_up_proj),
|
||||
}
|
||||
)
|
||||
else:
|
||||
# handle separate gate_proj, up_proj, and down_proj (e.g., for Gemma)
|
||||
if hasattr(mlp, "gate_proj"):
|
||||
weights[(i, "gate_proj")] = (
|
||||
f"model.layers.{i}.mlp.gate_proj",
|
||||
mlp.gate_proj,
|
||||
)
|
||||
if hasattr(mlp, "up_proj"):
|
||||
weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj)
|
||||
|
||||
if hasattr(mlp, "down_proj"):
|
||||
weights[(i, "down_proj")] = (
|
||||
f"model.layers.{i}.mlp.down_proj",
|
||||
mlp.down_proj,
|
||||
)
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
# build_layer_weight_lookup creates a mapping of model layers to their corresponding
|
||||
# weight tensors and paths. It builds a dictionary that maps layer identifiers to tuples
|
||||
# containing the weight tensor path and the actual layer object. This mapping is needed
|
||||
# for the lora adapter to know which weights to update when applying the adapter.
|
||||
def build_layer_weight_lookup(model):
|
||||
if hasattr(model, "language_model"):
|
||||
m = model.language_model.model
|
||||
elif hasattr(model, "text_model"):
|
||||
m = model.text_model.model
|
||||
else:
|
||||
m = model.model
|
||||
|
||||
layer_weights = {}
|
||||
|
||||
for i, layer in enumerate(m.layers):
|
||||
attn_weights = get_attn_weights(i, layer)
|
||||
mlp_weights = get_mlp_weights(i, layer)
|
||||
|
||||
layer_weights.update(attn_weights)
|
||||
layer_weights.update(mlp_weights)
|
||||
|
||||
lm_head = None
|
||||
if hasattr(m, "lm_head"):
|
||||
lm_head = m.lm_head
|
||||
elif hasattr(model, "lm_head"):
|
||||
lm_head = model.lm_head
|
||||
|
||||
if lm_head:
|
||||
layer_weights[(0, "lm_head")] = ("lm_head", lm_head)
|
||||
|
||||
return layer_weights
|
||||
|
|
Loading…
Reference in New Issue