fix: refactor adapter weight loading and mapping (#2193)

* fix: refactor adapter weight loading and mapping

* feat: enable lora load from directory

* fix: adjust launcher for local lora adapters

* feat: improve weight loading and add tests

* fix: improve logging and rebase syntax issue

* fix: impove adapter merge comments and remove unused conditional

* fix: improve get_model_with_lora_adapters naming

* fix: comment typo
This commit is contained in:
drbh 2024-07-24 15:32:14 -04:00 committed by GitHub
parent 93d2b9fe9c
commit 5d85a958c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 498 additions and 449 deletions

View File

@ -1607,6 +1607,10 @@ fn main() -> Result<(), LauncherError> {
// Download and convert lora adapters if any // Download and convert lora adapters if any
if let Some(lora_adapters) = &args.lora_adapters { if let Some(lora_adapters) = &args.lora_adapters {
for adapter in lora_adapters.split(',') { for adapter in lora_adapters.split(',') {
// skip download if a path is provided
if adapter.contains('=') {
continue;
}
download_convert_model( download_convert_model(
adapter, adapter,
None, None,

View File

@ -0,0 +1,187 @@
import pytest
from unittest.mock import Mock
from text_generation_server.utils.adapter import get_attn_weights, get_mlp_weights
def test_get_attn_weights():
# create a mock layer
mock_layer = Mock()
mock_layer.self_attn.query_key_value = Mock()
mock_layer.self_attn.o_proj = Mock()
# call the function
result = get_attn_weights(2, mock_layer)
# assert the result
expected = {
(2, "q_proj"): (
"model.layers.2.self_attn.q_proj",
mock_layer.self_attn.query_key_value,
),
(2, "k_proj"): (
"model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): (
"model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value,
),
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
}
assert result == expected
def test_get_mlp_weights_with_gate_up_proj():
# create a mock layer with gate_up_proj
mock_layer = Mock()
mock_layer.mlp.gate_up_proj = Mock()
mock_layer.mlp.down_proj = Mock()
# call the function
result = get_mlp_weights(3, mock_layer)
# assert the result
expected = {
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
}
assert result == expected
def test_get_mlp_weights_without_gate_up_proj():
# create a mock layer without gate_up_proj
mock_layer = Mock()
mock_layer.mlp = Mock(spec=[])
# call the function
result = get_mlp_weights(1, mock_layer)
# assert the result
assert result == {}
@pytest.mark.parametrize("layer_index", [0, 1, 5])
def test_get_attn_weights_different_layers(layer_index):
mock_layer = Mock()
mock_layer.self_attn.query_key_value = Mock()
mock_layer.self_attn.o_proj = Mock()
result = get_attn_weights(layer_index, mock_layer)
for k in ["q", "k", "v"]:
assert (layer_index, f"{k}_proj") in result
assert (
result[(layer_index, f"{k}_proj")][0]
== f"model.layers.{layer_index}.self_attn.{k}_proj"
)
assert (layer_index, "o_proj") in result
assert (
result[(layer_index, "o_proj")][0]
== f"model.layers.{layer_index}.self_attn.o_proj"
)
@pytest.mark.parametrize("layer_index", [0, 1, 5])
def test_get_mlp_weights_different_layers(layer_index):
mock_layer = Mock()
mock_layer.mlp.gate_up_proj = Mock()
mock_layer.mlp.down_proj = Mock()
result = get_mlp_weights(layer_index, mock_layer)
for k in ["gate", "up", "down"]:
assert (layer_index, f"{k}_proj") in result
assert (
result[(layer_index, f"{k}_proj")][0]
== f"model.layers.{layer_index}.mlp.{k}_proj"
)
def test_get_attn_weights_llama_compatibility():
mock_layer = Mock()
mock_layer.self_attn.query_key_value = Mock()
mock_layer.self_attn.o_proj = Mock()
result = get_attn_weights(2, mock_layer)
expected = {
(2, "q_proj"): (
"model.layers.2.self_attn.q_proj",
mock_layer.self_attn.query_key_value,
),
(2, "k_proj"): (
"model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): (
"model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value,
),
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
}
assert result == expected
def test_get_mlp_weights_llama_compatibility():
mock_layer = Mock()
mock_layer.mlp.gate_up_proj = Mock()
mock_layer.mlp.down_proj = Mock()
result = get_mlp_weights(3, mock_layer)
expected = {
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
}
assert result == expected
def test_get_attn_weights_gemma_compatibility():
mock_layer = Mock()
mock_layer.self_attn.query_key_value = Mock()
mock_layer.self_attn.o_proj = Mock()
result = get_attn_weights(2, mock_layer)
expected = {
(2, "q_proj"): (
"model.layers.2.self_attn.q_proj",
mock_layer.self_attn.query_key_value,
),
(2, "k_proj"): (
"model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): (
"model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value,
),
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
}
assert result == expected
def test_get_mlp_weights_gemma_compatibility():
mock_layer = Mock()
mock_layer.mlp.gate_proj = Mock()
mock_layer.mlp.up_proj = Mock()
mock_layer.mlp.down_proj = Mock()
# ensure that the mock_layer.mlp.gate_up_proj attribute does not exist.
# This is necessary because the use of `Mock` automatically creates any
# attributes that are accessed, even if they don't exist in the actual
# implementation. If `gate_up_proj` were created, `get_mlp_weights` might
# follow the wrong execution path and return an incorrect result.
del mock_layer.mlp.gate_up_proj
result = get_mlp_weights(3, mock_layer)
expected = {
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
}
assert result == expected

View File

@ -4,7 +4,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple from typing import TYPE_CHECKING, Dict, Set, Tuple
import torch import torch
@ -31,14 +31,3 @@ class AdapterConfig(ABC):
weight_names: Tuple[str], weight_names: Tuple[str],
) -> Tuple[ModuleMap, Set[str]]: ) -> Tuple[ModuleMap, Set[str]]:
pass pass
@abstractmethod
def load_batched_adapter_weights(
self,
model: "Model",
module_map: ModuleMap,
layer_type: str,
unused_weight_names: Set[str],
dynamic: bool,
) -> Optional[AdapterWeights]:
pass

View File

@ -102,22 +102,6 @@ class LoraConfig(AdapterConfig):
adapter_weight_names.add(lora_b_name) adapter_weight_names.add(lora_b_name)
return module_map, adapter_weight_names return module_map, adapter_weight_names
def load_batched_adapter_weights(
self,
model: "Model",
module_map: Dict[str, Dict],
layer_type: str,
unused_weight_names: Set[str],
dynamic: bool,
) -> Optional[AdapterWeights]:
return LoraWeights.load(
self,
model,
module_map,
layer_type,
unused_weight_names,
)
@classmethod @classmethod
def load(cls, adapter_id: str, api_token: str) -> "LoraConfig": def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token) hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
@ -192,22 +176,38 @@ class LoraWeights(AdapterWeights):
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]: def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
return [BatchLoraWeights] return [BatchLoraWeights]
# prepare pre-loaded lora weights for use in the model.
#
# this method processes and organizes lora weights for a specific layer type across all layers:
# - uses `config` (LoraConfig) to apply lora-specific settings like scaling factor.
# - retrieves weights from `module_map` based on the `layer_type`.
# - processes `nlayers` number of layers.
# - converts weights to the specified `dtype`.
# - shards weights across `world_size` number of processes using the `process_group`.
# - maps weights to specific layers using `target_to_layer`.
# - tracks `unused_weight_names` to identify any unused weights.
#
# the method handles weight transposition, scaling, and padding to ensure compatibility
# with SGMV or BGMV operations.
@classmethod @classmethod
def load( def prepare_weights(
cls, cls,
config: LoraConfig, config: LoraConfig,
model: "Model",
module_map: Dict[str, Dict], module_map: Dict[str, Dict],
layer_type: str, layer_type: str,
unused_weight_names: Set[str], unused_weight_names: Set[str],
nlayers: int,
dtype: torch.dtype,
world_size: int,
process_group: ProcessGroup,
target_to_layer: Dict[str, Tuple[str, torch.Tensor]],
) -> Optional[AdapterWeights]: ) -> Optional[AdapterWeights]:
nlayers = model.get_num_layers_for_type(layer_type)
lora_a_list = [None] * nlayers lora_a_list = [None] * nlayers
lora_b_list = [None] * nlayers lora_b_list = [None] * nlayers
for layer_id in range(nlayers): for layer_id in range(nlayers):
key = (layer_id, layer_type) key = (layer_id, layer_type)
weight_name, layer = model.target_to_layer[key] weight_name, layer = target_to_layer[key]
base_weight = layer.base_layer.linear.weight base_weight = layer.base_layer.linear.weight
base_device = base_weight.device base_device = base_weight.device
@ -216,10 +216,10 @@ class LoraWeights(AdapterWeights):
return None return None
lora_a, lora_a_name = module_map[weight_name]["lora_A"] lora_a, lora_a_name = module_map[weight_name]["lora_A"]
lora_a = lora_a.to(base_device, model.dtype) lora_a = lora_a.to(base_device, dtype)
lora_b, lora_b_name = module_map[weight_name]["lora_B"] lora_b, lora_b_name = module_map[weight_name]["lora_B"]
lora_b = lora_b.to(base_device, model.dtype) lora_b = lora_b.to(base_device, dtype)
scale = get_scaling_factor( scale = get_scaling_factor(
config.lora_alpha, config.lora_alpha,
@ -236,12 +236,8 @@ class LoraWeights(AdapterWeights):
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
# pad lora ranks to be compatible with sgmv # pad lora ranks to be compatible with sgmv
lora_a_list = [ lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list]
pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list]
]
lora_b_list = [
pad_rank(w, dim=0, world_size=model.world_size) for w in lora_b_list
]
if lora_a_list: if lora_a_list:
# update rank if it was padded # update rank if it was padded
@ -252,8 +248,8 @@ class LoraWeights(AdapterWeights):
*shard_lora_weights( *shard_lora_weights(
weights_a=lora_a_list, weights_a=lora_a_list,
weights_b=lora_b_list, weights_b=lora_b_list,
split_dim=0 if model.is_row_parallel(layer_type) else 1, split_dim=0 if layer_type in {"o_proj", "down_proj", "lm_head"} else 1,
process_group=model.process_group, process_group=process_group,
), ),
config, config,
) )
@ -293,10 +289,6 @@ class BatchLoraWeights(BatchAdapterWeights):
for rank_data in self.rank_data.values() for rank_data in self.rank_data.values()
) )
@classmethod
def key(cls) -> str:
return "lora"
@classmethod @classmethod
def load( def load(
self, self,

View File

@ -42,10 +42,6 @@ class BatchAdapterWeights(ABC):
def has_adapter(self, adapter_index: int) -> bool: def has_adapter(self, adapter_index: int) -> bool:
pass pass
@abstractclassmethod
def key(cls) -> str:
pass
@abstractclassmethod @abstractclassmethod
def load( def load(
cls, cls,
@ -71,13 +67,6 @@ class LayerAdapterWeights:
return return
del self.adapter_weights[adapter_idx] del self.adapter_weights[adapter_idx]
@property
def max_speculative_tokens(self) -> int:
return max(
adapter_weights.speculative_tokens
for adapter_weights in self.adapter_weights.values()
)
def is_empty(self) -> bool: def is_empty(self) -> bool:
return len(self.adapter_weights) == 0 return len(self.adapter_weights) == 0
@ -101,7 +90,7 @@ class LayerAdapterWeights:
adapter_weights, meta, prefill, prefill_head_indices adapter_weights, meta, prefill, prefill_head_indices
) )
if batched_weights is not None: if batched_weights is not None:
batch_data[batch_type.key()] = batched_weights batch_data = batched_weights
return batch_data return batch_data
@ -133,8 +122,7 @@ class AdapterBatchData:
def ranks(self) -> Set[int]: def ranks(self) -> Set[int]:
# TODO(travis): refactor to be less coupled to lora implementation # TODO(travis): refactor to be less coupled to lora implementation
ranks = set() ranks = set()
for layer_data in self.data.values(): for lora_data in self.data.values():
lora_data = layer_data.get("lora")
if lora_data is None: if lora_data is None:
continue continue

View File

@ -4,9 +4,10 @@ import typer
from pathlib import Path from pathlib import Path
from loguru import logger from loguru import logger
from typing import Optional from typing import Optional, List, Dict
from enum import Enum from enum import Enum
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from text_generation_server.utils.adapter import parse_lora_adapters
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
@ -80,28 +81,19 @@ def serve(
if otlp_endpoint is not None: if otlp_endpoint is not None:
setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint) setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
lora_adapter_ids = os.getenv("LORA_ADAPTERS", None) lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS"))
# split on comma and strip whitespace
lora_adapter_ids = (
[x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else []
)
if len(lora_adapter_ids) > 0:
log_master(
logger.warning,
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected.",
)
# TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled # TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
# and warn the user # and warn the user
if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None: if lora_adapters:
log_master( logger.warning("LoRA adapters enabled (experimental feature).")
logger.warning,
f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs.", if "CUDA_GRAPHS" in os.environ:
) logger.warning(
global CUDA_GRAPHS "LoRA adapters incompatible with CUDA Graphs. Disabling CUDA Graphs."
CUDA_GRAPHS = None )
global CUDA_GRAPHS
CUDA_GRAPHS = None
# Downgrade enum into str for easier management later on # Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value quantize = None if quantize is None else quantize.value
@ -117,7 +109,7 @@ def serve(
) )
server.serve( server.serve(
model_id, model_id,
lora_adapter_ids, lora_adapters,
revision, revision,
sharded, sharded,
quantize, quantize,

View File

@ -43,10 +43,7 @@ class LoraLinear(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
if adapter_data is None: if adapter_data is None:
return result return result
data = adapter_data.data.get(layer_type) data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type)
data: Optional["BatchLoraWeights"] = (
data.get("lora") if data is not None else None
)
if has_sgmv() and data is not None and data.can_vectorize(self.process_group): if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
# In tensor-parallel configurations, each GPU processes a specific segment of the output. # In tensor-parallel configurations, each GPU processes a specific segment of the output.

View File

@ -6,7 +6,7 @@ from loguru import logger
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download, HfApi from huggingface_hub import hf_hub_download, HfApi
from typing import Optional, List from typing import Optional, List, Dict
from pathlib import Path from pathlib import Path
from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.utils.speculate import get_speculate, set_speculate
@ -33,6 +33,16 @@ from text_generation_server.models.custom_modeling.t5_modeling import (
T5ForConditionalGeneration, T5ForConditionalGeneration,
) )
from text_generation_server.utils.adapter import (
AdapterParameters,
build_layer_weight_lookup,
load_and_merge_adapters,
AdapterInfo,
)
from text_generation_server.adapters.lora import LoraWeights
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
@ -50,7 +60,7 @@ __all__ = [
"Model", "Model",
"CausalLM", "CausalLM",
"Seq2SeqLM", "Seq2SeqLM",
"get_model", "get_model_with_lora_adapters",
] ]
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
@ -1115,3 +1125,116 @@ def get_model(
) )
raise ValueError(f"Unsupported model type {model_type}") raise ValueError(f"Unsupported model type {model_type}")
# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters
# this provides a post model loading hook to load adapters into the model after the model has been loaded
def get_model_with_lora_adapters(
model_id: str,
lora_adapters: Optional[List[AdapterInfo]],
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
speculate: Optional[int],
dtype: Optional[str],
trust_remote_code: bool,
max_input_tokens: int,
adapter_to_index: Dict[str, int],
):
lora_adapter_ids = [adapter.id for adapter in lora_adapters]
model = get_model(
model_id,
lora_adapter_ids,
revision,
sharded,
quantize,
speculate,
dtype,
trust_remote_code,
max_input_tokens,
)
if len(lora_adapters) > 0:
target_to_layer = build_layer_weight_lookup(model.model)
for index, adapter in enumerate(lora_adapters):
# The AdapterParameters object allows for merging multiple adapters into a single adapter.
# At the moment, we only support loading a single adapter into the model, but we keep the
# AdapterParameters object for easier extension in the future.
adapter_parameters = AdapterParameters(
adapter_info=[adapter],
# when merging multiple adapters we can weight them differently
# if this is not set, all adapters will be weighted equally
# see: text_generation_server.utils.merges.strategies for impl
weights=None,
merge_strategy=0,
density=1.0,
majority_sign_method=0,
)
adapter_index = index + 1
adapter_to_index[adapter.id] = adapter_index
logger.info(
f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}"
)
weight_names = tuple([v[0] for v in target_to_layer.values()])
(
module_map,
adapter_config,
adapter_weight_names,
adapter_tokenizer,
) = load_and_merge_adapters(
model.model_id,
adapter_parameters,
adapter_index,
weight_names,
False,
)
unused_weight_names = adapter_weight_names.copy()
adapter_layers = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
for layer_name in adapter_layers:
nlayers = (
1 if layer_name == "lm_head" else len(model.model.model.layers)
)
adapter_weights = LoraWeights.prepare_weights(
config=adapter_config,
module_map=module_map,
layer_type=layer_name,
unused_weight_names=unused_weight_names,
nlayers=nlayers,
dtype=model.dtype,
world_size=model.world_size,
process_group=model.process_group,
target_to_layer=target_to_layer,
)
if adapter_weights is None:
continue
model.layer_to_adapter_weights[layer_name].add_adapter(
adapter_index, adapter_weights
)
if len(unused_weight_names) > 0:
logger.warning(
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
)
if adapter_tokenizer is not None:
model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)
model.loaded_adapters.add(adapter_index)
return model

