45 lines
1.1 KiB
Python
45 lines
1.1 KiB
Python
# Origin: https://github.com/predibase/lorax
|
|
# Path: lorax/server/lorax_server/adapters/config.py
|
|
# License: Apache License Version 2.0, January 2004
|
|
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
|
|
|
|
import torch
|
|
|
|
from text_generation_server.adapters.weights import AdapterWeights
|
|
|
|
if TYPE_CHECKING:
|
|
from text_generation_server.models.model import Model
|
|
|
|
|
|
@dataclass
|
|
class ModuleMap:
|
|
module_name: str
|
|
module_weights: Dict[str, Tuple[torch.Tensor, str]]
|
|
|
|
|
|
@dataclass
|
|
class AdapterConfig(ABC):
|
|
base_model_name_or_path: str
|
|
|
|
@abstractmethod
|
|
def map_weights_for_model(
|
|
self,
|
|
adapter_weights: Dict[int, AdapterWeights],
|
|
weight_names: Tuple[str],
|
|
) -> Tuple[ModuleMap, Set[str]]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def load_batched_adapter_weights(
|
|
self,
|
|
model: "Model",
|
|
module_map: ModuleMap,
|
|
layer_type: str,
|
|
unused_weight_names: Set[str],
|
|
dynamic: bool,
|
|
) -> Optional[AdapterWeights]:
|
|
pass
|