diff --git a/launcher/src/main.rs b/launcher/src/main.rs index be7ae4b0..228b0e79 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1607,6 +1607,10 @@ fn main() -> Result<(), LauncherError> { // Download and convert lora adapters if any if let Some(lora_adapters) = &args.lora_adapters { for adapter in lora_adapters.split(',') { + // skip download if a path is provided + if adapter.contains('=') { + continue; + } download_convert_model( adapter, None, diff --git a/server/tests/utils/test_adapter.py b/server/tests/utils/test_adapter.py new file mode 100644 index 00000000..cc1b076d --- /dev/null +++ b/server/tests/utils/test_adapter.py @@ -0,0 +1,187 @@ +import pytest +from unittest.mock import Mock +from text_generation_server.utils.adapter import get_attn_weights, get_mlp_weights + + +def test_get_attn_weights(): + # create a mock layer + mock_layer = Mock() + mock_layer.self_attn.query_key_value = Mock() + mock_layer.self_attn.o_proj = Mock() + + # call the function + result = get_attn_weights(2, mock_layer) + + # assert the result + expected = { + (2, "q_proj"): ( + "model.layers.2.self_attn.q_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "k_proj"): ( + "model.layers.2.self_attn.k_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "v_proj"): ( + "model.layers.2.self_attn.v_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj), + } + assert result == expected + + +def test_get_mlp_weights_with_gate_up_proj(): + # create a mock layer with gate_up_proj + mock_layer = Mock() + mock_layer.mlp.gate_up_proj = Mock() + mock_layer.mlp.down_proj = Mock() + + # call the function + result = get_mlp_weights(3, mock_layer) + + # assert the result + expected = { + (3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj), + (3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj), + (3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj), + } + assert result == expected + + +def test_get_mlp_weights_without_gate_up_proj(): + # create a mock layer without gate_up_proj + mock_layer = Mock() + mock_layer.mlp = Mock(spec=[]) + + # call the function + result = get_mlp_weights(1, mock_layer) + + # assert the result + assert result == {} + + +@pytest.mark.parametrize("layer_index", [0, 1, 5]) +def test_get_attn_weights_different_layers(layer_index): + mock_layer = Mock() + mock_layer.self_attn.query_key_value = Mock() + mock_layer.self_attn.o_proj = Mock() + + result = get_attn_weights(layer_index, mock_layer) + + for k in ["q", "k", "v"]: + assert (layer_index, f"{k}_proj") in result + assert ( + result[(layer_index, f"{k}_proj")][0] + == f"model.layers.{layer_index}.self_attn.{k}_proj" + ) + + assert (layer_index, "o_proj") in result + assert ( + result[(layer_index, "o_proj")][0] + == f"model.layers.{layer_index}.self_attn.o_proj" + ) + + +@pytest.mark.parametrize("layer_index", [0, 1, 5]) +def test_get_mlp_weights_different_layers(layer_index): + mock_layer = Mock() + mock_layer.mlp.gate_up_proj = Mock() + mock_layer.mlp.down_proj = Mock() + + result = get_mlp_weights(layer_index, mock_layer) + + for k in ["gate", "up", "down"]: + assert (layer_index, f"{k}_proj") in result + assert ( + result[(layer_index, f"{k}_proj")][0] + == f"model.layers.{layer_index}.mlp.{k}_proj" + ) + + +def test_get_attn_weights_llama_compatibility(): + mock_layer = Mock() + mock_layer.self_attn.query_key_value = Mock() + mock_layer.self_attn.o_proj = Mock() + + result = get_attn_weights(2, mock_layer) + + expected = { + (2, "q_proj"): ( + "model.layers.2.self_attn.q_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "k_proj"): ( + "model.layers.2.self_attn.k_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "v_proj"): ( + "model.layers.2.self_attn.v_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj), + } + assert result == expected + + +def test_get_mlp_weights_llama_compatibility(): + mock_layer = Mock() + mock_layer.mlp.gate_up_proj = Mock() + mock_layer.mlp.down_proj = Mock() + + result = get_mlp_weights(3, mock_layer) + + expected = { + (3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj), + (3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj), + (3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj), + } + assert result == expected + + +def test_get_attn_weights_gemma_compatibility(): + mock_layer = Mock() + mock_layer.self_attn.query_key_value = Mock() + mock_layer.self_attn.o_proj = Mock() + + result = get_attn_weights(2, mock_layer) + + expected = { + (2, "q_proj"): ( + "model.layers.2.self_attn.q_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "k_proj"): ( + "model.layers.2.self_attn.k_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "v_proj"): ( + "model.layers.2.self_attn.v_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj), + } + assert result == expected + + +def test_get_mlp_weights_gemma_compatibility(): + mock_layer = Mock() + mock_layer.mlp.gate_proj = Mock() + mock_layer.mlp.up_proj = Mock() + mock_layer.mlp.down_proj = Mock() + + # ensure that the mock_layer.mlp.gate_up_proj attribute does not exist. + # This is necessary because the use of `Mock` automatically creates any + # attributes that are accessed, even if they don't exist in the actual + # implementation. If `gate_up_proj` were created, `get_mlp_weights` might + # follow the wrong execution path and return an incorrect result. + del mock_layer.mlp.gate_up_proj + + result = get_mlp_weights(3, mock_layer) + + expected = { + (3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj), + (3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj), + (3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj), + } + assert result == expected diff --git a/server/text_generation_server/adapters/config.py b/server/text_generation_server/adapters/config.py index 5261d4b5..2ee53b12 100644 --- a/server/text_generation_server/adapters/config.py +++ b/server/text_generation_server/adapters/config.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Set, Tuple import torch @@ -31,14 +31,3 @@ class AdapterConfig(ABC): weight_names: Tuple[str], ) -> Tuple[ModuleMap, Set[str]]: pass - - @abstractmethod - def load_batched_adapter_weights( - self, - model: "Model", - module_map: ModuleMap, - layer_type: str, - unused_weight_names: Set[str], - dynamic: bool, - ) -> Optional[AdapterWeights]: - pass diff --git a/server/text_generation_server/adapters/lora.py b/server/text_generation_server/adapters/lora.py index 87543be2..ac143bb7 100644 --- a/server/text_generation_server/adapters/lora.py +++ b/server/text_generation_server/adapters/lora.py @@ -102,22 +102,6 @@ class LoraConfig(AdapterConfig): adapter_weight_names.add(lora_b_name) return module_map, adapter_weight_names - def load_batched_adapter_weights( - self, - model: "Model", - module_map: Dict[str, Dict], - layer_type: str, - unused_weight_names: Set[str], - dynamic: bool, - ) -> Optional[AdapterWeights]: - return LoraWeights.load( - self, - model, - module_map, - layer_type, - unused_weight_names, - ) - @classmethod def load(cls, adapter_id: str, api_token: str) -> "LoraConfig": hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token) @@ -192,22 +176,38 @@ class LoraWeights(AdapterWeights): def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]: return [BatchLoraWeights] + # prepare pre-loaded lora weights for use in the model. + # + # this method processes and organizes lora weights for a specific layer type across all layers: + # - uses `config` (LoraConfig) to apply lora-specific settings like scaling factor. + # - retrieves weights from `module_map` based on the `layer_type`. + # - processes `nlayers` number of layers. + # - converts weights to the specified `dtype`. + # - shards weights across `world_size` number of processes using the `process_group`. + # - maps weights to specific layers using `target_to_layer`. + # - tracks `unused_weight_names` to identify any unused weights. + # + # the method handles weight transposition, scaling, and padding to ensure compatibility + # with SGMV or BGMV operations. @classmethod - def load( + def prepare_weights( cls, config: LoraConfig, - model: "Model", module_map: Dict[str, Dict], layer_type: str, unused_weight_names: Set[str], + nlayers: int, + dtype: torch.dtype, + world_size: int, + process_group: ProcessGroup, + target_to_layer: Dict[str, Tuple[str, torch.Tensor]], ) -> Optional[AdapterWeights]: - nlayers = model.get_num_layers_for_type(layer_type) lora_a_list = [None] * nlayers lora_b_list = [None] * nlayers for layer_id in range(nlayers): key = (layer_id, layer_type) - weight_name, layer = model.target_to_layer[key] + weight_name, layer = target_to_layer[key] base_weight = layer.base_layer.linear.weight base_device = base_weight.device @@ -216,10 +216,10 @@ class LoraWeights(AdapterWeights): return None lora_a, lora_a_name = module_map[weight_name]["lora_A"] - lora_a = lora_a.to(base_device, model.dtype) + lora_a = lora_a.to(base_device, dtype) lora_b, lora_b_name = module_map[weight_name]["lora_B"] - lora_b = lora_b.to(base_device, model.dtype) + lora_b = lora_b.to(base_device, dtype) scale = get_scaling_factor( config.lora_alpha, @@ -236,12 +236,8 @@ class LoraWeights(AdapterWeights): lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale # pad lora ranks to be compatible with sgmv - lora_a_list = [ - pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list - ] - lora_b_list = [ - pad_rank(w, dim=0, world_size=model.world_size) for w in lora_b_list - ] + lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list] + lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list] if lora_a_list: # update rank if it was padded @@ -252,8 +248,8 @@ class LoraWeights(AdapterWeights): *shard_lora_weights( weights_a=lora_a_list, weights_b=lora_b_list, - split_dim=0 if model.is_row_parallel(layer_type) else 1, - process_group=model.process_group, + split_dim=0 if layer_type in {"o_proj", "down_proj", "lm_head"} else 1, + process_group=process_group, ), config, ) @@ -293,10 +289,6 @@ class BatchLoraWeights(BatchAdapterWeights): for rank_data in self.rank_data.values() ) - @classmethod - def key(cls) -> str: - return "lora" - @classmethod def load( self, diff --git a/server/text_generation_server/adapters/weights.py b/server/text_generation_server/adapters/weights.py index 8f658756..da75dbcd 100644 --- a/server/text_generation_server/adapters/weights.py +++ b/server/text_generation_server/adapters/weights.py @@ -42,10 +42,6 @@ class BatchAdapterWeights(ABC): def has_adapter(self, adapter_index: int) -> bool: pass - @abstractclassmethod - def key(cls) -> str: - pass - @abstractclassmethod def load( cls, @@ -71,13 +67,6 @@ class LayerAdapterWeights: return del self.adapter_weights[adapter_idx] - @property - def max_speculative_tokens(self) -> int: - return max( - adapter_weights.speculative_tokens - for adapter_weights in self.adapter_weights.values() - ) - def is_empty(self) -> bool: return len(self.adapter_weights) == 0 @@ -101,7 +90,7 @@ class LayerAdapterWeights: adapter_weights, meta, prefill, prefill_head_indices ) if batched_weights is not None: - batch_data[batch_type.key()] = batched_weights + batch_data = batched_weights return batch_data @@ -133,8 +122,7 @@ class AdapterBatchData: def ranks(self) -> Set[int]: # TODO(travis): refactor to be less coupled to lora implementation ranks = set() - for layer_data in self.data.values(): - lora_data = layer_data.get("lora") + for lora_data in self.data.values(): if lora_data is None: continue diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 8ec2a5ae..8b5520bf 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -4,9 +4,10 @@ import typer from pathlib import Path from loguru import logger -from typing import Optional +from typing import Optional, List, Dict from enum import Enum from huggingface_hub import hf_hub_download +from text_generation_server.utils.adapter import parse_lora_adapters from text_generation_server.utils.log import log_master @@ -80,28 +81,19 @@ def serve( if otlp_endpoint is not None: setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint) - lora_adapter_ids = os.getenv("LORA_ADAPTERS", None) - - # split on comma and strip whitespace - lora_adapter_ids = ( - [x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else [] - ) - - if len(lora_adapter_ids) > 0: - log_master( - logger.warning, - f"LoRA adapters are enabled. This is an experimental feature and may not work as expected.", - ) + lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS")) # TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled # and warn the user - if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None: - log_master( - logger.warning, - f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs.", - ) - global CUDA_GRAPHS - CUDA_GRAPHS = None + if lora_adapters: + logger.warning("LoRA adapters enabled (experimental feature).") + + if "CUDA_GRAPHS" in os.environ: + logger.warning( + "LoRA adapters incompatible with CUDA Graphs. Disabling CUDA Graphs." + ) + global CUDA_GRAPHS + CUDA_GRAPHS = None # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value @@ -117,7 +109,7 @@ def serve( ) server.serve( model_id, - lora_adapter_ids, + lora_adapters, revision, sharded, quantize, diff --git a/server/text_generation_server/layers/lora.py b/server/text_generation_server/layers/lora.py index 0bb6db41..df5e92da 100644 --- a/server/text_generation_server/layers/lora.py +++ b/server/text_generation_server/layers/lora.py @@ -43,10 +43,7 @@ class LoraLinear(nn.Module): ) -> torch.Tensor: if adapter_data is None: return result - data = adapter_data.data.get(layer_type) - data: Optional["BatchLoraWeights"] = ( - data.get("lora") if data is not None else None - ) + data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type) if has_sgmv() and data is not None and data.can_vectorize(self.process_group): # In tensor-parallel configurations, each GPU processes a specific segment of the output. diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 1cd13a2a..ac7a8f3e 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -6,7 +6,7 @@ from loguru import logger from transformers.configuration_utils import PretrainedConfig from transformers.models.auto import modeling_auto from huggingface_hub import hf_hub_download, HfApi -from typing import Optional, List +from typing import Optional, List, Dict from pathlib import Path from text_generation_server.utils.speculate import get_speculate, set_speculate @@ -33,6 +33,16 @@ from text_generation_server.models.custom_modeling.t5_modeling import ( T5ForConditionalGeneration, ) + +from text_generation_server.utils.adapter import ( + AdapterParameters, + build_layer_weight_lookup, + load_and_merge_adapters, + AdapterInfo, +) +from text_generation_server.adapters.lora import LoraWeights + + from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.log import log_master @@ -50,7 +60,7 @@ __all__ = [ "Model", "CausalLM", "Seq2SeqLM", - "get_model", + "get_model_with_lora_adapters", ] FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." @@ -1115,3 +1125,116 @@ def get_model( ) raise ValueError(f"Unsupported model type {model_type}") + + +# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters +# this provides a post model loading hook to load adapters into the model after the model has been loaded +def get_model_with_lora_adapters( + model_id: str, + lora_adapters: Optional[List[AdapterInfo]], + revision: Optional[str], + sharded: bool, + quantize: Optional[str], + speculate: Optional[int], + dtype: Optional[str], + trust_remote_code: bool, + max_input_tokens: int, + adapter_to_index: Dict[str, int], +): + lora_adapter_ids = [adapter.id for adapter in lora_adapters] + model = get_model( + model_id, + lora_adapter_ids, + revision, + sharded, + quantize, + speculate, + dtype, + trust_remote_code, + max_input_tokens, + ) + + if len(lora_adapters) > 0: + target_to_layer = build_layer_weight_lookup(model.model) + + for index, adapter in enumerate(lora_adapters): + # The AdapterParameters object allows for merging multiple adapters into a single adapter. + # At the moment, we only support loading a single adapter into the model, but we keep the + # AdapterParameters object for easier extension in the future. + adapter_parameters = AdapterParameters( + adapter_info=[adapter], + # when merging multiple adapters we can weight them differently + # if this is not set, all adapters will be weighted equally + # see: text_generation_server.utils.merges.strategies for impl + weights=None, + merge_strategy=0, + density=1.0, + majority_sign_method=0, + ) + + adapter_index = index + 1 + adapter_to_index[adapter.id] = adapter_index + + logger.info( + f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}" + ) + weight_names = tuple([v[0] for v in target_to_layer.values()]) + ( + module_map, + adapter_config, + adapter_weight_names, + adapter_tokenizer, + ) = load_and_merge_adapters( + model.model_id, + adapter_parameters, + adapter_index, + weight_names, + False, + ) + + unused_weight_names = adapter_weight_names.copy() + + adapter_layers = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ] + + for layer_name in adapter_layers: + nlayers = ( + 1 if layer_name == "lm_head" else len(model.model.model.layers) + ) + adapter_weights = LoraWeights.prepare_weights( + config=adapter_config, + module_map=module_map, + layer_type=layer_name, + unused_weight_names=unused_weight_names, + nlayers=nlayers, + dtype=model.dtype, + world_size=model.world_size, + process_group=model.process_group, + target_to_layer=target_to_layer, + ) + + if adapter_weights is None: + continue + + model.layer_to_adapter_weights[layer_name].add_adapter( + adapter_index, adapter_weights + ) + + if len(unused_weight_names) > 0: + logger.warning( + f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}" + ) + + if adapter_tokenizer is not None: + model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer) + + model.loaded_adapters.add(adapter_index) + + return model diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5db62431..2ca40959 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1,7 +1,6 @@ import math import os import time -import itertools import torch import torch.distributed @@ -1700,72 +1699,3 @@ class FlashCausalLM(Model): forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode return generations, batch, (forward_ns, decode_ns) - - @property - def supports_adapter_loading(self) -> bool: - return True - - def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: - layer_weights = {} - - prefix = "model.layers" - - # This accounts for VLMs (e.g. LlavaNext, Idefics2) - # that have a language_model inside of the larger model. - if hasattr(self.model, "language_model"): - _model = self.model.language_model - elif hasattr(self.model, "text_model"): - _model = self.model.text_model - else: - _model = self.model - - for i, layer in enumerate(_model.model.layers): - layer_weights[(i, "q_proj")] = ( - f"{prefix}.{i}.self_attn.q_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "k_proj")] = ( - f"{prefix}.{i}.self_attn.k_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "v_proj")] = ( - f"{prefix}.{i}.self_attn.v_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "o_proj")] = ( - f"{prefix}.{i}.self_attn.o_proj", - layer.self_attn.o_proj, - ) - - # TODO: this is a hack to avoid the gate_proj for - # FlashStarcoder2 that doesnt have these layers - if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"): - layer_weights[(i, "gate_proj")] = ( - f"{prefix}.{i}.mlp.gate_proj", - layer.mlp.gate_up_proj, - ) - layer_weights[(i, "up_proj")] = ( - f"{prefix}.{i}.mlp.up_proj", - layer.mlp.gate_up_proj, - ) - layer_weights[(i, "down_proj")] = ( - f"{prefix}.{i}.mlp.down_proj", - layer.mlp.down_proj, - ) - - layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head) - return layer_weights - - @property - def adapter_layers(self) -> List[str]: - return ADAPTER_LAYERS - - @property - def default_traced_adapter_layers(self) -> List[str]: - return ["q_proj", "v_proj"] - - def get_num_layers_for_type(self, layer_type: str) -> int: - return 1 if layer_type == "lm_head" else len(self.model.model.layers) - - def is_row_parallel(self, layer_type: str) -> bool: - return layer_type in ROW_PARALLEL diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py deleted file mode 100644 index 2b2bd2e0..00000000 --- a/server/text_generation_server/models/flash_mistral.py +++ /dev/null @@ -1,85 +0,0 @@ -import torch -from typing import Optional, Tuple, Dict, List - -from text_generation_server.models import FlashCausalLM - - -ADAPTER_LAYERS = [ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", -] -ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} - - -class FlashMistral(FlashCausalLM): - @property - def supports_adapter_loading(self) -> bool: - return True - - def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: - layer_weights = {} - - prefix = "model.layers" - - # This accounts for VLMs (e.g. LlavaNext, Idefics2) - # that have a language_model inside of the larger model. - if hasattr(self.model, "text_model"): - _model = self.model.text_model - else: - _model = self.model - - for i, layer in enumerate(_model.model.layers): - layer_weights[(i, "q_proj")] = ( - f"{prefix}.{i}.self_attn.q_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "k_proj")] = ( - f"{prefix}.{i}.self_attn.k_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "v_proj")] = ( - f"{prefix}.{i}.self_attn.v_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "o_proj")] = ( - f"{prefix}.{i}.self_attn.o_proj", - layer.self_attn.o_proj, - ) - - # TODO: this is a hack to avoid the gate_proj for - # FlashStarcoder2 that doesnt have these layers - if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"): - layer_weights[(i, "gate_proj")] = ( - f"{prefix}.{i}.mlp.gate_proj", - layer.mlp.gate_up_proj, - ) - layer_weights[(i, "up_proj")] = ( - f"{prefix}.{i}.mlp.up_proj", - layer.mlp.gate_up_proj, - ) - layer_weights[(i, "down_proj")] = ( - f"{prefix}.{i}.mlp.down_proj", - layer.mlp.down_proj, - ) - - layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head) - return layer_weights - - @property - def adapter_layers(self) -> List[str]: - return ADAPTER_LAYERS - - @property - def default_traced_adapter_layers(self) -> List[str]: - return ["q_proj", "v_proj"] - - def get_num_layers_for_type(self, layer_type: str) -> int: - return 1 if layer_type == "lm_head" else len(self.model.model.layers) - - def is_row_parallel(self, layer_type: str) -> bool: - return layer_type in ROW_PARALLEL diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index e7748bb9..159139de 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -4,20 +4,12 @@ import torch from abc import ABC, abstractmethod from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict from collections import defaultdict -from transformers import PreTrainedTokenizerBase, PretrainedConfig +from transformers import PreTrainedTokenizerBase from text_generation_server.models.types import Batch, Generation from text_generation_server.utils.speculate import get_speculate from text_generation_server.pb.generate_pb2 import InfoResponse from text_generation_server.adapters.weights import LayerAdapterWeights -from text_generation_server.utils.adapter import ( - load_and_merge_adapters, - AdapterParameters, - AdapterSource, -) -from text_generation_server.utils.log import log_master -from loguru import logger - BASE_MODEL_ADAPTER_ID = "__base_model__" @@ -61,7 +53,6 @@ class Model(ABC): self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict( LayerAdapterWeights ) - self.target_to_layer = None self.loaded_adapters = set() self.static_adapter_id = adapter_id @@ -142,140 +133,3 @@ class Model(ABC): raise RuntimeError( f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}" ) - - @property - def supports_adapter_loading(self) -> bool: - return False - - def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: - return {} - - @property - def adapter_layers(self) -> List[str]: - return [] - - @property - def default_traced_adapter_layers(self) -> List[str]: - return [] - - def get_num_layers_for_type(self, layer_type: str) -> int: - return 0 - - def is_row_parallel(self, layer_type: str) -> bool: - return False - - @property - def max_speculative_tokens(self) -> int: - return max( - [ - weights.max_speculative_tokens - for weights in self.layer_to_adapter_weights.values() - ], - default=0, - ) - - def load_adapter( - self, - adapter_parameters: AdapterParameters, - adapter_source: AdapterSource, - adapter_index: int, - api_token: str, - dynamic: bool = True, - ): - """Loads adapter weights from disk / host memory on the GPU. - - adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded - into model. Otherwise, the adapter weights are applied during the forward - pass and stored separately from the base model parameters. - """ - if self.target_to_layer is None: - self.target_to_layer = self.adapter_target_to_layer() - if adapter_index in self.loaded_adapters: - # Adapter already loaded - return - - if not self.supports_adapter_loading: - raise ValueError("This model does not support adapter loading.") - - if dynamic and not self.dynamic_adapter_loading_enabled: - raise ValueError( - f"This model was initialized with the adapter {self.static_adapter_id} " - f"and therefore does not support dynamic adapter loading. " - f"Please initialize a new model instance from the base model in " - f"order to use the dynamic adapter loading feature." - ) - - log_master( - logger.info, - f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}", - ) - weight_names = tuple([v[0] for v in self.target_to_layer.values()]) - ( - module_map, - adapter_config, - adapter_weight_names, - adapter_tokenizer, - ) = load_and_merge_adapters( - self.model_id, - adapter_parameters, - adapter_source, - adapter_index, - weight_names, - api_token, - False, - ) - - unused_weight_names = adapter_weight_names.copy() - for layer_name in self.adapter_layers: - adapter_weights = adapter_config.load_batched_adapter_weights( - self, - module_map, - layer_name, - unused_weight_names, - dynamic, - ) - - if adapter_weights is None: - continue - - layer_weights = self.layer_to_adapter_weights[layer_name] - layer_weights.add_adapter(adapter_index, adapter_weights) - - if len(unused_weight_names) > 0: - log_master( - logger.warning, - f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}", - ) - - if adapter_tokenizer is not None: - self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer) - - self.loaded_adapters.add(adapter_index) - - def offload_adapter( - self, - adapter_parameters: AdapterParameters, - adapter_source: AdapterSource, - adapter_index: int, - ): - """Offloads the adapter weights from GPU to CPU or disk.""" - if adapter_index not in self.loaded_adapters: - # Adapter already offloaded - return - - if not self.supports_adapter_loading: - raise ValueError("This model does not support adapter loading.") - - if not self.dynamic_adapter_loading_enabled: - raise ValueError( - f"This model was initialized with the adapter {self.static_adapter_id} " - f"and therefore does not support dynamic adapter loading. " - f"Please initialize a new model instance from the base model in " - f"order to use the dynamic adapter loading feature." - ) - - for layer_name in self.adapter_layers: - if layer_name in self.layer_to_adapter_weights: - self.layer_to_adapter_weights[layer_name].remove_adapter(adapter_index) - - self.loaded_adapters.remove(adapter_index) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index aee287c6..7ac54603 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -9,11 +9,12 @@ from loguru import logger from grpc_reflection.v1alpha import reflection from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Dict from text_generation_server.cache import Cache from text_generation_server.interceptor import ExceptionInterceptor -from text_generation_server.models import Model, get_model +from text_generation_server.models import Model, get_model_with_lora_adapters +from text_generation_server.utils.adapter import AdapterInfo try: from text_generation_server.models.pali_gemma import PaliGemmaBatch @@ -30,9 +31,6 @@ except (ImportError, NotImplementedError): from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.models.globals import set_model_id, set_adapter_to_index -from text_generation_server.utils.adapter import ( - AdapterParameters, -) class SignalHandler: @@ -195,7 +193,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def serve( model_id: str, - lora_adapter_ids: Optional[List[str]], + lora_adapters: Optional[List[AdapterInfo]], revision: Optional[str], sharded: bool, quantize: Optional[str], @@ -207,7 +205,7 @@ def serve( ): async def serve_inner( model_id: str, - lora_adapter_ids: Optional[List[str]], + lora_adapters: Optional[List[AdapterInfo]], revision: Optional[str], sharded: bool = False, quantize: Optional[str] = None, @@ -228,9 +226,9 @@ def serve( server_urls = [local_url] try: - model = get_model( + model = get_model_with_lora_adapters( model_id, - lora_adapter_ids, + lora_adapters, revision, sharded, quantize, @@ -238,29 +236,9 @@ def serve( dtype, trust_remote_code, max_input_tokens, + adapter_to_index, ) - if len(lora_adapter_ids) > 0: - for index, adapter_id in enumerate(lora_adapter_ids): - # TODO: improve non merged adapter loading and long term - # improve adapter loading as a whole - adapter_parameters = AdapterParameters( - adapter_ids=[adapter_id], - weights=None, # will be set to 1 - merge_strategy=0, - density=1.0, - majority_sign_method=0, - ) - adapter_index = index + 1 - adapter_to_index[adapter_id] = adapter_index - model.load_adapter( - adapter_parameters, - None, # adapter_source - adapter_index, - None, # api_token - False, # dynamic - ) - except Exception: logger.exception("Error when initializing model") raise @@ -297,7 +275,7 @@ def serve( asyncio.run( serve_inner( model_id, - lora_adapter_ids, + lora_adapters, revision, sharded, quantize, diff --git a/server/text_generation_server/utils/adapter.py b/server/text_generation_server/utils/adapter.py index 4e2492de..1db5f77b 100644 --- a/server/text_generation_server/utils/adapter.py +++ b/server/text_generation_server/utils/adapter.py @@ -5,12 +5,11 @@ import warnings from dataclasses import dataclass from functools import lru_cache -from typing import TYPE_CHECKING, Set, Tuple +from typing import TYPE_CHECKING, Set, Tuple, Optional, List from safetensors.torch import load_file from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer -from text_generation_server.pb import generate_pb2 from text_generation_server.utils.merges.strategies import merge_adapters from text_generation_server.utils import hub @@ -24,9 +23,15 @@ if TYPE_CHECKING: BASE_MODEL_ADAPTER_ID = "__base_model__" +@dataclass +class AdapterInfo: + id: str + path: Optional[str] + + @dataclass class AdapterParameters: - adapter_ids: Tuple[str] + adapter_info: Tuple[AdapterInfo] weights: Tuple[float] merge_strategy: NotImplemented density: float @@ -40,37 +45,47 @@ class AdapterSource: revision: str +def parse_lora_adapters(lora_adapters: Optional[str]) -> List[AdapterInfo]: + if not lora_adapters: + return [] + + adapter_list = [] + for adapter in lora_adapters.split(","): + parts = adapter.strip().split("=") + if len(parts) == 1: + adapter_list.append(AdapterInfo(id=parts[0], path=None)) + elif len(parts) == 2: + adapter_list.append(AdapterInfo(id=parts[0], path=parts[1])) + else: + raise ValueError(f"Invalid LoRA adapter format: {adapter}") + return adapter_list + + def load_and_merge_adapters( model_id: str, adapter_parameters: AdapterParameters, - adapter_source: str, adapter_index: int, weight_names: Tuple[str], - api_token: str, trust_remote_code: bool = False, ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: - if len(adapter_parameters.adapter_ids) == 1: + + if len(adapter_parameters.adapter_info) == 1: + adapter_info = next(iter(adapter_parameters.adapter_info)) return load_module_map( model_id, - adapter_parameters.adapter_ids[0], - adapter_source, + adapter_info.id, + adapter_info.path, weight_names, - api_token, trust_remote_code, ) - adapter_params = AdapterParametersContainer( - adapter_parameters, adapter_source, adapter_index - ) - return _load_and_merge( - model_id, adapter_params, weight_names, api_token, trust_remote_code - ) + adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index) + return _load_and_merge(model_id, adapter_params, weight_names, trust_remote_code) @dataclass class AdapterParametersContainer: adapter_parameters: AdapterParameters - adapter_source: str adapter_index: int def __hash__(self) -> int: @@ -82,7 +97,6 @@ def _load_and_merge( model_id: str, adapter_params: AdapterParametersContainer, weight_names: Tuple[str], - api_token: str, trust_remote_code: bool = False, ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: params = adapter_params.adapter_parameters @@ -90,17 +104,16 @@ def _load_and_merge( adapters_to_merge = [] merged_weight_names = set() tokenizer = None - for adapter_id in params.adapter_ids: - if adapter_id == BASE_MODEL_ADAPTER_ID: + for adapter in params.adapter_info: + if adapter.id == BASE_MODEL_ADAPTER_ID: raise ValueError("Base model adapter cannot be merged.") module_map, adapter_config, adapter_weight_names, adapter_tokenizer = ( load_module_map( model_id, - adapter_id, - adapter_params.adapter_source, + adapter.id, + adapter.path, weight_names, - api_token, trust_remote_code, ) ) @@ -159,25 +172,28 @@ def check_architectures( def load_module_map( model_id: str, adapter_id: str, - adapter_source: str, + adapter_path: Optional[str], weight_names: Tuple[str], - api_token: str, trust_remote_code: bool = False, ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: revision = "main" - adapter_config = LoraConfig.load(adapter_id, api_token) - if adapter_config.base_model_name_or_path != model_id: + adapter_config = LoraConfig.load(adapter_path or adapter_id, None) + + if not adapter_path and adapter_config.base_model_name_or_path != model_id: check_architectures(model_id, adapter_id, adapter_config, trust_remote_code) - adapter_filenames = hub._cached_adapter_weight_files( - adapter_id, revision=revision, extension=".safetensors" + adapter_filenames = ( + hub._adapter_weight_files_from_dir(adapter_path, extension=".safetensors") + if adapter_path + else hub._cached_adapter_weight_files( + adapter_id, revision=revision, extension=".safetensors" + ) ) try: adapter_tokenizer = AutoTokenizer.from_pretrained( adapter_config.config_path, - token=api_token, trust_remote_code=trust_remote_code, ) except Exception: @@ -194,3 +210,87 @@ def load_module_map( adapter_weights, weight_names ) return module_map, adapter_config, adapter_weight_names, adapter_tokenizer + + +def get_attn_weights(i, layer): + qkv = layer.self_attn.query_key_value + weights = {} + + for k in ["q", "k", "v"]: + key = (i, f"{k}_proj") + value = (f"model.layers.{i}.self_attn.{k}_proj", qkv) + weights[key] = value + + weights[(i, "o_proj")] = ( + f"model.layers.{i}.self_attn.o_proj", + layer.self_attn.o_proj, + ) + + return weights + + +def get_mlp_weights(i, layer): + weights = {} + if hasattr(layer, "mlp"): + mlp = layer.mlp + if hasattr(mlp, "gate_up_proj"): + # handle combined gate_up_proj (e.g., for some LLaMA variants) + weights.update( + { + (i, "gate_proj"): ( + f"model.layers.{i}.mlp.gate_proj", + mlp.gate_up_proj, + ), + (i, "up_proj"): (f"model.layers.{i}.mlp.up_proj", mlp.gate_up_proj), + } + ) + else: + # handle separate gate_proj, up_proj, and down_proj (e.g., for Gemma) + if hasattr(mlp, "gate_proj"): + weights[(i, "gate_proj")] = ( + f"model.layers.{i}.mlp.gate_proj", + mlp.gate_proj, + ) + if hasattr(mlp, "up_proj"): + weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj) + + if hasattr(mlp, "down_proj"): + weights[(i, "down_proj")] = ( + f"model.layers.{i}.mlp.down_proj", + mlp.down_proj, + ) + + return weights + + +# build_layer_weight_lookup creates a mapping of model layers to their corresponding +# weight tensors and paths. It builds a dictionary that maps layer identifiers to tuples +# containing the weight tensor path and the actual layer object. This mapping is needed +# for the lora adapter to know which weights to update when applying the adapter. +def build_layer_weight_lookup(model): + if hasattr(model, "language_model"): + m = model.language_model.model + elif hasattr(model, "text_model"): + m = model.text_model.model + else: + m = model.model + + layer_weights = {} + + for i, layer in enumerate(m.layers): + attn_weights = get_attn_weights(i, layer) + mlp_weights = get_mlp_weights(i, layer) + + layer_weights.update(attn_weights) + layer_weights.update(mlp_weights) + + lm_head = None + if hasattr(m, "lm_head"): + lm_head = m.lm_head + elif hasattr(model, "lm_head"): + lm_head = model.lm_head + + if lm_head: + layer_weights[(0, "lm_head")] = ("lm_head", lm_head) + + return layer_weights