fix: refactors and adjust flash llama lora logic

This commit is contained in:
drbh 2024-06-19 16:13:42 +00:00
parent 224455f389
commit 4f1543d3c7
9 changed files with 101 additions and 94 deletions

View File

@ -1,31 +1,13 @@
import json # Origin: https://github.com/predibase/lorax
from pathlib import Path # Path: lorax/server/lorax_server/adapters/__init__.py
from typing import Dict, Optional # License: Apache License Version 2.0, January 2004
from text_generation_server.adapters.config import AdapterConfig
from text_generation_server.adapters.lora import LoraConfig
from text_generation_server.adapters.weights import ( from text_generation_server.adapters.weights import (
AdapterBatchData, AdapterBatchData,
AdapterBatchMetadata, AdapterBatchMetadata,
) )
def load_adapter_config(
config_path: Optional[Path],
adapter_config_path: Optional[Path],
api_token: str,
) -> AdapterConfig:
if adapter_config_path is not None and adapter_config_path.exists():
return LoraConfig.load(str(adapter_config_path.parent), api_token)
raise ValueError(
f"No valid adapter config file found: "
f"tried {adapter_config_path} and {config_path}"
)
__all__ = [ __all__ = [
"AdapterBatchData", "AdapterBatchData",
"AdapterBatchMetadata", "AdapterBatchMetadata",
"load_adapter_config",
] ]

View File

@ -1,3 +1,7 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/config.py
# License: Apache License Version 2.0, January 2004
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
@ -10,7 +14,10 @@ if TYPE_CHECKING:
from text_generation_server.models.model import Model from text_generation_server.models.model import Model
ModuleMap = Dict[str, Dict[str, Tuple[torch.Tensor, str]]] @dataclass
class ModuleMap:
module_name: str
module_weights: Dict[str, Tuple[torch.Tensor, str]]
@dataclass @dataclass
@ -20,7 +27,7 @@ class AdapterConfig(ABC):
@abstractmethod @abstractmethod
def map_weights_for_model( def map_weights_for_model(
self, self,
adapter_weights: Dict, adapter_weights: Dict[int, AdapterWeights],
weight_names: Tuple[str], weight_names: Tuple[str],
) -> Tuple[ModuleMap, Set[str]]: ) -> Tuple[ModuleMap, Set[str]]:
pass pass
@ -29,7 +36,7 @@ class AdapterConfig(ABC):
def load_batched_adapter_weights( def load_batched_adapter_weights(
self, self,
model: "Model", model: "Model",
module_map: Dict[str, Dict], module_map: ModuleMap,
layer_type: str, layer_type: str,
unused_weight_names: Set[str], unused_weight_names: Set[str],
dynamic: bool, dynamic: bool,

View File

@ -1,3 +1,7 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/lora.py
# License: Apache License Version 2.0, January 2004
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union
@ -25,8 +29,6 @@ from text_generation_server.utils.sgmv import (
if TYPE_CHECKING: if TYPE_CHECKING:
from text_generation_server.models.model import Model from text_generation_server.models.model import Model
EMPTY_TENSOR = torch.tensor([])
@dataclass @dataclass
class LoraConfig(AdapterConfig): class LoraConfig(AdapterConfig):
@ -38,7 +40,7 @@ class LoraConfig(AdapterConfig):
def map_weights_for_model( def map_weights_for_model(
self, self,
adapter_weights: Dict, adapter_weights: Dict[int, AdapterWeights],
weight_names: Tuple[str], weight_names: Tuple[str],
) -> Tuple[ModuleMap, Set[str]]: ) -> Tuple[ModuleMap, Set[str]]:
adapter_weight_names = set() adapter_weight_names = set()
@ -262,7 +264,7 @@ class BatchLoraWeights(BatchAdapterWeights):
if not adapter_weights: if not adapter_weights:
return None return None
first_weights = list(adapter_weights.values())[0] first_weights = next(iter(adapter_weights.values()))
device = first_weights.weights_a.device device = first_weights.weights_a.device
segment_indices = meta.segment_indices segment_indices = meta.segment_indices
@ -293,7 +295,7 @@ class BatchLoraWeights(BatchAdapterWeights):
( (
adapter_weights[idx].weights_a.data_ptr() adapter_weights[idx].weights_a.data_ptr()
if idx in adapter_weights if idx in adapter_weights
else EMPTY_TENSOR.data_ptr() else 0
) )
for idx in segment_indices for idx in segment_indices
], ],
@ -305,7 +307,7 @@ class BatchLoraWeights(BatchAdapterWeights):
( (
adapter_weights[idx].weights_b.data_ptr() adapter_weights[idx].weights_b.data_ptr()
if idx in adapter_weights if idx in adapter_weights
else EMPTY_TENSOR.data_ptr() else 0
) )
for idx in segment_indices for idx in segment_indices
], ],
@ -319,7 +321,7 @@ class BatchLoraWeights(BatchAdapterWeights):
( (
adapter_weights[idx].weights_a_t.data_ptr() adapter_weights[idx].weights_a_t.data_ptr()
if idx in adapter_weights if idx in adapter_weights
else EMPTY_TENSOR.data_ptr() else 0
) )
for idx in segment_indices for idx in segment_indices
], ],
@ -331,7 +333,7 @@ class BatchLoraWeights(BatchAdapterWeights):
( (
adapter_weights[idx].weights_b_t.data_ptr() adapter_weights[idx].weights_b_t.data_ptr()
if idx in adapter_weights if idx in adapter_weights
else EMPTY_TENSOR.data_ptr() else 0
) )
for idx in segment_indices for idx in segment_indices
], ],