View File

@ -1,7 +1,6 @@
import math import math
import os import os
import time import time
import itertools
import torch import torch
import torch.distributed import torch.distributed
@ -1700,72 +1699,3 @@ class FlashCausalLM(Model):
forward_ns = start_decode - start forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns) return generations, batch, (forward_ns, decode_ns)
@property
def supports_adapter_loading(self) -> bool:
return True
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights = {}
prefix = "model.layers"
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
# that have a language_model inside of the larger model.
if hasattr(self.model, "language_model"):
_model = self.model.language_model
elif hasattr(self.model, "text_model"):
_model = self.model.text_model
else:
_model = self.model
for i, layer in enumerate(_model.model.layers):
layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "k_proj")] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "v_proj")] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "o_proj")] = (
f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj,
)
# TODO: this is a hack to avoid the gate_proj for
# FlashStarcoder2 that doesnt have these layers
if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"):
layer_weights[(i, "gate_proj")] = (
f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "up_proj")] = (
f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "down_proj")] = (
f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj,
)
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
return layer_weights
@property
def adapter_layers(self) -> List[str]:
return ADAPTER_LAYERS
@property
def default_traced_adapter_layers(self) -> List[str]:
return ["q_proj", "v_proj"]
def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL

View File

@ -1,85 +0,0 @@
import torch
from typing import Optional, Tuple, Dict, List
from text_generation_server.models import FlashCausalLM
ADAPTER_LAYERS = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
class FlashMistral(FlashCausalLM):
@property
def supports_adapter_loading(self) -> bool:
return True
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights = {}
prefix = "model.layers"
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
# that have a language_model inside of the larger model.
if hasattr(self.model, "text_model"):
_model = self.model.text_model
else:
_model = self.model
for i, layer in enumerate(_model.model.layers):
layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "k_proj")] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "v_proj")] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "o_proj")] = (
f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj,
)
# TODO: this is a hack to avoid the gate_proj for
# FlashStarcoder2 that doesnt have these layers
if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"):
layer_weights[(i, "gate_proj")] = (
f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "up_proj")] = (
f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "down_proj")] = (
f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj,
)
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
return layer_weights
@property
def adapter_layers(self) -> List[str]:
return ADAPTER_LAYERS
@property
def default_traced_adapter_layers(self) -> List[str]:
return ["q_proj", "v_proj"]
def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL

