# Origin: https://github.com/predibase/lorax # Path: lorax/server/lorax_server/adapters/lora.py # License: Apache License Version 2.0, January 2004 from collections import defaultdict from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union import torch from peft import LoraConfig as _LoraConfig from torch.distributed import ProcessGroup from text_generation_server.adapters.config import AdapterConfig, ModuleMap from text_generation_server.adapters.weights import ( AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights, ) from text_generation_server.utils.sgmv import ( BGMV_MAX_RANK, MAX_RANK_CUSTOM, get_tmp_tensors, orient_for_rank, pad_rank, use_cutlass_shrink, ) if TYPE_CHECKING: from text_generation_server.models.model import Model def get_start_stop_idxs_for_rank(offset, size, rank, world_size): block_size = size // world_size start = offset + rank * block_size stop = offset + (rank + 1) * block_size return start, stop def shard_on_dim( t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup ): world_size = process_group.size() rank = process_group.rank() size = t.shape[dim] start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size) if dim == 0: tensor = t[start:stop] elif dim == 1: tensor = t[:, start:stop] else: raise NotImplementedError("Let's make that generic when needed") return tensor def shard_lora_weights( weights_a: List[torch.Tensor], weights_b: List[torch.Tensor], split_dim: int, process_group: ProcessGroup, ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: # [hidden_size, r] weights_a = [ shard_on_dim(w, dim=split_dim, process_group=process_group) for w in weights_a ] # [r, hidden_size] weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b] return weights_a, weights_b @dataclass class LoraConfig(AdapterConfig): r: int target_modules: Optional[Union[List[str], str]] fan_in_fan_out: bool lora_alpha: int use_rslora: bool def map_weights_for_model( self, adapter_weights: Dict[int, AdapterWeights], weight_names: Tuple[str], ) -> Tuple[ModuleMap, Set[str]]: adapter_weight_names = set() module_map = {} for weight_name in weight_names: lora_a_name = f"base_model.model.{weight_name}.lora_A.weight" lora_b_name = f"base_model.model.{weight_name}.lora_B.weight" if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights: continue module_map[weight_name] = { "lora_A": (adapter_weights[lora_a_name], lora_a_name), "lora_B": (adapter_weights[lora_b_name], lora_b_name), } adapter_weight_names.add(lora_a_name) adapter_weight_names.add(lora_b_name) return module_map, adapter_weight_names @classmethod def load(cls, adapter_id: str, api_token: str) -> "LoraConfig": hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token) return cls( base_model_name_or_path=hf_config.base_model_name_or_path, r=hf_config.r, target_modules=hf_config.target_modules, fan_in_fan_out=hf_config.fan_in_fan_out, lora_alpha=hf_config.lora_alpha, use_rslora=( hf_config.use_rslora if hasattr(hf_config, "use_rslora") else False ), ) class LoraWeights(AdapterWeights): """LoRA weights for a single adapter merged across all layers.""" def __init__( self, weights_a: List[torch.Tensor], weights_b: List[torch.Tensor], adapter_config: LoraConfig, ): self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1 self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1 self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r) self._is_transposed = False # [num_layers, hidden_size, r] weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a] self._weights_a = torch.stack(weights_a) # [num_layers, r, hidden_size] self._weights_b = torch.stack(weights_b) self.adapter_config = adapter_config @property def weights_a(self) -> torch.Tensor: if self._is_transposed: self._transpose_weights() return self._weights_a @property def weights_b(self) -> torch.Tensor: if self._is_transposed: self._transpose_weights() return self._weights_b @property def weights_a_t(self) -> torch.Tensor: if not self._is_transposed: self._transpose_weights() return self._weights_a @property def weights_b_t(self) -> torch.Tensor: if not self._is_transposed: self._transpose_weights() return self._weights_b def _transpose_weights(self): if self._use_cutlass_shrink: # If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation self._weights_a = self._weights_a.transpose(1, 2).contiguous() self._weights_b = self._weights_b.transpose(1, 2).contiguous() self._is_transposed = not self._is_transposed @classmethod def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]: return [BatchLoraWeights] # prepare pre-loaded lora weights for use in the model. # # this method processes and organizes lora weights for a specific layer type across all layers: # - uses `config` (LoraConfig) to apply lora-specific settings like scaling factor. # - retrieves weights from `module_map` based on the `layer_type`. # - processes `nlayers` number of layers. # - converts weights to the specified `dtype`. # - shards weights across `world_size` number of processes using the `process_group`. # - maps weights to specific layers using `target_to_layer`. # - tracks `unused_weight_names` to identify any unused weights. # # the method handles weight transposition, scaling, and padding to ensure compatibility # with SGMV or BGMV operations. @classmethod def prepare_weights( cls, config: LoraConfig, module_map: Dict[str, Dict], layer_type: str, unused_weight_names: Set[str], nlayers: int, dtype: torch.dtype, world_size: int, process_group: ProcessGroup, target_to_layer: Dict[str, Tuple[str, torch.Tensor]], ) -> Optional[AdapterWeights]: lora_a_list = [None] * nlayers lora_b_list = [None] * nlayers for layer_id in range(nlayers): key = (layer_id, layer_type) weight_name, layer = target_to_layer[key] base_weight = layer.base_layer.linear.weight base_device = base_weight.device if weight_name not in module_map: # There is no LoRA weight for this layer type in the adapter return None lora_a, lora_a_name = module_map[weight_name]["lora_A"] lora_a = lora_a.to(base_device, dtype) lora_b, lora_b_name = module_map[weight_name]["lora_B"] lora_b = lora_b.to(base_device, dtype) scale = get_scaling_factor( config.lora_alpha, config.r, uses_rslora=config.use_rslora, ) unused_weight_names.discard(lora_a_name) unused_weight_names.discard(lora_b_name) # Merge scaling factor into lora_b due to associativity of matrix multiplication: # (A * B) * C = A * (B * C) lora_a_list[layer_id] = lora_a.transpose(0, 1) lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale # pad lora ranks to be compatible with sgmv lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list] lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list] if lora_a_list: # update rank if it was padded padded_rank = lora_a_list[0].size(1) config.r = padded_rank return LoraWeights( *shard_lora_weights( weights_a=lora_a_list, weights_b=lora_b_list, split_dim=0 if layer_type in {"o_proj", "down_proj", "lm_head"} else 1, process_group=process_group, ), config, ) @dataclass class RankSegments: rank: int lora_a_ptr: torch.Tensor lora_b_ptr: torch.Tensor # prefill (sgmv) tmp_shrink: torch.Tensor tmp_expand: torch.Tensor segment_starts: torch.Tensor segment_ends: torch.Tensor # decode (bgmv) indices: torch.Tensor @dataclass class BatchLoraWeights(BatchAdapterWeights): lora_a: Dict[int, torch.Tensor] lora_b: Dict[int, torch.Tensor] adapter_index_configs: Dict[int, LoraConfig] rank_data: Dict[int, RankSegments] use_sgmv: bool def has_adapter(self, adapter_index: int) -> bool: return adapter_index in self.adapter_index_configs def can_vectorize(self, pg: ProcessGroup) -> bool: return all( rank_data.rank // pg.size() <= MAX_RANK_CUSTOM for rank_data in self.rank_data.values() ) @classmethod def load( self, adapter_weights: Dict[int, AdapterWeights], meta: AdapterBatchMetadata, prefill: bool, prefill_head_indices: Optional[torch.Tensor], ) -> Optional["BatchLoraWeights"]: adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()} adapter_weights = { k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights) } if not adapter_weights: return None first_weights = next(iter(adapter_weights.values())) device = first_weights.weights_a.device segment_indices = meta.segment_indices lora_a = { idx: adapter_weights[idx].weights_a for idx in segment_indices if idx in adapter_weights } lora_b = { idx: adapter_weights[idx].weights_b for idx in segment_indices if idx in adapter_weights } max_rank = max( ( adapter_weights[idx].lora_a_r for idx in segment_indices if idx in adapter_weights ), default=0, ) if prefill or max_rank > BGMV_MAX_RANK: use_sgmv = True lora_a_ptr = torch.tensor( [ ( adapter_weights[idx].weights_a.data_ptr() if idx in adapter_weights else 0 ) for idx in segment_indices ], dtype=torch.int64, device=device, ) lora_b_ptr = torch.tensor( [ ( adapter_weights[idx].weights_b.data_ptr() if idx in adapter_weights else 0 ) for idx in segment_indices ], dtype=torch.int64, device=device, ) else: use_sgmv = False lora_a_ptr = torch.tensor( [ ( adapter_weights[idx].weights_a_t.data_ptr() if idx in adapter_weights else 0 ) for idx in segment_indices ], dtype=torch.int64, device=device, ) lora_b_ptr = torch.tensor( [ ( adapter_weights[idx].weights_b_t.data_ptr() if idx in adapter_weights else 0 ) for idx in segment_indices ], dtype=torch.int64, device=device, ) adapter_index_configs = { idx: adapter_weights[idx].adapter_config for idx in segment_indices if idx in adapter_weights } adapter_to_segment = {v: k for k, v in enumerate(segment_indices)} rank_indices = defaultdict(list) for segment_idx, adapter_idx in enumerate(segment_indices): if adapter_idx not in adapter_weights: continue rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx) if prefill_head_indices is not None: j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0] for head_index in prefill_head_indices: # j cannot go out of bounds as that would mean there are tokens without corresponding adapters if head_index < meta.adapter_segments[j]: prefill_head_segment_ends[-1] += 1 else: prefill_head_segment_starts.append(prefill_head_segment_ends[-1]) prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1) j += 1 rank_data = {} for rank, indices in rank_indices.items(): tmp_shrink = None tmp_expand = None segment_starts = None segment_ends = None batch_indices = None if use_sgmv: lora_a_ptr_indices = lora_a_ptr[indices] tmp_shrink, tmp_expand = get_tmp_tensors( lora_a_ptr_indices.size(0), rank, device ) segment_starts = meta.adapter_segments[indices] segment_ends = meta.adapter_segments[[i + 1 for i in indices]] if prefill_head_indices is not None: for i, segment_index in enumerate(indices): segment_starts[i] = prefill_head_segment_starts[segment_index] segment_ends[i] = prefill_head_segment_ends[segment_index] else: rank_indices = set(indices) batch_indices = [ adapter_to_segment[idx] for idx in meta.adapter_indices.tolist() ] batch_indices = [ idx if idx in rank_indices else -1 for idx in batch_indices ] batch_indices = torch.tensor( batch_indices, dtype=torch.int64, device=device ) rank_data[rank] = RankSegments( rank=rank, tmp_shrink=tmp_shrink, tmp_expand=tmp_expand, lora_a_ptr=lora_a_ptr[indices], lora_b_ptr=lora_b_ptr[indices], segment_starts=segment_starts, segment_ends=segment_ends, indices=batch_indices, ) return BatchLoraWeights( lora_a=lora_a, lora_b=lora_b, adapter_index_configs=adapter_index_configs, rank_data=rank_data, use_sgmv=use_sgmv, ) def get_scaling_factor( lora_alpha: int, r: int, uses_rslora: bool = False, ) -> float: """Computes the scaling factor for the lora weights.""" if uses_rslora: return lora_alpha / (r**0.5) return lora_alpha / r def _convert_lora(v: AdapterWeights) -> AdapterWeights: if hasattr(v, "lora_weights"): return v.lora_weights return v