View File

@ -1,3 +1,7 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/weights.py
# License: Apache License Version 2.0, January 2004
from abc import ABC, abstractclassmethod from abc import ABC, abstractclassmethod
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass

View File

@ -59,7 +59,7 @@ def load_attention(config, prefix, weights, layer_id):
# if specific model type, load the correct attention # if specific model type, load the correct attention
if config.model_type == "phi3": if config.model_type == "phi3":
base_layer = TensorParallelColumnLinear.load_qkv( return TensorParallelColumnLinear.load_qkv(
config, config,
prefix=f"{prefix}.qkv_proj", prefix=f"{prefix}.qkv_proj",
weights=weights, weights=weights,
@ -68,7 +68,7 @@ def load_attention(config, prefix, weights, layer_id):
num_key_value_heads=config.num_key_value_heads, num_key_value_heads=config.num_key_value_heads,
) )
elif config.model_type == "baichuan": elif config.model_type == "baichuan":
base_layer = TensorParallelColumnLinear.load_qkv( return TensorParallelColumnLinear.load_qkv(
config, config,
prefix=f"{prefix}.W_pack", prefix=f"{prefix}.W_pack",
weights=weights, weights=weights,
@ -76,28 +76,28 @@ def load_attention(config, prefix, weights, layer_id):
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads, num_key_value_heads=config.num_key_value_heads,
) )
else:
# otherwise, load the default attention based on the number of heads
base_layer = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=bias,
)
# otherwise, load the default attention based on the number of heads head_size = config.hidden_size // config.num_attention_heads
base_layer = TensorParallelColumnLinear.load_multi( return TensorParallelMultiAdapterLinear.load(
config, base_layer,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], layer_id,
dim=0, ["q_proj", "k_proj", "v_proj"],
weights=weights, sizes=[
bias=bias, head_size * config.num_attention_heads,
) head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
head_size = config.hidden_size // config.num_attention_heads ],
return TensorParallelMultiAdapterLinear.load( process_group=weights.process_group,
base_layer, )
layer_id,
["q_proj", "k_proj", "v_proj"],
sizes=[
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
],
process_group=weights.process_group,
)
class FlashLlamaAttention(torch.nn.Module): class FlashLlamaAttention(torch.nn.Module):
@ -240,7 +240,7 @@ class LlamaMLP(nn.Module):
# Fuse gate and up proj # Fuse gate and up proj
bias = getattr(config, "mlp_bias", False) bias = getattr(config, "mlp_bias", False)
if config.model_type == "phi3": if config.model_type == "phi3":
gate_up_proj = TensorParallelColumnLinear.load_gate_up( self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
config, config,
prefix=f"{prefix}.gate_up_proj", prefix=f"{prefix}.gate_up_proj",
weights=weights, weights=weights,
@ -255,16 +255,16 @@ class LlamaMLP(nn.Module):
bias=bias, bias=bias,
) )
self.gate_up_proj = TensorParallelMultiAdapterLinear.load( self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj, gate_up_proj,
index, index,
["gate_proj", "up_proj"], ["gate_proj", "up_proj"],
sizes=[ sizes=[
config.intermediate_size, config.intermediate_size,
config.intermediate_size, config.intermediate_size,
], ],
process_group=weights.process_group, process_group=weights.process_group,
) )
down_proj = TensorParallelRowLinear.load( down_proj = TensorParallelRowLinear.load(
config, config,
@ -273,12 +273,15 @@ class LlamaMLP(nn.Module):
bias=bias, bias=bias,
) )
self.down_proj = TensorParallelAdapterRowLinear.load( if config.model_type == "phi3":
down_proj, self.down_proj = down_proj
index, else:
"down_proj", self.down_proj = TensorParallelAdapterRowLinear.load(
process_group=weights.process_group, down_proj,
) index,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = ( self.intermediate_size = (
config.intermediate_size // weights.process_group.size() config.intermediate_size // weights.process_group.size()
@ -471,9 +474,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
weights=weights, weights=weights,
) )
def get_lora_index(self, adapter_id):
return self.model.layers[0].self_attn.key_to_index[adapter_id]
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,

View File

@ -1,3 +1,7 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/utils/adapter.py
# License: Apache License Version 2.0, January 2004
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
@ -20,22 +24,20 @@ if TYPE_CHECKING:
BASE_MODEL_ADAPTER_ID = "__base_model__" BASE_MODEL_ADAPTER_ID = "__base_model__"
@dataclass
class AdapterParameters: class AdapterParameters:
def __init__( adapter_ids: Tuple[str]
self, adapter_ids, weights, merge_strategy, density, majority_sign_method weights: Tuple[float]
): merge_strategy: NotImplemented
self.adapter_ids = adapter_ids density: float
self.weights = weights majority_sign_method: NotImplemented
self.merge_strategy = merge_strategy
self.density = density
self.majority_sign_method = majority_sign_method
@dataclass
class AdapterSource: class AdapterSource:
def __init__(self, adapter_id: str, model_id: str, revision: str): adapter_id: str
self.adapter_id = adapter_id model_id: str
self.model_id = model_id revision: str
self.revision = revision
def load_and_merge_adapters( def load_and_merge_adapters(
@ -65,11 +67,11 @@ def load_and_merge_adapters(
) )
@dataclass
class AdapterParametersContainer: class AdapterParametersContainer:
def __init__(self, adapter_parameters, adapter_source, adapter_index): adapter_parameters: AdapterParameters
self.adapter_parameters = adapter_parameters adapter_source: str
self.adapter_source = adapter_source adapter_index: int
self.adapter_index = adapter_index
def __hash__(self) -> int: def __hash__(self) -> int:
return self.adapter_index return self.adapter_index
@ -123,7 +125,7 @@ def check_architectures(
): ):
try: try:
if not adapter_config.base_model_name_or_path: if not adapter_config.base_model_name_or_path:
# Avoid execuation latency caused by the network connection retrying for AutoConfig.from_pretrained(None) # Avoid execution latency caused by the network connection retrying for AutoConfig.from_pretrained(None)
return return
expected_config = AutoConfig.from_pretrained( expected_config = AutoConfig.from_pretrained(

View File

@ -1,5 +1,5 @@
import os import os
import json from typing import Union
from loguru import logger from loguru import logger
import torch import torch
@ -45,7 +45,9 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
tokenizer.save_pretrained(cache_dir) tokenizer.save_pretrained(cache_dir)
def download_peft(model_id, revision, trust_remote_code): def download_peft(
model_id: Union[str, os.PathLike], revision: str, trust_remote_code: bool
):
torch_dtype = torch.float16 torch_dtype = torch.float16
try: try:
_model = AutoPeftModelForCausalLM.from_pretrained( _model = AutoPeftModelForCausalLM.from_pretrained(

View File

@ -1,3 +1,7 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/utils/segments.py
# License: Apache License Version 2.0, January 2004
from typing import List, Tuple, Union from typing import List, Tuple, Union
import torch import torch

View File

@ -1,3 +1,7 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/utils/sgmv.py
# License: Apache License Version 2.0, January 2004
import os import os
import warnings import warnings
from functools import lru_cache from functools import lru_cache