View File

@ -4,20 +4,12 @@ import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict
from collections import defaultdict from collections import defaultdict
from transformers import PreTrainedTokenizerBase, PretrainedConfig from transformers import PreTrainedTokenizerBase
from text_generation_server.models.types import Batch, Generation from text_generation_server.models.types import Batch, Generation
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
from text_generation_server.pb.generate_pb2 import InfoResponse from text_generation_server.pb.generate_pb2 import InfoResponse
from text_generation_server.adapters.weights import LayerAdapterWeights from text_generation_server.adapters.weights import LayerAdapterWeights
from text_generation_server.utils.adapter import (
load_and_merge_adapters,
AdapterParameters,
AdapterSource,
)
from text_generation_server.utils.log import log_master
from loguru import logger
BASE_MODEL_ADAPTER_ID = "__base_model__" BASE_MODEL_ADAPTER_ID = "__base_model__"
@ -61,7 +53,6 @@ class Model(ABC):
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict( self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
LayerAdapterWeights LayerAdapterWeights
) )
self.target_to_layer = None
self.loaded_adapters = set() self.loaded_adapters = set()
self.static_adapter_id = adapter_id self.static_adapter_id = adapter_id
@ -142,140 +133,3 @@ class Model(ABC):
raise RuntimeError( raise RuntimeError(
f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}" f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
) )
@property
def supports_adapter_loading(self) -> bool:
return False
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
return {}
@property
def adapter_layers(self) -> List[str]:
return []
@property
def default_traced_adapter_layers(self) -> List[str]:
return []
def get_num_layers_for_type(self, layer_type: str) -> int:
return 0
def is_row_parallel(self, layer_type: str) -> bool:
return False
@property
def max_speculative_tokens(self) -> int:
return max(
[
weights.max_speculative_tokens
for weights in self.layer_to_adapter_weights.values()
],
default=0,
)
def load_adapter(
self,
adapter_parameters: AdapterParameters,
adapter_source: AdapterSource,
adapter_index: int,
api_token: str,
dynamic: bool = True,
):
"""Loads adapter weights from disk / host memory on the GPU.
adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded
into model. Otherwise, the adapter weights are applied during the forward
pass and stored separately from the base model parameters.
"""
if self.target_to_layer is None:
self.target_to_layer = self.adapter_target_to_layer()
if adapter_index in self.loaded_adapters:
# Adapter already loaded
return
if not self.supports_adapter_loading:
raise ValueError("This model does not support adapter loading.")
if dynamic and not self.dynamic_adapter_loading_enabled:
raise ValueError(
f"This model was initialized with the adapter {self.static_adapter_id} "
f"and therefore does not support dynamic adapter loading. "
f"Please initialize a new model instance from the base model in "
f"order to use the dynamic adapter loading feature."
)
log_master(
logger.info,
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}",
)
weight_names = tuple([v[0] for v in self.target_to_layer.values()])
(
module_map,
adapter_config,
adapter_weight_names,
adapter_tokenizer,
) = load_and_merge_adapters(
self.model_id,
adapter_parameters,
adapter_source,
adapter_index,
weight_names,
api_token,
False,
)
unused_weight_names = adapter_weight_names.copy()
for layer_name in self.adapter_layers:
adapter_weights = adapter_config.load_batched_adapter_weights(
self,
module_map,
layer_name,
unused_weight_names,
dynamic,
)
if adapter_weights is None:
continue
layer_weights = self.layer_to_adapter_weights[layer_name]
layer_weights.add_adapter(adapter_index, adapter_weights)
if len(unused_weight_names) > 0:
log_master(
logger.warning,
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}",
)
if adapter_tokenizer is not None:
self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)
self.loaded_adapters.add(adapter_index)
def offload_adapter(
self,
adapter_parameters: AdapterParameters,
adapter_source: AdapterSource,
adapter_index: int,
):
"""Offloads the adapter weights from GPU to CPU or disk."""
if adapter_index not in self.loaded_adapters:
# Adapter already offloaded
return
if not self.supports_adapter_loading:
raise ValueError("This model does not support adapter loading.")
if not self.dynamic_adapter_loading_enabled:
raise ValueError(
f"This model was initialized with the adapter {self.static_adapter_id} "
f"and therefore does not support dynamic adapter loading. "
f"Please initialize a new model instance from the base model in "
f"order to use the dynamic adapter loading feature."
)
for layer_name in self.adapter_layers:
if layer_name in self.layer_to_adapter_weights:
self.layer_to_adapter_weights[layer_name].remove_adapter(adapter_index)
self.loaded_adapters.remove(adapter_index)

