import copy from abc import ABC from collections import defaultdict from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union from text_generation_server.utils.merges.utils import ( calculate_majority_sign_mask, disjoint_merge, prune, ) import torch if TYPE_CHECKING: from text_generation_server.adapters.lora import LoraConfig from text_generation_server.utils.adapter import ModuleMap class AdapterParameters: def __init__( self, adapter_ids, weights, merge_strategy, density, majority_sign_method ): self.adapter_ids = adapter_ids self.weights = weights self.merge_strategy = merge_strategy self.density = density self.majority_sign_method = majority_sign_method def _apply_weights( tensors: Union[torch.Tensor, List[torch.Tensor]], w: torch.Tensor ) -> torch.Tensor: if isinstance(tensors, torch.Tensor): t = tensors else: t = torch.stack(tensors, dim=0) # element-wise weighting of each task tensor # need to unsqueeze weights to match task tensor dimensions # for multiplication to apply element-wise while len(t.shape) > len(w.shape): w = w.unsqueeze(-1) return t * w class MergeStrategy(ABC): def merge( self, task_tensors: List[torch.Tensor], weights: torch.Tensor ) -> torch.Tensor: raise NotImplementedError() class LinearMerge(MergeStrategy): def __init__(self, **kwargs): pass def merge( self, task_tensors: List[torch.Tensor], weights: torch.Tensor ) -> torch.Tensor: weighted_task_tensors = _apply_weights(task_tensors, weights) return weighted_task_tensors.sum(dim=0) class TiesMerge(MergeStrategy): def __init__(self, density: float, majority_sign_method: str = "total", **kwargs): self.density = density self.majority_sign_method = majority_sign_method def merge( self, task_tensors: List[torch.Tensor], weights: torch.Tensor ) -> torch.Tensor: # sparsify task_tensors = [ prune(tensor, self.density, method="magnitude") for tensor in task_tensors ] task_tensors = torch.stack(task_tensors, dim=0) # elect sign before applying weights majority_sign_mask = calculate_majority_sign_mask( task_tensors, method=self.majority_sign_method ) weighted_task_tensors = _apply_weights(task_tensors, weights) # disjoint merge return disjoint_merge(weighted_task_tensors, majority_sign_mask) class DareLinearMerge(MergeStrategy): def __init__(self, density: float, **kwargs): self.density = density def merge( self, task_tensors: List[torch.Tensor], weights: torch.Tensor ) -> torch.Tensor: # sparsify task_tensors = [ prune(tensor, self.density, method="random", rescale=True) for tensor in task_tensors ] weighted_task_tensors = _apply_weights(task_tensors, weights) return weighted_task_tensors.sum(dim=0) class DareTiesMerge(MergeStrategy): def __init__(self, density: float, majority_sign_method: str = "total", **kwargs): self.density = density self.majority_sign_method = majority_sign_method def merge( self, task_tensors: List[torch.Tensor], weights: torch.Tensor ) -> torch.Tensor: # sparsify task_tensors = [ prune(tensor, self.density, method="random", rescale=True) for tensor in task_tensors ] task_tensors = torch.stack(task_tensors, dim=0) # elect sign before applying weights majority_sign_mask = calculate_majority_sign_mask( task_tensors, method=self.majority_sign_method ) weighted_task_tensors = _apply_weights(task_tensors, weights) # disjoint merge mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask) return mixed_task_tensors strategy_registry: Dict[str, Type[MergeStrategy]] = { "linear": LinearMerge, "ties": TiesMerge, "dare_linear": DareLinearMerge, "dare_ties": DareTiesMerge, } def merge_adapters( adapters: List[Tuple["ModuleMap", "LoraConfig"]], merge_params: AdapterParameters, ) -> Tuple["ModuleMap", "LoraConfig"]: # strategy_name = MergeStrategyEnum.Name(merge_params.merge_strategy).lower() strategy_name = "linear" weights = merge_params.weights if not weights: weights = torch.ones(len(adapters)) else: weights = torch.tensor(weights) merge_config = { "density": merge_params.density, # "majority_sign_method": MajoritySignMethodEnum.Name( # merge_params.majority_sign_method # ).lower(), "majority_sign_method": "total", } merge_strategy = strategy_registry[strategy_name](**merge_config) module_maps: Dict[str, Dict[str, Dict[str, List[torch.Tensor]]]] = defaultdict( lambda: defaultdict(lambda: defaultdict(list)) ) lora_configs = [] weight_name_to_adapter_idx = defaultdict(list) # input is list of (module_map, lora_config) tuples # convert into dict[k][param_name] -> list of tensors for idx, (module_map, lora_config) in enumerate(adapters): for weight_name, data in module_map.items(): weight_name_to_adapter_idx[weight_name].append(idx) for k, (param_data, param_name) in data.items(): module_maps[weight_name][k][param_name].append(param_data) lora_configs.append(lora_config) # validate lora configs are compatible _validate_lora_configs(lora_configs) # merge tensors for each module such that we have a single ModuleMap: # dict[k] -> merged tensor merged_module_map: "ModuleMap" = defaultdict(dict) for weight_name, data in module_maps.items(): indices = weight_name_to_adapter_idx[weight_name] param_weights = weights[indices] for k, param_data in data.items(): for param_name, tensors in param_data.items(): merged_tensor = merge_strategy.merge(tensors, param_weights) merged_module_map[weight_name][k] = (merged_tensor, param_name) # merge lora configs merged_lora_config = _merge_lora_configs(lora_configs) return merged_module_map, merged_lora_config def _validate_lora_configs(lora_configs: List["LoraConfig"]): # check that all configs have the same rank ranks = set(lora_config.r for lora_config in lora_configs) if len(ranks) > 1: raise ValueError( f"unable to merge adapters, lora configs have different ranks: {ranks}" ) if all(len(lora_config.target_modules) == 0 for lora_config in lora_configs): raise ValueError( "unable to merge adapters, lora configs have no target modules" ) def _merge_lora_configs(lora_configs: List["LoraConfig"]) -> "LoraConfig": merged_lora_config = copy.copy(lora_configs[0]) # merge target modules as a union operation merged_target_modules = sorted( set( module for lora_config in lora_configs for module in lora_config.target_modules ) ) merged_lora_config.target_modules = merged_target_modules return merged_lora_config