fix: refactor adapter weight loading and mapping (#2193)
* fix: refactor adapter weight loading and mapping * feat: enable lora load from directory * fix: adjust launcher for local lora adapters * feat: improve weight loading and add tests * fix: improve logging and rebase syntax issue * fix: impove adapter merge comments and remove unused conditional * fix: improve get_model_with_lora_adapters naming * fix: comment typo
This commit is contained in:
parent
93d2b9fe9c
commit
5d85a958c9
|
@ -1607,6 +1607,10 @@ fn main() -> Result<(), LauncherError> {
|
||||||
// Download and convert lora adapters if any
|
// Download and convert lora adapters if any
|
||||||
if let Some(lora_adapters) = &args.lora_adapters {
|
if let Some(lora_adapters) = &args.lora_adapters {
|
||||||
for adapter in lora_adapters.split(',') {
|
for adapter in lora_adapters.split(',') {
|
||||||
|
// skip download if a path is provided
|
||||||
|
if adapter.contains('=') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
download_convert_model(
|
download_convert_model(
|
||||||
adapter,
|
adapter,
|
||||||
None,
|
None,
|
||||||
|
|
|
@ -0,0 +1,187 @@
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock
|
||||||
|
from text_generation_server.utils.adapter import get_attn_weights, get_mlp_weights
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_attn_weights():
|
||||||
|
# create a mock layer
|
||||||
|
mock_layer = Mock()
|
||||||
|
mock_layer.self_attn.query_key_value = Mock()
|
||||||
|
mock_layer.self_attn.o_proj = Mock()
|
||||||
|
|
||||||
|
# call the function
|
||||||
|
result = get_attn_weights(2, mock_layer)
|
||||||
|
|
||||||
|
# assert the result
|
||||||
|
expected = {
|
||||||
|
(2, "q_proj"): (
|
||||||
|
"model.layers.2.self_attn.q_proj",
|
||||||
|
mock_layer.self_attn.query_key_value,
|
||||||
|
),
|
||||||
|
(2, "k_proj"): (
|
||||||
|
"model.layers.2.self_attn.k_proj",
|
||||||
|
mock_layer.self_attn.query_key_value,
|
||||||
|
),
|
||||||
|
(2, "v_proj"): (
|
||||||
|
"model.layers.2.self_attn.v_proj",
|
||||||
|
mock_layer.self_attn.query_key_value,
|
||||||
|
),
|
||||||
|
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
|
||||||
|
}
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_mlp_weights_with_gate_up_proj():
|
||||||
|
# create a mock layer with gate_up_proj
|
||||||
|
mock_layer = Mock()
|
||||||
|
mock_layer.mlp.gate_up_proj = Mock()
|
||||||
|
mock_layer.mlp.down_proj = Mock()
|
||||||
|
|
||||||
|
# call the function
|
||||||
|
result = get_mlp_weights(3, mock_layer)
|
||||||
|
|
||||||
|
# assert the result
|
||||||
|
expected = {
|
||||||
|
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
|
||||||
|
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
|
||||||
|
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
|
||||||
|
}
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_mlp_weights_without_gate_up_proj():
|
||||||
|
# create a mock layer without gate_up_proj
|
||||||
|
mock_layer = Mock()
|
||||||
|
mock_layer.mlp = Mock(spec=[])
|
||||||
|
|
||||||
|
# call the function
|
||||||
|
result = get_mlp_weights(1, mock_layer)
|
||||||
|
|
||||||
|
# assert the result
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("layer_index", [0, 1, 5])
|
||||||
|
def test_get_attn_weights_different_layers(layer_index):
|
||||||
|
mock_layer = Mock()
|
||||||
|
mock_layer.self_attn.query_key_value = Mock()
|
||||||
|
mock_layer.self_attn.o_proj = Mock()
|
||||||
|
|
||||||
|
result = get_attn_weights(layer_index, mock_layer)
|
||||||
|
|
||||||
|
for k in ["q", "k", "v"]:
|
||||||
|
assert (layer_index, f"{k}_proj") in result
|
||||||
|
assert (
|
||||||
|
result[(layer_index, f"{k}_proj")][0]
|
||||||
|
== f"model.layers.{layer_index}.self_attn.{k}_proj"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (layer_index, "o_proj") in result
|
||||||
|
assert (
|
||||||
|
result[(layer_index, "o_proj")][0]
|
||||||
|
== f"model.layers.{layer_index}.self_attn.o_proj"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("layer_index", [0, 1, 5])
|
||||||
|
def test_get_mlp_weights_different_layers(layer_index):
|
||||||
|
mock_layer = Mock()
|
||||||
|
mock_layer.mlp.gate_up_proj = Mock()
|
||||||
|
mock_layer.mlp.down_proj = Mock()
|
||||||
|
|
||||||
|
result = get_mlp_weights(layer_index, mock_layer)
|
||||||
|
|
||||||
|
for k in ["gate", "up", "down"]:
|
||||||
|
assert (layer_index, f"{k}_proj") in result
|
||||||
|
assert (
|
||||||
|
result[(layer_index, f"{k}_proj")][0]
|
||||||
|
== f"model.layers.{layer_index}.mlp.{k}_proj"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_attn_weights_llama_compatibility():
|
||||||
|
mock_layer = Mock()
|
||||||
|
mock_layer.self_attn.query_key_value = Mock()
|
||||||
|
mock_layer.self_attn.o_proj = Mock()
|
||||||
|
|
||||||
|
result = get_attn_weights(2, mock_layer)
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
(2, "q_proj"): (
|
||||||
|
"model.layers.2.self_attn.q_proj",
|
||||||
|
mock_layer.self_attn.query_key_value,
|
||||||
|
),
|
||||||
|
(2, "k_proj"): (
|
||||||
|
"model.layers.2.self_attn.k_proj",
|
||||||
|
mock_layer.self_attn.query_key_value,
|
||||||
|
),
|
||||||
|
(2, "v_proj"): (
|
||||||
|
"model.layers.2.self_attn.v_proj",
|
||||||
|
mock_layer.self_attn.query_key_value,
|
||||||
|
),
|
||||||
|
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
|
||||||
|
}
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_mlp_weights_llama_compatibility():
|
||||||
|
mock_layer = Mock()
|
||||||
|
mock_layer.mlp.gate_up_proj = Mock()
|
||||||
|
mock_layer.mlp.down_proj = Mock()
|
||||||
|
|
||||||
|
result = get_mlp_weights(3, mock_layer)
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
|
||||||
|
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
|
||||||
|
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
|
||||||
|
}
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_attn_weights_gemma_compatibility():
|
||||||
|
mock_layer = Mock()
|
||||||
|
mock_layer.self_attn.query_key_value = Mock()
|
||||||
|
mock_layer.self_attn.o_proj = Mock()
|
||||||
|
|
||||||
|
result = get_attn_weights(2, mock_layer)
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
(2, "q_proj"): (
|
||||||
|
"model.layers.2.self_attn.q_proj",
|
||||||
|
mock_layer.self_attn.query_key_value,
|
||||||
|
),
|
||||||
|
(2, "k_proj"): (
|
||||||
|
"model.layers.2.self_attn.k_proj",
|
||||||
|
mock_layer.self_attn.query_key_value,
|
||||||
|
),
|
||||||
|
(2, "v_proj"): (
|
||||||
|
"model.layers.2.self_attn.v_proj",
|
||||||
|
mock_layer.self_attn.query_key_value,
|
||||||
|
),
|
||||||
|
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
|
||||||
|
}
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_mlp_weights_gemma_compatibility():
|
||||||
|
mock_layer = Mock()
|
||||||
|
mock_layer.mlp.gate_proj = Mock()
|
||||||
|
mock_layer.mlp.up_proj = Mock()
|
||||||
|
mock_layer.mlp.down_proj = Mock()
|
||||||
|
|
||||||
|
# ensure that the mock_layer.mlp.gate_up_proj attribute does not exist.
|
||||||
|
# This is necessary because the use of `Mock` automatically creates any
|
||||||
|
# attributes that are accessed, even if they don't exist in the actual
|
||||||
|
# implementation. If `gate_up_proj` were created, `get_mlp_weights` might
|
||||||
|
# follow the wrong execution path and return an incorrect result.
|
||||||
|
del mock_layer.mlp.gate_up_proj
|
||||||
|
|
||||||
|
result = get_mlp_weights(3, mock_layer)
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj),
|
||||||
|
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj),
|
||||||
|
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
|
||||||
|
}
|
||||||
|
assert result == expected
|
|
@ -4,7 +4,7 @@
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
|
from typing import TYPE_CHECKING, Dict, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -31,14 +31,3 @@ class AdapterConfig(ABC):
|
||||||
weight_names: Tuple[str],
|
weight_names: Tuple[str],
|
||||||
) -> Tuple[ModuleMap, Set[str]]:
|
) -> Tuple[ModuleMap, Set[str]]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def load_batched_adapter_weights(
|
|
||||||
self,
|
|
||||||
model: "Model",
|
|
||||||
module_map: ModuleMap,
|
|
||||||
layer_type: str,
|
|
||||||
unused_weight_names: Set[str],
|
|
||||||
dynamic: bool,
|
|
||||||
) -> Optional[AdapterWeights]:
|
|
||||||
pass
|
|
||||||
|
|
|
@ -102,22 +102,6 @@ class LoraConfig(AdapterConfig):
|
||||||
adapter_weight_names.add(lora_b_name)
|
adapter_weight_names.add(lora_b_name)
|
||||||
return module_map, adapter_weight_names
|
return module_map, adapter_weight_names
|
||||||
|
|
||||||
def load_batched_adapter_weights(
|
|
||||||
self,
|
|
||||||
model: "Model",
|
|
||||||
module_map: Dict[str, Dict],
|
|
||||||
layer_type: str,
|
|
||||||
unused_weight_names: Set[str],
|
|
||||||
dynamic: bool,
|
|
||||||
) -> Optional[AdapterWeights]:
|
|
||||||
return LoraWeights.load(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
module_map,
|
|
||||||
layer_type,
|
|
||||||
unused_weight_names,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
|
def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
|
||||||
hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
|
hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
|
||||||
|
@ -192,22 +176,38 @@ class LoraWeights(AdapterWeights):
|
||||||
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
|
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
|
||||||
return [BatchLoraWeights]
|
return [BatchLoraWeights]
|
||||||
|
|
||||||
|
# prepare pre-loaded lora weights for use in the model.
|
||||||
|
#
|
||||||
|
# this method processes and organizes lora weights for a specific layer type across all layers:
|
||||||
|
# - uses `config` (LoraConfig) to apply lora-specific settings like scaling factor.
|
||||||
|
# - retrieves weights from `module_map` based on the `layer_type`.
|
||||||
|
# - processes `nlayers` number of layers.
|
||||||
|
# - converts weights to the specified `dtype`.
|
||||||
|
# - shards weights across `world_size` number of processes using the `process_group`.
|
||||||
|
# - maps weights to specific layers using `target_to_layer`.
|
||||||
|
# - tracks `unused_weight_names` to identify any unused weights.
|
||||||
|
#
|
||||||
|
# the method handles weight transposition, scaling, and padding to ensure compatibility
|
||||||
|
# with SGMV or BGMV operations.
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def prepare_weights(
|
||||||
cls,
|
cls,
|
||||||
config: LoraConfig,
|
config: LoraConfig,
|
||||||
model: "Model",
|
|
||||||
module_map: Dict[str, Dict],
|
module_map: Dict[str, Dict],
|
||||||
layer_type: str,
|
layer_type: str,
|
||||||
unused_weight_names: Set[str],
|
unused_weight_names: Set[str],
|
||||||
|
nlayers: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
world_size: int,
|
||||||
|
process_group: ProcessGroup,
|
||||||
|
target_to_layer: Dict[str, Tuple[str, torch.Tensor]],
|
||||||
) -> Optional[AdapterWeights]:
|
) -> Optional[AdapterWeights]:
|
||||||
nlayers = model.get_num_layers_for_type(layer_type)
|
|
||||||
lora_a_list = [None] * nlayers
|
lora_a_list = [None] * nlayers
|
||||||
lora_b_list = [None] * nlayers
|
lora_b_list = [None] * nlayers
|
||||||
|
|
||||||
for layer_id in range(nlayers):
|
for layer_id in range(nlayers):
|
||||||
key = (layer_id, layer_type)
|
key = (layer_id, layer_type)
|
||||||
weight_name, layer = model.target_to_layer[key]
|
weight_name, layer = target_to_layer[key]
|
||||||
base_weight = layer.base_layer.linear.weight
|
base_weight = layer.base_layer.linear.weight
|
||||||
base_device = base_weight.device
|
base_device = base_weight.device
|
||||||
|
|
||||||
|
@ -216,10 +216,10 @@ class LoraWeights(AdapterWeights):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
lora_a, lora_a_name = module_map[weight_name]["lora_A"]
|
lora_a, lora_a_name = module_map[weight_name]["lora_A"]
|
||||||
lora_a = lora_a.to(base_device, model.dtype)
|
lora_a = lora_a.to(base_device, dtype)
|
||||||
|
|
||||||
lora_b, lora_b_name = module_map[weight_name]["lora_B"]
|
lora_b, lora_b_name = module_map[weight_name]["lora_B"]
|
||||||
lora_b = lora_b.to(base_device, model.dtype)
|
lora_b = lora_b.to(base_device, dtype)
|
||||||
|
|
||||||
scale = get_scaling_factor(
|
scale = get_scaling_factor(
|
||||||
config.lora_alpha,
|
config.lora_alpha,
|
||||||
|
@ -236,12 +236,8 @@ class LoraWeights(AdapterWeights):
|
||||||
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
|
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
|
||||||
|
|
||||||
# pad lora ranks to be compatible with sgmv
|
# pad lora ranks to be compatible with sgmv
|
||||||
lora_a_list = [
|
lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list]
|
||||||
pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list
|
lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list]
|
||||||
]
|
|
||||||
lora_b_list = [
|
|
||||||
pad_rank(w, dim=0, world_size=model.world_size) for w in lora_b_list
|
|
||||||
]
|
|
||||||
|
|
||||||
if lora_a_list:
|
if lora_a_list:
|
||||||
# update rank if it was padded
|
# update rank if it was padded
|
||||||
|
@ -252,8 +248,8 @@ class LoraWeights(AdapterWeights):
|
||||||
*shard_lora_weights(
|
*shard_lora_weights(
|
||||||
weights_a=lora_a_list,
|
weights_a=lora_a_list,
|
||||||
weights_b=lora_b_list,
|
weights_b=lora_b_list,
|
||||||
split_dim=0 if model.is_row_parallel(layer_type) else 1,
|
split_dim=0 if layer_type in {"o_proj", "down_proj", "lm_head"} else 1,
|
||||||
process_group=model.process_group,
|
process_group=process_group,
|
||||||
),
|
),
|
||||||
config,
|
config,
|
||||||
)
|
)
|
||||||
|
@ -293,10 +289,6 @@ class BatchLoraWeights(BatchAdapterWeights):
|
||||||
for rank_data in self.rank_data.values()
|
for rank_data in self.rank_data.values()
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def key(cls) -> str:
|
|
||||||
return "lora"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -42,10 +42,6 @@ class BatchAdapterWeights(ABC):
|
||||||
def has_adapter(self, adapter_index: int) -> bool:
|
def has_adapter(self, adapter_index: int) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractclassmethod
|
|
||||||
def key(cls) -> str:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractclassmethod
|
@abstractclassmethod
|
||||||
def load(
|
def load(
|
||||||
cls,
|
cls,
|
||||||
|
@ -71,13 +67,6 @@ class LayerAdapterWeights:
|
||||||
return
|
return
|
||||||
del self.adapter_weights[adapter_idx]
|
del self.adapter_weights[adapter_idx]
|
||||||
|
|
||||||
@property
|
|
||||||
def max_speculative_tokens(self) -> int:
|
|
||||||
return max(
|
|
||||||
adapter_weights.speculative_tokens
|
|
||||||
for adapter_weights in self.adapter_weights.values()
|
|
||||||
)
|
|
||||||
|
|
||||||
def is_empty(self) -> bool:
|
def is_empty(self) -> bool:
|
||||||
return len(self.adapter_weights) == 0
|
return len(self.adapter_weights) == 0
|
||||||
|
|
||||||
|
@ -101,7 +90,7 @@ class LayerAdapterWeights:
|
||||||
adapter_weights, meta, prefill, prefill_head_indices
|
adapter_weights, meta, prefill, prefill_head_indices
|
||||||
)
|
)
|
||||||
if batched_weights is not None:
|
if batched_weights is not None:
|
||||||
batch_data[batch_type.key()] = batched_weights
|
batch_data = batched_weights
|
||||||
return batch_data
|
return batch_data
|
||||||
|
|
||||||
|
|
||||||
|
@ -133,8 +122,7 @@ class AdapterBatchData:
|
||||||
def ranks(self) -> Set[int]:
|
def ranks(self) -> Set[int]:
|
||||||
# TODO(travis): refactor to be less coupled to lora implementation
|
# TODO(travis): refactor to be less coupled to lora implementation
|
||||||
ranks = set()
|
ranks = set()
|
||||||
for layer_data in self.data.values():
|
for lora_data in self.data.values():
|
||||||
lora_data = layer_data.get("lora")
|
|
||||||
if lora_data is None:
|
if lora_data is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
@ -4,9 +4,10 @@ import typer
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Optional
|
from typing import Optional, List, Dict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
from text_generation_server.utils.adapter import parse_lora_adapters
|
||||||
|
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
|
@ -80,28 +81,19 @@ def serve(
|
||||||
if otlp_endpoint is not None:
|
if otlp_endpoint is not None:
|
||||||
setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
|
setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
|
||||||
|
|
||||||
lora_adapter_ids = os.getenv("LORA_ADAPTERS", None)
|
lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS"))
|
||||||
|
|
||||||
# split on comma and strip whitespace
|
|
||||||
lora_adapter_ids = (
|
|
||||||
[x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else []
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(lora_adapter_ids) > 0:
|
|
||||||
log_master(
|
|
||||||
logger.warning,
|
|
||||||
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
|
# TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
|
||||||
# and warn the user
|
# and warn the user
|
||||||
if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None:
|
if lora_adapters:
|
||||||
log_master(
|
logger.warning("LoRA adapters enabled (experimental feature).")
|
||||||
logger.warning,
|
|
||||||
f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs.",
|
if "CUDA_GRAPHS" in os.environ:
|
||||||
)
|
logger.warning(
|
||||||
global CUDA_GRAPHS
|
"LoRA adapters incompatible with CUDA Graphs. Disabling CUDA Graphs."
|
||||||
CUDA_GRAPHS = None
|
)
|
||||||
|
global CUDA_GRAPHS
|
||||||
|
CUDA_GRAPHS = None
|
||||||
|
|
||||||
# Downgrade enum into str for easier management later on
|
# Downgrade enum into str for easier management later on
|
||||||
quantize = None if quantize is None else quantize.value
|
quantize = None if quantize is None else quantize.value
|
||||||
|
@ -117,7 +109,7 @@ def serve(
|
||||||
)
|
)
|
||||||
server.serve(
|
server.serve(
|
||||||
model_id,
|
model_id,
|
||||||
lora_adapter_ids,
|
lora_adapters,
|
||||||
revision,
|
revision,
|
||||||
sharded,
|
sharded,
|
||||||
quantize,
|
quantize,
|
||||||
|
|
|
@ -43,10 +43,7 @@ class LoraLinear(nn.Module):
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if adapter_data is None:
|
if adapter_data is None:
|
||||||
return result
|
return result
|
||||||
data = adapter_data.data.get(layer_type)
|
data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type)
|
||||||
data: Optional["BatchLoraWeights"] = (
|
|
||||||
data.get("lora") if data is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
|
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
|
||||||
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
|
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
|
||||||
|
|
|
@ -6,7 +6,7 @@ from loguru import logger
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.models.auto import modeling_auto
|
from transformers.models.auto import modeling_auto
|
||||||
from huggingface_hub import hf_hub_download, HfApi
|
from huggingface_hub import hf_hub_download, HfApi
|
||||||
from typing import Optional, List
|
from typing import Optional, List, Dict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||||
|
@ -33,6 +33,16 @@ from text_generation_server.models.custom_modeling.t5_modeling import (
|
||||||
T5ForConditionalGeneration,
|
T5ForConditionalGeneration,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
from text_generation_server.utils.adapter import (
|
||||||
|
AdapterParameters,
|
||||||
|
build_layer_weight_lookup,
|
||||||
|
load_and_merge_adapters,
|
||||||
|
AdapterInfo,
|
||||||
|
)
|
||||||
|
from text_generation_server.adapters.lora import LoraWeights
|
||||||
|
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
|
@ -50,7 +60,7 @@ __all__ = [
|
||||||
"Model",
|
"Model",
|
||||||
"CausalLM",
|
"CausalLM",
|
||||||
"Seq2SeqLM",
|
"Seq2SeqLM",
|
||||||
"get_model",
|
"get_model_with_lora_adapters",
|
||||||
]
|
]
|
||||||
|
|
||||||
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||||
|
@ -1115,3 +1125,116 @@ def get_model(
|
||||||
)
|
)
|
||||||
|
|
||||||
raise ValueError(f"Unsupported model type {model_type}")
|
raise ValueError(f"Unsupported model type {model_type}")
|
||||||
|
|
||||||
|
|
||||||
|
# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters
|
||||||
|
# this provides a post model loading hook to load adapters into the model after the model has been loaded
|
||||||
|
def get_model_with_lora_adapters(
|
||||||
|
model_id: str,
|
||||||
|
lora_adapters: Optional[List[AdapterInfo]],
|
||||||
|
revision: Optional[str],
|
||||||
|
sharded: bool,
|
||||||
|
quantize: Optional[str],
|
||||||
|
speculate: Optional[int],
|
||||||
|
dtype: Optional[str],
|
||||||
|
trust_remote_code: bool,
|
||||||
|
max_input_tokens: int,
|
||||||
|
adapter_to_index: Dict[str, int],
|
||||||
|
):
|
||||||
|
lora_adapter_ids = [adapter.id for adapter in lora_adapters]
|
||||||
|
model = get_model(
|
||||||
|
model_id,
|
||||||
|
lora_adapter_ids,
|
||||||
|
revision,
|
||||||
|
sharded,
|
||||||
|
quantize,
|
||||||
|
speculate,
|
||||||
|
dtype,
|
||||||
|
trust_remote_code,
|
||||||
|
max_input_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(lora_adapters) > 0:
|
||||||
|
target_to_layer = build_layer_weight_lookup(model.model)
|
||||||
|
|
||||||
|
for index, adapter in enumerate(lora_adapters):
|
||||||
|
# The AdapterParameters object allows for merging multiple adapters into a single adapter.
|
||||||
|
# At the moment, we only support loading a single adapter into the model, but we keep the
|
||||||
|
# AdapterParameters object for easier extension in the future.
|
||||||
|
adapter_parameters = AdapterParameters(
|
||||||
|
adapter_info=[adapter],
|
||||||
|
# when merging multiple adapters we can weight them differently
|
||||||
|
# if this is not set, all adapters will be weighted equally
|
||||||
|
# see: text_generation_server.utils.merges.strategies for impl
|
||||||
|
weights=None,
|
||||||
|
merge_strategy=0,
|
||||||
|
density=1.0,
|
||||||
|
majority_sign_method=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter_index = index + 1
|
||||||
|
adapter_to_index[adapter.id] = adapter_index
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}"
|
||||||
|
)
|
||||||
|
weight_names = tuple([v[0] for v in target_to_layer.values()])
|
||||||
|
(
|
||||||
|
module_map,
|
||||||
|
adapter_config,
|
||||||
|
adapter_weight_names,
|
||||||
|
adapter_tokenizer,
|
||||||
|
) = load_and_merge_adapters(
|
||||||
|
model.model_id,
|
||||||
|
adapter_parameters,
|
||||||
|
adapter_index,
|
||||||
|
weight_names,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
unused_weight_names = adapter_weight_names.copy()
|
||||||
|
|
||||||
|
adapter_layers = [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
"o_proj",
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
"down_proj",
|
||||||
|
]
|
||||||
|
|
||||||
|
for layer_name in adapter_layers:
|
||||||
|
nlayers = (
|
||||||
|
1 if layer_name == "lm_head" else len(model.model.model.layers)
|
||||||
|
)
|
||||||
|
adapter_weights = LoraWeights.prepare_weights(
|
||||||
|
config=adapter_config,
|
||||||
|
module_map=module_map,
|
||||||
|
layer_type=layer_name,
|
||||||
|
unused_weight_names=unused_weight_names,
|
||||||
|
nlayers=nlayers,
|
||||||
|
dtype=model.dtype,
|
||||||
|
world_size=model.world_size,
|
||||||
|
process_group=model.process_group,
|
||||||
|
target_to_layer=target_to_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
if adapter_weights is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
model.layer_to_adapter_weights[layer_name].add_adapter(
|
||||||
|
adapter_index, adapter_weights
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(unused_weight_names) > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if adapter_tokenizer is not None:
|
||||||
|
model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)
|
||||||
|
|
||||||
|
model.loaded_adapters.add(adapter_index)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import itertools
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
@ -1700,72 +1699,3 @@ class FlashCausalLM(Model):
|
||||||
forward_ns = start_decode - start
|
forward_ns = start_decode - start
|
||||||
decode_ns = time.time_ns() - start_decode
|
decode_ns = time.time_ns() - start_decode
|
||||||
return generations, batch, (forward_ns, decode_ns)
|
return generations, batch, (forward_ns, decode_ns)
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_adapter_loading(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
|
|
||||||
layer_weights = {}
|
|
||||||
|
|
||||||
prefix = "model.layers"
|
|
||||||
|
|
||||||
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
|
|
||||||
# that have a language_model inside of the larger model.
|
|
||||||
if hasattr(self.model, "language_model"):
|
|
||||||
_model = self.model.language_model
|
|
||||||
elif hasattr(self.model, "text_model"):
|
|
||||||
_model = self.model.text_model
|
|
||||||
else:
|
|
||||||
_model = self.model
|
|
||||||
|
|
||||||
for i, layer in enumerate(_model.model.layers):
|
|
||||||
layer_weights[(i, "q_proj")] = (
|
|
||||||
f"{prefix}.{i}.self_attn.q_proj",
|
|
||||||
layer.self_attn.query_key_value,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "k_proj")] = (
|
|
||||||
f"{prefix}.{i}.self_attn.k_proj",
|
|
||||||
layer.self_attn.query_key_value,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "v_proj")] = (
|
|
||||||
f"{prefix}.{i}.self_attn.v_proj",
|
|
||||||
layer.self_attn.query_key_value,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "o_proj")] = (
|
|
||||||
f"{prefix}.{i}.self_attn.o_proj",
|
|
||||||
layer.self_attn.o_proj,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: this is a hack to avoid the gate_proj for
|
|
||||||
# FlashStarcoder2 that doesnt have these layers
|
|
||||||
if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"):
|
|
||||||
layer_weights[(i, "gate_proj")] = (
|
|
||||||
f"{prefix}.{i}.mlp.gate_proj",
|
|
||||||
layer.mlp.gate_up_proj,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "up_proj")] = (
|
|
||||||
f"{prefix}.{i}.mlp.up_proj",
|
|
||||||
layer.mlp.gate_up_proj,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "down_proj")] = (
|
|
||||||
f"{prefix}.{i}.mlp.down_proj",
|
|
||||||
layer.mlp.down_proj,
|
|
||||||
)
|
|
||||||
|
|
||||||
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
|
|
||||||
return layer_weights
|
|
||||||
|
|
||||||
@property
|
|
||||||
def adapter_layers(self) -> List[str]:
|
|
||||||
return ADAPTER_LAYERS
|
|
||||||
|
|
||||||
@property
|
|
||||||
def default_traced_adapter_layers(self) -> List[str]:
|
|
||||||
return ["q_proj", "v_proj"]
|
|
||||||
|
|
||||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
|
||||||
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
|
|
||||||
|
|
||||||
def is_row_parallel(self, layer_type: str) -> bool:
|
|
||||||
return layer_type in ROW_PARALLEL
|
|
||||||
|
|
|
@ -1,85 +0,0 @@
|
||||||
import torch
|
|
||||||
from typing import Optional, Tuple, Dict, List
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
|
|
||||||
|
|
||||||
ADAPTER_LAYERS = [
|
|
||||||
"q_proj",
|
|
||||||
"k_proj",
|
|
||||||
"v_proj",
|
|
||||||
"o_proj",
|
|
||||||
"gate_proj",
|
|
||||||
"up_proj",
|
|
||||||
"down_proj",
|
|
||||||
]
|
|
||||||
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
|
|
||||||
|
|
||||||
|
|
||||||
class FlashMistral(FlashCausalLM):
|
|
||||||
@property
|
|
||||||
def supports_adapter_loading(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
|
|
||||||
layer_weights = {}
|
|
||||||
|
|
||||||
prefix = "model.layers"
|
|
||||||
|
|
||||||
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
|
|
||||||
# that have a language_model inside of the larger model.
|
|
||||||
if hasattr(self.model, "text_model"):
|
|
||||||
_model = self.model.text_model
|
|
||||||
else:
|
|
||||||
_model = self.model
|
|
||||||
|
|
||||||
for i, layer in enumerate(_model.model.layers):
|
|
||||||
layer_weights[(i, "q_proj")] = (
|
|
||||||
f"{prefix}.{i}.self_attn.q_proj",
|
|
||||||
layer.self_attn.query_key_value,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "k_proj")] = (
|
|
||||||
f"{prefix}.{i}.self_attn.k_proj",
|
|
||||||
layer.self_attn.query_key_value,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "v_proj")] = (
|
|
||||||
f"{prefix}.{i}.self_attn.v_proj",
|
|
||||||
layer.self_attn.query_key_value,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "o_proj")] = (
|
|
||||||
f"{prefix}.{i}.self_attn.o_proj",
|
|
||||||
layer.self_attn.o_proj,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: this is a hack to avoid the gate_proj for
|
|
||||||
# FlashStarcoder2 that doesnt have these layers
|
|
||||||
if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"):
|
|
||||||
layer_weights[(i, "gate_proj")] = (
|
|
||||||
f"{prefix}.{i}.mlp.gate_proj",
|
|
||||||
layer.mlp.gate_up_proj,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "up_proj")] = (
|
|
||||||
f"{prefix}.{i}.mlp.up_proj",
|
|
||||||
layer.mlp.gate_up_proj,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "down_proj")] = (
|
|
||||||
f"{prefix}.{i}.mlp.down_proj",
|
|
||||||
layer.mlp.down_proj,
|
|
||||||
)
|
|
||||||
|
|
||||||
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
|
|
||||||
return layer_weights
|
|
||||||
|
|
||||||
@property
|
|
||||||
def adapter_layers(self) -> List[str]:
|
|
||||||
return ADAPTER_LAYERS
|
|
||||||
|
|
||||||
@property
|
|
||||||
def default_traced_adapter_layers(self) -> List[str]:
|
|
||||||
return ["q_proj", "v_proj"]
|
|
||||||
|
|
||||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
|
||||||
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
|
|
||||||
|
|
||||||
def is_row_parallel(self, layer_type: str) -> bool:
|
|
||||||
return layer_type in ROW_PARALLEL
|
|
|
@ -4,20 +4,12 @@ import torch
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict
|
from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from text_generation_server.models.types import Batch, Generation
|
from text_generation_server.models.types import Batch, Generation
|
||||||
from text_generation_server.utils.speculate import get_speculate
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||||
from text_generation_server.adapters.weights import LayerAdapterWeights
|
from text_generation_server.adapters.weights import LayerAdapterWeights
|
||||||
from text_generation_server.utils.adapter import (
|
|
||||||
load_and_merge_adapters,
|
|
||||||
AdapterParameters,
|
|
||||||
AdapterSource,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.log import log_master
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
|
|
||||||
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
||||||
|
|
||||||
|
@ -61,7 +53,6 @@ class Model(ABC):
|
||||||
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
|
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
|
||||||
LayerAdapterWeights
|
LayerAdapterWeights
|
||||||
)
|
)
|
||||||
self.target_to_layer = None
|
|
||||||
self.loaded_adapters = set()
|
self.loaded_adapters = set()
|
||||||
self.static_adapter_id = adapter_id
|
self.static_adapter_id = adapter_id
|
||||||
|
|
||||||
|
@ -142,140 +133,3 @@ class Model(ABC):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
|
f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_adapter_loading(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def adapter_layers(self) -> List[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
@property
|
|
||||||
def default_traced_adapter_layers(self) -> List[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def is_row_parallel(self, layer_type: str) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def max_speculative_tokens(self) -> int:
|
|
||||||
return max(
|
|
||||||
[
|
|
||||||
weights.max_speculative_tokens
|
|
||||||
for weights in self.layer_to_adapter_weights.values()
|
|
||||||
],
|
|
||||||
default=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_adapter(
|
|
||||||
self,
|
|
||||||
adapter_parameters: AdapterParameters,
|
|
||||||
adapter_source: AdapterSource,
|
|
||||||
adapter_index: int,
|
|
||||||
api_token: str,
|
|
||||||
dynamic: bool = True,
|
|
||||||
):
|
|
||||||
"""Loads adapter weights from disk / host memory on the GPU.
|
|
||||||
|
|
||||||
adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded
|
|
||||||
into model. Otherwise, the adapter weights are applied during the forward
|
|
||||||
pass and stored separately from the base model parameters.
|
|
||||||
"""
|
|
||||||
if self.target_to_layer is None:
|
|
||||||
self.target_to_layer = self.adapter_target_to_layer()
|
|
||||||
if adapter_index in self.loaded_adapters:
|
|
||||||
# Adapter already loaded
|
|
||||||
return
|
|
||||||
|
|
||||||
if not self.supports_adapter_loading:
|
|
||||||
raise ValueError("This model does not support adapter loading.")
|
|
||||||
|
|
||||||
if dynamic and not self.dynamic_adapter_loading_enabled:
|
|
||||||
raise ValueError(
|
|
||||||
f"This model was initialized with the adapter {self.static_adapter_id} "
|
|
||||||
f"and therefore does not support dynamic adapter loading. "
|
|
||||||
f"Please initialize a new model instance from the base model in "
|
|
||||||
f"order to use the dynamic adapter loading feature."
|
|
||||||
)
|
|
||||||
|
|
||||||
log_master(
|
|
||||||
logger.info,
|
|
||||||
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}",
|
|
||||||
)
|
|
||||||
weight_names = tuple([v[0] for v in self.target_to_layer.values()])
|
|
||||||
(
|
|
||||||
module_map,
|
|
||||||
adapter_config,
|
|
||||||
adapter_weight_names,
|
|
||||||
adapter_tokenizer,
|
|
||||||
) = load_and_merge_adapters(
|
|
||||||
self.model_id,
|
|
||||||
adapter_parameters,
|
|
||||||
adapter_source,
|
|
||||||
adapter_index,
|
|
||||||
weight_names,
|
|
||||||
api_token,
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
|
|
||||||
unused_weight_names = adapter_weight_names.copy()
|
|
||||||
for layer_name in self.adapter_layers:
|
|
||||||
adapter_weights = adapter_config.load_batched_adapter_weights(
|
|
||||||
self,
|
|
||||||
module_map,
|
|
||||||
layer_name,
|
|
||||||
unused_weight_names,
|
|
||||||
dynamic,
|
|
||||||
)
|
|
||||||
|
|
||||||
if adapter_weights is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
layer_weights = self.layer_to_adapter_weights[layer_name]
|
|
||||||
layer_weights.add_adapter(adapter_index, adapter_weights)
|
|
||||||
|
|
||||||
if len(unused_weight_names) > 0:
|
|
||||||
log_master(
|
|
||||||
logger.warning,
|
|
||||||
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}",
|
|
||||||
)
|
|
||||||
|
|
||||||
if adapter_tokenizer is not None:
|
|
||||||
self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)
|
|
||||||
|
|
||||||
self.loaded_adapters.add(adapter_index)
|
|
||||||
|
|
||||||
def offload_adapter(
|
|
||||||
self,
|
|
||||||
adapter_parameters: AdapterParameters,
|
|
||||||
adapter_source: AdapterSource,
|
|
||||||
adapter_index: int,
|
|
||||||
):
|
|
||||||
"""Offloads the adapter weights from GPU to CPU or disk."""
|
|
||||||
if adapter_index not in self.loaded_adapters:
|
|
||||||
# Adapter already offloaded
|
|
||||||
return
|
|
||||||
|
|
||||||
if not self.supports_adapter_loading:
|
|
||||||
raise ValueError("This model does not support adapter loading.")
|
|
||||||
|
|
||||||
if not self.dynamic_adapter_loading_enabled:
|
|
||||||
raise ValueError(
|
|
||||||
f"This model was initialized with the adapter {self.static_adapter_id} "
|
|
||||||
f"and therefore does not support dynamic adapter loading. "
|
|
||||||
f"Please initialize a new model instance from the base model in "
|
|
||||||
f"order to use the dynamic adapter loading feature."
|
|
||||||
)
|
|
||||||
|
|
||||||
for layer_name in self.adapter_layers:
|
|
||||||
if layer_name in self.layer_to_adapter_weights:
|
|
||||||
self.layer_to_adapter_weights[layer_name].remove_adapter(adapter_index)
|
|
||||||
|
|
||||||
self.loaded_adapters.remove(adapter_index)
|
|
||||||
|
|
|
@ -9,11 +9,12 @@ from loguru import logger
|
||||||
|
|
||||||
from grpc_reflection.v1alpha import reflection
|
from grpc_reflection.v1alpha import reflection
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Dict
|
||||||
|
|
||||||
from text_generation_server.cache import Cache
|
from text_generation_server.cache import Cache
|
||||||
from text_generation_server.interceptor import ExceptionInterceptor
|
from text_generation_server.interceptor import ExceptionInterceptor
|
||||||
from text_generation_server.models import Model, get_model
|
from text_generation_server.models import Model, get_model_with_lora_adapters
|
||||||
|
from text_generation_server.utils.adapter import AdapterInfo
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
||||||
|
@ -30,9 +31,6 @@ except (ImportError, NotImplementedError):
|
||||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||||
from text_generation_server.models.globals import set_model_id, set_adapter_to_index
|
from text_generation_server.models.globals import set_model_id, set_adapter_to_index
|
||||||
from text_generation_server.utils.adapter import (
|
|
||||||
AdapterParameters,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SignalHandler:
|
class SignalHandler:
|
||||||
|
@ -195,7 +193,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
|
|
||||||
def serve(
|
def serve(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
lora_adapter_ids: Optional[List[str]],
|
lora_adapters: Optional[List[AdapterInfo]],
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool,
|
sharded: bool,
|
||||||
quantize: Optional[str],
|
quantize: Optional[str],
|
||||||
|
@ -207,7 +205,7 @@ def serve(
|
||||||
):
|
):
|
||||||
async def serve_inner(
|
async def serve_inner(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
lora_adapter_ids: Optional[List[str]],
|
lora_adapters: Optional[List[AdapterInfo]],
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool = False,
|
sharded: bool = False,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
@ -228,9 +226,9 @@ def serve(
|
||||||
server_urls = [local_url]
|
server_urls = [local_url]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model = get_model(
|
model = get_model_with_lora_adapters(
|
||||||
model_id,
|
model_id,
|
||||||
lora_adapter_ids,
|
lora_adapters,
|
||||||
revision,
|
revision,
|
||||||
sharded,
|
sharded,
|
||||||
quantize,
|
quantize,
|
||||||
|
@ -238,29 +236,9 @@ def serve(
|
||||||
dtype,
|
dtype,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
|
adapter_to_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(lora_adapter_ids) > 0:
|
|
||||||
for index, adapter_id in enumerate(lora_adapter_ids):
|
|
||||||
# TODO: improve non merged adapter loading and long term
|
|
||||||
# improve adapter loading as a whole
|
|
||||||
adapter_parameters = AdapterParameters(
|
|
||||||
adapter_ids=[adapter_id],
|
|
||||||
weights=None, # will be set to 1
|
|
||||||
merge_strategy=0,
|
|
||||||
density=1.0,
|
|
||||||
majority_sign_method=0,
|
|
||||||
)
|
|
||||||
adapter_index = index + 1
|
|
||||||
adapter_to_index[adapter_id] = adapter_index
|
|
||||||
model.load_adapter(
|
|
||||||
adapter_parameters,
|
|
||||||
None, # adapter_source
|
|
||||||
adapter_index,
|
|
||||||
None, # api_token
|
|
||||||
False, # dynamic
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error when initializing model")
|
logger.exception("Error when initializing model")
|
||||||
raise
|
raise
|
||||||
|
@ -297,7 +275,7 @@ def serve(
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
serve_inner(
|
serve_inner(
|
||||||
model_id,
|
model_id,
|
||||||
lora_adapter_ids,
|
lora_adapters,
|
||||||
revision,
|
revision,
|
||||||
sharded,
|
sharded,
|
||||||
quantize,
|
quantize,
|
||||||
|
|
|
@ -5,12 +5,11 @@
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import TYPE_CHECKING, Set, Tuple
|
from typing import TYPE_CHECKING, Set, Tuple, Optional, List
|
||||||
|
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
||||||
|
|
||||||
from text_generation_server.pb import generate_pb2
|
|
||||||
from text_generation_server.utils.merges.strategies import merge_adapters
|
from text_generation_server.utils.merges.strategies import merge_adapters
|
||||||
|
|
||||||
from text_generation_server.utils import hub
|
from text_generation_server.utils import hub
|
||||||
|
@ -24,9 +23,15 @@ if TYPE_CHECKING:
|
||||||
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AdapterInfo:
|
||||||
|
id: str
|
||||||
|
path: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AdapterParameters:
|
class AdapterParameters:
|
||||||
adapter_ids: Tuple[str]
|
adapter_info: Tuple[AdapterInfo]
|
||||||
weights: Tuple[float]
|
weights: Tuple[float]
|
||||||
merge_strategy: NotImplemented
|
merge_strategy: NotImplemented
|
||||||
density: float
|
density: float
|
||||||
|
@ -40,37 +45,47 @@ class AdapterSource:
|
||||||
revision: str
|
revision: str
|
||||||
|
|
||||||
|
|
||||||
|
def parse_lora_adapters(lora_adapters: Optional[str]) -> List[AdapterInfo]:
|
||||||
|
if not lora_adapters:
|
||||||
|
return []
|
||||||
|
|
||||||
|
adapter_list = []
|
||||||
|
for adapter in lora_adapters.split(","):
|
||||||
|
parts = adapter.strip().split("=")
|
||||||
|
if len(parts) == 1:
|
||||||
|
adapter_list.append(AdapterInfo(id=parts[0], path=None))
|
||||||
|
elif len(parts) == 2:
|
||||||
|
adapter_list.append(AdapterInfo(id=parts[0], path=parts[1]))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid LoRA adapter format: {adapter}")
|
||||||
|
return adapter_list
|
||||||
|
|
||||||
|
|
||||||
def load_and_merge_adapters(
|
def load_and_merge_adapters(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
adapter_parameters: AdapterParameters,
|
adapter_parameters: AdapterParameters,
|
||||||
adapter_source: str,
|
|
||||||
adapter_index: int,
|
adapter_index: int,
|
||||||
weight_names: Tuple[str],
|
weight_names: Tuple[str],
|
||||||
api_token: str,
|
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||||
if len(adapter_parameters.adapter_ids) == 1:
|
|
||||||
|
if len(adapter_parameters.adapter_info) == 1:
|
||||||
|
adapter_info = next(iter(adapter_parameters.adapter_info))
|
||||||
return load_module_map(
|
return load_module_map(
|
||||||
model_id,
|
model_id,
|
||||||
adapter_parameters.adapter_ids[0],
|
adapter_info.id,
|
||||||
adapter_source,
|
adapter_info.path,
|
||||||
weight_names,
|
weight_names,
|
||||||
api_token,
|
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
adapter_params = AdapterParametersContainer(
|
adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index)
|
||||||
adapter_parameters, adapter_source, adapter_index
|
return _load_and_merge(model_id, adapter_params, weight_names, trust_remote_code)
|
||||||
)
|
|
||||||
return _load_and_merge(
|
|
||||||
model_id, adapter_params, weight_names, api_token, trust_remote_code
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AdapterParametersContainer:
|
class AdapterParametersContainer:
|
||||||
adapter_parameters: AdapterParameters
|
adapter_parameters: AdapterParameters
|
||||||
adapter_source: str
|
|
||||||
adapter_index: int
|
adapter_index: int
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
|
@ -82,7 +97,6 @@ def _load_and_merge(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
adapter_params: AdapterParametersContainer,
|
adapter_params: AdapterParametersContainer,
|
||||||
weight_names: Tuple[str],
|
weight_names: Tuple[str],
|
||||||
api_token: str,
|
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||||
params = adapter_params.adapter_parameters
|
params = adapter_params.adapter_parameters
|
||||||
|
@ -90,17 +104,16 @@ def _load_and_merge(
|
||||||
adapters_to_merge = []
|
adapters_to_merge = []
|
||||||
merged_weight_names = set()
|
merged_weight_names = set()
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
for adapter_id in params.adapter_ids:
|
for adapter in params.adapter_info:
|
||||||
if adapter_id == BASE_MODEL_ADAPTER_ID:
|
if adapter.id == BASE_MODEL_ADAPTER_ID:
|
||||||
raise ValueError("Base model adapter cannot be merged.")
|
raise ValueError("Base model adapter cannot be merged.")
|
||||||
|
|
||||||
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
|
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
|
||||||
load_module_map(
|
load_module_map(
|
||||||
model_id,
|
model_id,
|
||||||
adapter_id,
|
adapter.id,
|
||||||
adapter_params.adapter_source,
|
adapter.path,
|
||||||
weight_names,
|
weight_names,
|
||||||
api_token,
|
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -159,25 +172,28 @@ def check_architectures(
|
||||||
def load_module_map(
|
def load_module_map(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
adapter_id: str,
|
adapter_id: str,
|
||||||
adapter_source: str,
|
adapter_path: Optional[str],
|
||||||
weight_names: Tuple[str],
|
weight_names: Tuple[str],
|
||||||
api_token: str,
|
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||||
revision = "main"
|
revision = "main"
|
||||||
|
|
||||||
adapter_config = LoraConfig.load(adapter_id, api_token)
|
adapter_config = LoraConfig.load(adapter_path or adapter_id, None)
|
||||||
if adapter_config.base_model_name_or_path != model_id:
|
|
||||||
|
if not adapter_path and adapter_config.base_model_name_or_path != model_id:
|
||||||
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
|
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
|
||||||
|
|
||||||
adapter_filenames = hub._cached_adapter_weight_files(
|
adapter_filenames = (
|
||||||
adapter_id, revision=revision, extension=".safetensors"
|
hub._adapter_weight_files_from_dir(adapter_path, extension=".safetensors")
|
||||||
|
if adapter_path
|
||||||
|
else hub._cached_adapter_weight_files(
|
||||||
|
adapter_id, revision=revision, extension=".safetensors"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
adapter_tokenizer = AutoTokenizer.from_pretrained(
|
adapter_tokenizer = AutoTokenizer.from_pretrained(
|
||||||
adapter_config.config_path,
|
adapter_config.config_path,
|
||||||
token=api_token,
|
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -194,3 +210,87 @@ def load_module_map(
|
||||||
adapter_weights, weight_names
|
adapter_weights, weight_names
|
||||||
)
|
)
|
||||||
return module_map, adapter_config, adapter_weight_names, adapter_tokenizer
|
return module_map, adapter_config, adapter_weight_names, adapter_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def get_attn_weights(i, layer):
|
||||||
|
qkv = layer.self_attn.query_key_value
|
||||||
|
weights = {}
|
||||||
|
|
||||||
|
for k in ["q", "k", "v"]:
|
||||||
|
key = (i, f"{k}_proj")
|
||||||
|
value = (f"model.layers.{i}.self_attn.{k}_proj", qkv)
|
||||||
|
weights[key] = value
|
||||||
|
|
||||||
|
weights[(i, "o_proj")] = (
|
||||||
|
f"model.layers.{i}.self_attn.o_proj",
|
||||||
|
layer.self_attn.o_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
return weights
|
||||||
|
|
||||||
|
|
||||||
|
def get_mlp_weights(i, layer):
|
||||||
|
weights = {}
|
||||||
|
if hasattr(layer, "mlp"):
|
||||||
|
mlp = layer.mlp
|
||||||
|
if hasattr(mlp, "gate_up_proj"):
|
||||||
|
# handle combined gate_up_proj (e.g., for some LLaMA variants)
|
||||||
|
weights.update(
|
||||||
|
{
|
||||||
|
(i, "gate_proj"): (
|
||||||
|
f"model.layers.{i}.mlp.gate_proj",
|
||||||
|
mlp.gate_up_proj,
|
||||||
|
),
|
||||||
|
(i, "up_proj"): (f"model.layers.{i}.mlp.up_proj", mlp.gate_up_proj),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# handle separate gate_proj, up_proj, and down_proj (e.g., for Gemma)
|
||||||
|
if hasattr(mlp, "gate_proj"):
|
||||||
|
weights[(i, "gate_proj")] = (
|
||||||
|
f"model.layers.{i}.mlp.gate_proj",
|
||||||
|
mlp.gate_proj,
|
||||||
|
)
|
||||||
|
if hasattr(mlp, "up_proj"):
|
||||||
|
weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj)
|
||||||
|
|
||||||
|
if hasattr(mlp, "down_proj"):
|
||||||
|
weights[(i, "down_proj")] = (
|
||||||
|
f"model.layers.{i}.mlp.down_proj",
|
||||||
|
mlp.down_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
return weights
|
||||||
|
|
||||||
|
|
||||||
|
# build_layer_weight_lookup creates a mapping of model layers to their corresponding
|
||||||
|
# weight tensors and paths. It builds a dictionary that maps layer identifiers to tuples
|
||||||
|
# containing the weight tensor path and the actual layer object. This mapping is needed
|
||||||
|
# for the lora adapter to know which weights to update when applying the adapter.
|
||||||
|
def build_layer_weight_lookup(model):
|
||||||
|
if hasattr(model, "language_model"):
|
||||||
|
m = model.language_model.model
|
||||||
|
elif hasattr(model, "text_model"):
|
||||||
|
m = model.text_model.model
|
||||||
|
else:
|
||||||
|
m = model.model
|
||||||
|
|
||||||
|
layer_weights = {}
|
||||||
|
|
||||||
|
for i, layer in enumerate(m.layers):
|
||||||
|
attn_weights = get_attn_weights(i, layer)
|
||||||
|
mlp_weights = get_mlp_weights(i, layer)
|
||||||
|
|
||||||
|
layer_weights.update(attn_weights)
|
||||||
|
layer_weights.update(mlp_weights)
|
||||||
|
|
||||||
|
lm_head = None
|
||||||
|
if hasattr(m, "lm_head"):
|
||||||
|
lm_head = m.lm_head
|
||||||
|
elif hasattr(model, "lm_head"):
|
||||||
|
lm_head = model.lm_head
|
||||||
|
|
||||||
|
if lm_head:
|
||||||
|
layer_weights[(0, "lm_head")] = ("lm_head", lm_head)
|
||||||
|
|
||||||
|
return layer_weights
|
||||||
|
|
Loading…
Reference in New Issue