2024-06-25 12:46:27 -06:00
|
|
|
# Origin: https://github.com/predibase/lorax
|
|
|
|
# Path: lorax/server/lorax_server/adapters/weights.py
|
|
|
|
# License: Apache License Version 2.0, January 2004
|
|
|
|
|
|
|
|
from abc import ABC, abstractclassmethod
|
|
|
|
from collections import defaultdict
|
|
|
|
from dataclasses import dataclass
|
|
|
|
from typing import Dict, List, Optional, Set, Type
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class AdapterBatchMetadata:
|
|
|
|
# [batch_size]
|
|
|
|
adapter_indices: torch.Tensor
|
|
|
|
|
|
|
|
# [num_adapters]
|
|
|
|
adapter_set: Set[int]
|
|
|
|
|
|
|
|
# [num_segments + 1]
|
|
|
|
adapter_segments: torch.Tensor
|
|
|
|
|
|
|
|
# [num_segments]
|
|
|
|
# maps from segment index to adapter index, i.e.:
|
|
|
|
# segment_indices[s] == adapter_indices[i]
|
|
|
|
segment_indices: List[int]
|
|
|
|
|
|
|
|
|
|
|
|
class AdapterWeights(ABC):
|
|
|
|
@abstractclassmethod
|
|
|
|
def get_batch_types(cls) -> List[Type["BatchAdapterWeights"]]:
|
|
|
|
pass
|
|
|
|
|
|
|
|
@property
|
|
|
|
def speculative_tokens(self) -> int:
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
|
|
class BatchAdapterWeights(ABC):
|
|
|
|
@abstractclassmethod
|
|
|
|
def has_adapter(self, adapter_index: int) -> bool:
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractclassmethod
|
|
|
|
def load(
|
|
|
|
cls,
|
|
|
|
adapter_weights: Dict[int, AdapterWeights],
|
|
|
|
meta: "AdapterBatchMetadata",
|
|
|
|
prefill: bool,
|
|
|
|
prefill_head_indices: torch.Tensor,
|
|
|
|
) -> Optional["BatchAdapterWeights"]:
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class LayerAdapterWeights:
|
|
|
|
"""Adapter weights that apply to a particular layer."""
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self.adapter_weights: Dict[int, AdapterWeights] = {}
|
|
|
|
|
|
|
|
def add_adapter(self, adapter_idx: int, weights: AdapterWeights):
|
|
|
|
self.adapter_weights[adapter_idx] = weights
|
|
|
|
|
|
|
|
def remove_adapter(self, adapter_idx: int):
|
|
|
|
if adapter_idx not in self.adapter_weights:
|
|
|
|
return
|
|
|
|
del self.adapter_weights[adapter_idx]
|
|
|
|
|
|
|
|
def is_empty(self) -> bool:
|
|
|
|
return len(self.adapter_weights) == 0
|
|
|
|
|
|
|
|
def get_data(
|
|
|
|
self,
|
|
|
|
meta: AdapterBatchMetadata,
|
|
|
|
prefill: bool,
|
|
|
|
prefill_head_indices: Optional[torch.Tensor],
|
|
|
|
) -> Dict[str, BatchAdapterWeights]:
|
|
|
|
# bucket adapters by batch class
|
|
|
|
adapter_batch_types: Dict[
|
|
|
|
Type[BatchAdapterWeights], Dict[int, AdapterWeights]
|
|
|
|
] = defaultdict(dict)
|
|
|
|
for adapter_index, adapter_weights in self.adapter_weights.items():
|
|
|
|
for batch_type in adapter_weights.get_batch_types():
|
|
|
|
adapter_batch_types[batch_type][adapter_index] = adapter_weights
|
|
|
|
|
|
|
|
batch_data = {}
|
|
|
|
for batch_type, adapter_weights in adapter_batch_types.items():
|
|
|
|
batched_weights = batch_type.load(
|
|
|
|
adapter_weights, meta, prefill, prefill_head_indices
|
|
|
|
)
|
|
|
|
if batched_weights is not None:
|
2024-07-24 13:32:14 -06:00
|
|
|
batch_data = batched_weights
|
2024-06-25 12:46:27 -06:00
|
|
|
return batch_data
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class AdapterBatchData:
|
|
|
|
meta: AdapterBatchMetadata
|
|
|
|
|
|
|
|
# layer type -> adapter type -> batch weight data
|
|
|
|
data: Dict[str, Dict[str, BatchAdapterWeights]]
|
|
|
|
|
|
|
|
prefill: bool
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def from_meta(
|
|
|
|
meta: AdapterBatchMetadata,
|
|
|
|
weights: Dict[str, LayerAdapterWeights],
|
|
|
|
prefill: bool,
|
|
|
|
prefill_head_indices: Optional[torch.Tensor],
|
|
|
|
) -> "AdapterBatchData":
|
|
|
|
data = {}
|
|
|
|
for k, v in weights.items():
|
|
|
|
if v.is_empty():
|
|
|
|
continue
|
|
|
|
data[k] = v.get_data(
|
|
|
|
meta, prefill, prefill_head_indices if k == "lm_head" else None
|
|
|
|
)
|
|
|
|
return AdapterBatchData(meta=meta, data=data, prefill=prefill)
|
|
|
|
|
|
|
|
def ranks(self) -> Set[int]:
|
|
|
|
# TODO(travis): refactor to be less coupled to lora implementation
|
|
|
|
ranks = set()
|
2024-07-24 13:32:14 -06:00
|
|
|
for lora_data in self.data.values():
|
2024-06-25 12:46:27 -06:00
|
|
|
if lora_data is None:
|
|
|
|
continue
|
|
|
|
|
|
|
|
for rank_data in lora_data.rank_data.values():
|
|
|
|
ranks.add(rank_data.rank)
|
|
|
|
|
|
|
|
return ranks
|
|
|
|
|
|
|
|
def layer_names(self) -> Set[str]:
|
|
|
|
return set(self.data.keys())
|
|
|
|
|
|
|
|
def adapter_keys(self) -> Set[str]:
|
|
|
|
adapter_keys = set()
|
|
|
|
for layer_data in self.data.values():
|
|
|
|
adapter_keys.update(layer_data.keys())
|
|
|
|
return adapter_keys
|
|
|
|
|
|
|
|
@property
|
|
|
|
def max_rank(self) -> int:
|
|
|
|
ranks = self.ranks()
|
|
|
|
return max(ranks) if len(ranks) > 0 else 0
|