472 lines
16 KiB
Python
472 lines
16 KiB
Python
# Origin: https://github.com/predibase/lorax
|
|
# Path: lorax/server/lorax_server/adapters/lora.py
|
|
# License: Apache License Version 2.0, January 2004
|
|
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from typing import Dict, List, Optional, Set, Tuple, Type, Union
|
|
|
|
import torch
|
|
from peft import LoraConfig as _LoraConfig
|
|
from torch.distributed import ProcessGroup
|
|
|
|
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
|
|
|
|
from text_generation_server.adapters.weights import (
|
|
AdapterBatchMetadata,
|
|
AdapterWeights,
|
|
BatchAdapterWeights,
|
|
)
|
|
from text_generation_server.utils.sgmv import (
|
|
BGMV_MAX_RANK,
|
|
MAX_RANK_CUSTOM,
|
|
get_tmp_tensors,
|
|
orient_for_rank,
|
|
pad_rank,
|
|
use_cutlass_shrink,
|
|
)
|
|
|
|
|
|
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
|
|
block_size = size // world_size
|
|
start = offset + rank * block_size
|
|
stop = offset + (rank + 1) * block_size
|
|
return start, stop
|
|
|
|
|
|
def shard_on_dim(
|
|
t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup
|
|
):
|
|
world_size = process_group.size()
|
|
rank = process_group.rank()
|
|
|
|
size = t.shape[dim]
|
|
start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)
|
|
|
|
if dim == 0:
|
|
tensor = t[start:stop]
|
|
elif dim == 1:
|
|
tensor = t[:, start:stop]
|
|
else:
|
|
raise NotImplementedError("Let's make that generic when needed")
|
|
|
|
return tensor
|
|
|
|
|
|
def shard_lora_weights(
|
|
weights_a: List[torch.Tensor],
|
|
weights_b: List[torch.Tensor],
|
|
split_dim: int,
|
|
process_group: ProcessGroup,
|
|
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
|
# [hidden_size, r]
|
|
weights_a = [
|
|
shard_on_dim(w, dim=split_dim, process_group=process_group) for w in weights_a
|
|
]
|
|
|
|
# [r, hidden_size]
|
|
weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b]
|
|
|
|
return weights_a, weights_b
|
|
|
|
|
|
@dataclass
|
|
class LoraConfig(AdapterConfig):
|
|
r: int
|
|
target_modules: Optional[Union[List[str], str]]
|
|
fan_in_fan_out: bool
|
|
lora_alpha: int
|
|
use_rslora: bool
|
|
|
|
def map_weights_for_model(
|
|
self,
|
|
adapter_weights: Dict[int, AdapterWeights],
|
|
weight_names: Tuple[str],
|
|
) -> Tuple[ModuleMap, Set[str]]:
|
|
adapter_weight_names = set()
|
|
module_map = {}
|
|
for weight_name in weight_names:
|
|
lora_a_name = f"base_model.model.{weight_name}.lora_A.weight"
|
|
lora_b_name = f"base_model.model.{weight_name}.lora_B.weight"
|
|
if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights:
|
|
continue
|
|
|
|
module_map[weight_name] = {
|
|
"lora_A": (adapter_weights[lora_a_name], lora_a_name),
|
|
"lora_B": (adapter_weights[lora_b_name], lora_b_name),
|
|
}
|
|
adapter_weight_names.add(lora_a_name)
|
|
adapter_weight_names.add(lora_b_name)
|
|
return module_map, adapter_weight_names
|
|
|
|
@classmethod
|
|
def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
|
|
hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
|
|
return cls(
|
|
base_model_name_or_path=hf_config.base_model_name_or_path,
|
|
r=hf_config.r,
|
|
target_modules=hf_config.target_modules,
|
|
fan_in_fan_out=hf_config.fan_in_fan_out,
|
|
lora_alpha=hf_config.lora_alpha,
|
|
use_rslora=(
|
|
hf_config.use_rslora if hasattr(hf_config, "use_rslora") else False
|
|
),
|
|
)
|
|
|
|
|
|
class LoraWeights(AdapterWeights):
|
|
"""LoRA weights for a single adapter merged across all layers."""
|
|
|
|
def __init__(
|
|
self,
|
|
weights_a: List[torch.Tensor],
|
|
weights_b: List[torch.Tensor],
|
|
adapter_config: LoraConfig,
|
|
):
|
|
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
|
|
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1
|
|
|
|
self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r)
|
|
self._is_transposed = False
|
|
|
|
# [num_layers, hidden_size, r]
|
|
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
|
|
self._weights_a = torch.stack(weights_a)
|
|
|
|
# [num_layers, r, hidden_size]
|
|
self._weights_b = torch.stack(weights_b)
|
|
|
|
self.adapter_config = adapter_config
|
|
|
|
@property
|
|
def weights_a(self) -> torch.Tensor:
|
|
if self._is_transposed:
|
|
self._transpose_weights()
|
|
return self._weights_a
|
|
|
|
@property
|
|
def weights_b(self) -> torch.Tensor:
|
|
if self._is_transposed:
|
|
self._transpose_weights()
|
|
return self._weights_b
|
|
|
|
@property
|
|
def weights_a_t(self) -> torch.Tensor:
|
|
if not self._is_transposed:
|
|
self._transpose_weights()
|
|
return self._weights_a
|
|
|
|
@property
|
|
def weights_b_t(self) -> torch.Tensor:
|
|
if not self._is_transposed:
|
|
self._transpose_weights()
|
|
return self._weights_b
|
|
|
|
def _transpose_weights(self):
|
|
if self._use_cutlass_shrink:
|
|
# If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation
|
|
self._weights_a = self._weights_a.transpose(1, 2).contiguous()
|
|
self._weights_b = self._weights_b.transpose(1, 2).contiguous()
|
|
self._is_transposed = not self._is_transposed
|
|
|
|
@classmethod
|
|
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
|
|
return [BatchLoraWeights]
|
|
|
|
# prepare pre-loaded lora weights for use in the model.
|
|
#
|
|
# this method processes and organizes lora weights for a specific layer type across all layers:
|
|
# - uses `config` (LoraConfig) to apply lora-specific settings like scaling factor.
|
|
# - retrieves weights from `module_map` based on the `layer_type`.
|
|
# - processes `nlayers` number of layers.
|
|
# - converts weights to the specified `dtype`.
|
|
# - shards weights across `world_size` number of processes using the `process_group`.
|
|
# - maps weights to specific layers using `target_to_layer`.
|
|
# - tracks `unused_weight_names` to identify any unused weights.
|
|
#
|
|
# the method handles weight transposition, scaling, and padding to ensure compatibility
|
|
# with SGMV or BGMV operations.
|
|
@classmethod
|
|
def prepare_weights(
|
|
cls,
|
|
config: LoraConfig,
|
|
module_map: Dict[str, Dict],
|
|
layer_type: str,
|
|
unused_weight_names: Set[str],
|
|
nlayers: int,
|
|
dtype: torch.dtype,
|
|
world_size: int,
|
|
process_group: ProcessGroup,
|
|
target_to_layer: Dict[str, Tuple[str, torch.Tensor]],
|
|
) -> Optional[AdapterWeights]:
|
|
lora_a_list = [None] * nlayers
|
|
lora_b_list = [None] * nlayers
|
|
|
|
for layer_id in range(nlayers):
|
|
key = (layer_id, layer_type)
|
|
weight_name, layer = target_to_layer[key]
|
|
base_weight = layer.base_layer.linear.weight
|
|
base_device = base_weight.device
|
|
|
|
if weight_name not in module_map:
|
|
# There is no LoRA weight for this layer type in the adapter
|
|
return None
|
|
|
|
lora_a, lora_a_name = module_map[weight_name]["lora_A"]
|
|
lora_a = lora_a.to(base_device, dtype)
|
|
|
|
lora_b, lora_b_name = module_map[weight_name]["lora_B"]
|
|
lora_b = lora_b.to(base_device, dtype)
|
|
|
|
scale = get_scaling_factor(
|
|
config.lora_alpha,
|
|
config.r,
|
|
uses_rslora=config.use_rslora,
|
|
)
|
|
|
|
unused_weight_names.discard(lora_a_name)
|
|
unused_weight_names.discard(lora_b_name)
|
|
|
|
# Merge scaling factor into lora_b due to associativity of matrix multiplication:
|
|
# (A * B) * C = A * (B * C)
|
|
lora_a_list[layer_id] = lora_a.transpose(0, 1)
|
|
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
|
|
|
|
# pad lora ranks to be compatible with sgmv
|
|
lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list]
|
|
lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list]
|
|
|
|
if lora_a_list:
|
|
# update rank if it was padded
|
|
padded_rank = lora_a_list[0].size(1)
|
|
config.r = padded_rank
|
|
|
|
return LoraWeights(
|
|
*shard_lora_weights(
|
|
weights_a=lora_a_list,
|
|
weights_b=lora_b_list,
|
|
split_dim=0 if layer_type in {"o_proj", "down_proj", "lm_head"} else 1,
|
|
process_group=process_group,
|
|
),
|
|
config,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class RankSegments:
|
|
rank: int
|
|
|
|
lora_a_ptr: torch.Tensor
|
|
lora_b_ptr: torch.Tensor
|
|
|
|
# prefill (sgmv)
|
|
tmp_shrink: torch.Tensor
|
|
tmp_expand: torch.Tensor
|
|
segment_starts: torch.Tensor
|
|
segment_ends: torch.Tensor
|
|
|
|
# decode (bgmv)
|
|
indices: torch.Tensor
|
|
|
|
|
|
@dataclass
|
|
class BatchLoraWeights(BatchAdapterWeights):
|
|
lora_a: Dict[int, torch.Tensor]
|
|
lora_b: Dict[int, torch.Tensor]
|
|
adapter_index_configs: Dict[int, LoraConfig]
|
|
rank_data: Dict[int, RankSegments]
|
|
use_sgmv: bool
|
|
|
|
def has_adapter(self, adapter_index: int) -> bool:
|
|
return adapter_index in self.adapter_index_configs
|
|
|
|
def can_vectorize(self, pg: ProcessGroup) -> bool:
|
|
return all(
|
|
rank_data.rank // pg.size() <= MAX_RANK_CUSTOM
|
|
for rank_data in self.rank_data.values()
|
|
)
|
|
|
|
@classmethod
|
|
def load(
|
|
self,
|
|
adapter_weights: Dict[int, AdapterWeights],
|
|
meta: AdapterBatchMetadata,
|
|
prefill: bool,
|
|
prefill_head_indices: Optional[torch.Tensor],
|
|
) -> Optional["BatchLoraWeights"]:
|
|
adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
|
|
adapter_weights = {
|
|
k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)
|
|
}
|
|
if not adapter_weights:
|
|
return None
|
|
|
|
first_weights = next(iter(adapter_weights.values()))
|
|
device = first_weights.weights_a.device
|
|
segment_indices = meta.segment_indices
|
|
|
|
lora_a = {
|
|
idx: adapter_weights[idx].weights_a
|
|
for idx in segment_indices
|
|
if idx in adapter_weights
|
|
}
|
|
lora_b = {
|
|
idx: adapter_weights[idx].weights_b
|
|
for idx in segment_indices
|
|
if idx in adapter_weights
|
|
}
|
|
|
|
max_rank = max(
|
|
(
|
|
adapter_weights[idx].lora_a_r
|
|
for idx in segment_indices
|
|
if idx in adapter_weights
|
|
),
|
|
default=0,
|
|
)
|
|
|
|
if prefill or max_rank > BGMV_MAX_RANK:
|
|
use_sgmv = True
|
|
lora_a_ptr = torch.tensor(
|
|
[
|
|
(
|
|
adapter_weights[idx].weights_a.data_ptr()
|
|
if idx in adapter_weights
|
|
else 0
|
|
)
|
|
for idx in segment_indices
|
|
],
|
|
dtype=torch.int64,
|
|
device=device,
|
|
)
|
|
lora_b_ptr = torch.tensor(
|
|
[
|
|
(
|
|
adapter_weights[idx].weights_b.data_ptr()
|
|
if idx in adapter_weights
|
|
else 0
|
|
)
|
|
for idx in segment_indices
|
|
],
|
|
dtype=torch.int64,
|
|
device=device,
|
|
)
|
|
else:
|
|
use_sgmv = False
|
|
lora_a_ptr = torch.tensor(
|
|
[
|
|
(
|
|
adapter_weights[idx].weights_a_t.data_ptr()
|
|
if idx in adapter_weights
|
|
else 0
|
|
)
|
|
for idx in segment_indices
|
|
],
|
|
dtype=torch.int64,
|
|
device=device,
|
|
)
|
|
lora_b_ptr = torch.tensor(
|
|
[
|
|
(
|
|
adapter_weights[idx].weights_b_t.data_ptr()
|
|
if idx in adapter_weights
|
|
else 0
|
|
)
|
|
for idx in segment_indices
|
|
],
|
|
dtype=torch.int64,
|
|
device=device,
|
|
)
|
|
|
|
adapter_index_configs = {
|
|
idx: adapter_weights[idx].adapter_config
|
|
for idx in segment_indices
|
|
if idx in adapter_weights
|
|
}
|
|
|
|
adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}
|
|
|
|
rank_indices = defaultdict(list)
|
|
for segment_idx, adapter_idx in enumerate(segment_indices):
|
|
if adapter_idx not in adapter_weights:
|
|
continue
|
|
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)
|
|
|
|
if prefill_head_indices is not None:
|
|
j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
|
|
for head_index in prefill_head_indices:
|
|
# j cannot go out of bounds as that would mean there are tokens without corresponding adapters
|
|
if head_index < meta.adapter_segments[j]:
|
|
prefill_head_segment_ends[-1] += 1
|
|
else:
|
|
prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
|
|
prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
|
|
j += 1
|
|
|
|
rank_data = {}
|
|
for rank, indices in rank_indices.items():
|
|
tmp_shrink = None
|
|
tmp_expand = None
|
|
segment_starts = None
|
|
segment_ends = None
|
|
batch_indices = None
|
|
|
|
if use_sgmv:
|
|
lora_a_ptr_indices = lora_a_ptr[indices]
|
|
tmp_shrink, tmp_expand = get_tmp_tensors(
|
|
lora_a_ptr_indices.size(0), rank, device
|
|
)
|
|
segment_starts = meta.adapter_segments[indices]
|
|
segment_ends = meta.adapter_segments[[i + 1 for i in indices]]
|
|
if prefill_head_indices is not None:
|
|
for i, segment_index in enumerate(indices):
|
|
segment_starts[i] = prefill_head_segment_starts[segment_index]
|
|
segment_ends[i] = prefill_head_segment_ends[segment_index]
|
|
else:
|
|
rank_indices = set(indices)
|
|
batch_indices = [
|
|
adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()
|
|
]
|
|
batch_indices = [
|
|
idx if idx in rank_indices else -1 for idx in batch_indices
|
|
]
|
|
batch_indices = torch.tensor(
|
|
batch_indices, dtype=torch.int64, device=device
|
|
)
|
|
|
|
rank_data[rank] = RankSegments(
|
|
rank=rank,
|
|
tmp_shrink=tmp_shrink,
|
|
tmp_expand=tmp_expand,
|
|
lora_a_ptr=lora_a_ptr[indices],
|
|
lora_b_ptr=lora_b_ptr[indices],
|
|
segment_starts=segment_starts,
|
|
segment_ends=segment_ends,
|
|
indices=batch_indices,
|
|
)
|
|
|
|
return BatchLoraWeights(
|
|
lora_a=lora_a,
|
|
lora_b=lora_b,
|
|
adapter_index_configs=adapter_index_configs,
|
|
rank_data=rank_data,
|
|
use_sgmv=use_sgmv,
|
|
)
|
|
|
|
|
|
def get_scaling_factor(
|
|
lora_alpha: int,
|
|
r: int,
|
|
uses_rslora: bool = False,
|
|
) -> float:
|
|
"""Computes the scaling factor for the lora weights."""
|
|
if uses_rslora:
|
|
return lora_alpha / (r**0.5)
|
|
return lora_alpha / r
|
|
|
|
|
|
def _convert_lora(v: AdapterWeights) -> AdapterWeights:
|
|
if hasattr(v, "lora_weights"):
|
|
return v.lora_weights
|
|
return v
|