View File

@ -9,11 +9,12 @@ from loguru import logger
from grpc_reflection.v1alpha import reflection from grpc_reflection.v1alpha import reflection
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional, Dict
from text_generation_server.cache import Cache from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model from text_generation_server.models import Model, get_model_with_lora_adapters
from text_generation_server.utils.adapter import AdapterInfo
try: try:
from text_generation_server.models.pali_gemma import PaliGemmaBatch from text_generation_server.models.pali_gemma import PaliGemmaBatch
@ -30,9 +31,6 @@ except (ImportError, NotImplementedError):
from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.globals import set_model_id, set_adapter_to_index from text_generation_server.models.globals import set_model_id, set_adapter_to_index
from text_generation_server.utils.adapter import (
AdapterParameters,
)
class SignalHandler: class SignalHandler:
@ -195,7 +193,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def serve( def serve(
model_id: str, model_id: str,
lora_adapter_ids: Optional[List[str]], lora_adapters: Optional[List[AdapterInfo]],
revision: Optional[str], revision: Optional[str],
sharded: bool, sharded: bool,
quantize: Optional[str], quantize: Optional[str],
@ -207,7 +205,7 @@ def serve(
): ):
async def serve_inner( async def serve_inner(
model_id: str, model_id: str,
lora_adapter_ids: Optional[List[str]], lora_adapters: Optional[List[AdapterInfo]],
revision: Optional[str], revision: Optional[str],
sharded: bool = False, sharded: bool = False,
quantize: Optional[str] = None, quantize: Optional[str] = None,
@ -228,9 +226,9 @@ def serve(
server_urls = [local_url] server_urls = [local_url]
try: try:
model = get_model( model = get_model_with_lora_adapters(
model_id, model_id,
lora_adapter_ids, lora_adapters,
revision, revision,
sharded, sharded,
quantize, quantize,
@ -238,29 +236,9 @@ def serve(
dtype, dtype,
trust_remote_code, trust_remote_code,
max_input_tokens, max_input_tokens,
adapter_to_index,
) )
if len(lora_adapter_ids) > 0:
for index, adapter_id in enumerate(lora_adapter_ids):
# TODO: improve non merged adapter loading and long term
# improve adapter loading as a whole
adapter_parameters = AdapterParameters(
adapter_ids=[adapter_id],
weights=None, # will be set to 1
merge_strategy=0,
density=1.0,
majority_sign_method=0,
)
adapter_index = index + 1
adapter_to_index[adapter_id] = adapter_index
model.load_adapter(
adapter_parameters,
None, # adapter_source
adapter_index,
None, # api_token
False, # dynamic
)
except Exception: except Exception:
logger.exception("Error when initializing model") logger.exception("Error when initializing model")
raise raise
@ -297,7 +275,7 @@ def serve(
asyncio.run( asyncio.run(
serve_inner( serve_inner(
model_id, model_id,
lora_adapter_ids, lora_adapters,
revision, revision,
sharded, sharded,
quantize, quantize,

View File

@ -5,12 +5,11 @@
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from typing import TYPE_CHECKING, Set, Tuple from typing import TYPE_CHECKING, Set, Tuple, Optional, List
from safetensors.torch import load_file from safetensors.torch import load_file
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
from text_generation_server.pb import generate_pb2
from text_generation_server.utils.merges.strategies import merge_adapters from text_generation_server.utils.merges.strategies import merge_adapters
from text_generation_server.utils import hub from text_generation_server.utils import hub
@ -24,9 +23,15 @@ if TYPE_CHECKING:
BASE_MODEL_ADAPTER_ID = "__base_model__" BASE_MODEL_ADAPTER_ID = "__base_model__"
@dataclass
class AdapterInfo:
id: str
path: Optional[str]
@dataclass @dataclass
class AdapterParameters: class AdapterParameters:
adapter_ids: Tuple[str] adapter_info: Tuple[AdapterInfo]
weights: Tuple[float] weights: Tuple[float]
merge_strategy: NotImplemented merge_strategy: NotImplemented
density: float density: float
@ -40,37 +45,47 @@ class AdapterSource:
revision: str revision: str
def parse_lora_adapters(lora_adapters: Optional[str]) -> List[AdapterInfo]:
if not lora_adapters:
return []
adapter_list = []
for adapter in lora_adapters.split(","):
parts = adapter.strip().split("=")
if len(parts) == 1:
adapter_list.append(AdapterInfo(id=parts[0], path=None))
elif len(parts) == 2:
adapter_list.append(AdapterInfo(id=parts[0], path=parts[1]))
else:
raise ValueError(f"Invalid LoRA adapter format: {adapter}")
return adapter_list
def load_and_merge_adapters( def load_and_merge_adapters(
model_id: str, model_id: str,
adapter_parameters: AdapterParameters, adapter_parameters: AdapterParameters,
adapter_source: str,
adapter_index: int, adapter_index: int,
weight_names: Tuple[str], weight_names: Tuple[str],
api_token: str,
trust_remote_code: bool = False, trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
if len(adapter_parameters.adapter_ids) == 1:
if len(adapter_parameters.adapter_info) == 1:
adapter_info = next(iter(adapter_parameters.adapter_info))
return load_module_map( return load_module_map(
model_id, model_id,
adapter_parameters.adapter_ids[0], adapter_info.id,
adapter_source, adapter_info.path,
weight_names, weight_names,
api_token,
trust_remote_code, trust_remote_code,
) )
adapter_params = AdapterParametersContainer( adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index)
adapter_parameters, adapter_source, adapter_index return _load_and_merge(model_id, adapter_params, weight_names, trust_remote_code)
)
return _load_and_merge(
model_id, adapter_params, weight_names, api_token, trust_remote_code
)
@dataclass @dataclass
class AdapterParametersContainer: class AdapterParametersContainer:
adapter_parameters: AdapterParameters adapter_parameters: AdapterParameters
adapter_source: str
adapter_index: int adapter_index: int
def __hash__(self) -> int: def __hash__(self) -> int:
@ -82,7 +97,6 @@ def _load_and_merge(
model_id: str, model_id: str,
adapter_params: AdapterParametersContainer, adapter_params: AdapterParametersContainer,
weight_names: Tuple[str], weight_names: Tuple[str],
api_token: str,
trust_remote_code: bool = False, trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
params = adapter_params.adapter_parameters params = adapter_params.adapter_parameters
@ -90,17 +104,16 @@ def _load_and_merge(
adapters_to_merge = [] adapters_to_merge = []
merged_weight_names = set() merged_weight_names = set()
tokenizer = None tokenizer = None
for adapter_id in params.adapter_ids: for adapter in params.adapter_info:
if adapter_id == BASE_MODEL_ADAPTER_ID: if adapter.id == BASE_MODEL_ADAPTER_ID:
raise ValueError("Base model adapter cannot be merged.") raise ValueError("Base model adapter cannot be merged.")
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = ( module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
load_module_map( load_module_map(
model_id, model_id,
adapter_id, adapter.id,
adapter_params.adapter_source, adapter.path,
weight_names, weight_names,
api_token,
trust_remote_code, trust_remote_code,
) )
) )
@ -159,25 +172,28 @@ def check_architectures(
def load_module_map( def load_module_map(
model_id: str, model_id: str,
adapter_id: str, adapter_id: str,
adapter_source: str, adapter_path: Optional[str],
weight_names: Tuple[str], weight_names: Tuple[str],
api_token: str,
trust_remote_code: bool = False, trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
revision = "main" revision = "main"
adapter_config = LoraConfig.load(adapter_id, api_token) adapter_config = LoraConfig.load(adapter_path or adapter_id, None)
if adapter_config.base_model_name_or_path != model_id:
if not adapter_path and adapter_config.base_model_name_or_path != model_id:
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code) check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
adapter_filenames = hub._cached_adapter_weight_files( adapter_filenames = (
adapter_id, revision=revision, extension=".safetensors" hub._adapter_weight_files_from_dir(adapter_path, extension=".safetensors")
if adapter_path
else hub._cached_adapter_weight_files(
adapter_id, revision=revision, extension=".safetensors"
)
) )
try: try:
adapter_tokenizer = AutoTokenizer.from_pretrained( adapter_tokenizer = AutoTokenizer.from_pretrained(
adapter_config.config_path, adapter_config.config_path,
token=api_token,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
except Exception: except Exception:
@ -194,3 +210,87 @@ def load_module_map(
adapter_weights, weight_names adapter_weights, weight_names
) )
return module_map, adapter_config, adapter_weight_names, adapter_tokenizer return module_map, adapter_config, adapter_weight_names, adapter_tokenizer
def get_attn_weights(i, layer):
qkv = layer.self_attn.query_key_value
weights = {}
for k in ["q", "k", "v"]:
key = (i, f"{k}_proj")
value = (f"model.layers.{i}.self_attn.{k}_proj", qkv)
weights[key] = value
weights[(i, "o_proj")] = (
f"model.layers.{i}.self_attn.o_proj",
layer.self_attn.o_proj,
)
return weights
def get_mlp_weights(i, layer):
weights = {}
if hasattr(layer, "mlp"):
mlp = layer.mlp
if hasattr(mlp, "gate_up_proj"):
# handle combined gate_up_proj (e.g., for some LLaMA variants)
weights.update(
{
(i, "gate_proj"): (
f"model.layers.{i}.mlp.gate_proj",
mlp.gate_up_proj,
),
(i, "up_proj"): (f"model.layers.{i}.mlp.up_proj", mlp.gate_up_proj),
}
)
else:
# handle separate gate_proj, up_proj, and down_proj (e.g., for Gemma)
if hasattr(mlp, "gate_proj"):
weights[(i, "gate_proj")] = (
f"model.layers.{i}.mlp.gate_proj",
mlp.gate_proj,
)
if hasattr(mlp, "up_proj"):
weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj)
if hasattr(mlp, "down_proj"):
weights[(i, "down_proj")] = (
f"model.layers.{i}.mlp.down_proj",
mlp.down_proj,
)
return weights
# build_layer_weight_lookup creates a mapping of model layers to their corresponding
# weight tensors and paths. It builds a dictionary that maps layer identifiers to tuples
# containing the weight tensor path and the actual layer object. This mapping is needed
# for the lora adapter to know which weights to update when applying the adapter.
def build_layer_weight_lookup(model):
if hasattr(model, "language_model"):
m = model.language_model.model
elif hasattr(model, "text_model"):
m = model.text_model.model
else:
m = model.model
layer_weights = {}
for i, layer in enumerate(m.layers):
attn_weights = get_attn_weights(i, layer)
mlp_weights = get_mlp_weights(i, layer)
layer_weights.update(attn_weights)
layer_weights.update(mlp_weights)
lm_head = None
if hasattr(m, "lm_head"):
lm_head = m.lm_head
elif hasattr(model, "lm_head"):
lm_head = model.lm_head
if lm_head:
layer_weights[(0, "lm_head")] = ("lm_head", lm_head)
return layer_weights