fix: refactors and adjust flash llama lora logic

This commit is contained in:
drbh 2024-06-19 16:13:42 +00:00
parent 224455f389
commit 4f1543d3c7
9 changed files with 101 additions and 94 deletions

View File

@ -1,31 +1,13 @@
import json
from pathlib import Path
from typing import Dict, Optional
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/__init__.py
# License: Apache License Version 2.0, January 2004
from text_generation_server.adapters.config import AdapterConfig
from text_generation_server.adapters.lora import LoraConfig
from text_generation_server.adapters.weights import (
AdapterBatchData,
AdapterBatchMetadata,
)
def load_adapter_config(
config_path: Optional[Path],
adapter_config_path: Optional[Path],
api_token: str,
) -> AdapterConfig:
if adapter_config_path is not None and adapter_config_path.exists():
return LoraConfig.load(str(adapter_config_path.parent), api_token)
raise ValueError(
f"No valid adapter config file found: "
f"tried {adapter_config_path} and {config_path}"
)
__all__ = [
"AdapterBatchData",
"AdapterBatchMetadata",
"load_adapter_config",
]

View File

@ -1,3 +1,7 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/config.py
# License: Apache License Version 2.0, January 2004
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
@ -10,7 +14,10 @@ if TYPE_CHECKING:
from text_generation_server.models.model import Model
ModuleMap = Dict[str, Dict[str, Tuple[torch.Tensor, str]]]
@dataclass
class ModuleMap:
module_name: str
module_weights: Dict[str, Tuple[torch.Tensor, str]]
@dataclass
@ -20,7 +27,7 @@ class AdapterConfig(ABC):
@abstractmethod
def map_weights_for_model(
self,
adapter_weights: Dict,
adapter_weights: Dict[int, AdapterWeights],
weight_names: Tuple[str],
) -> Tuple[ModuleMap, Set[str]]:
pass
@ -29,7 +36,7 @@ class AdapterConfig(ABC):
def load_batched_adapter_weights(
self,
model: "Model",
module_map: Dict[str, Dict],
module_map: ModuleMap,
layer_type: str,
unused_weight_names: Set[str],
dynamic: bool,

View File

@ -1,3 +1,7 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/lora.py
# License: Apache License Version 2.0, January 2004
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union
@ -25,8 +29,6 @@ from text_generation_server.utils.sgmv import (
if TYPE_CHECKING:
from text_generation_server.models.model import Model
EMPTY_TENSOR = torch.tensor([])
@dataclass
class LoraConfig(AdapterConfig):
@ -38,7 +40,7 @@ class LoraConfig(AdapterConfig):
def map_weights_for_model(
self,
adapter_weights: Dict,
adapter_weights: Dict[int, AdapterWeights],
weight_names: Tuple[str],
) -> Tuple[ModuleMap, Set[str]]:
adapter_weight_names = set()
@ -262,7 +264,7 @@ class BatchLoraWeights(BatchAdapterWeights):
if not adapter_weights:
return None
first_weights = list(adapter_weights.values())[0]
first_weights = next(iter(adapter_weights.values()))
device = first_weights.weights_a.device
segment_indices = meta.segment_indices
@ -293,7 +295,7 @@ class BatchLoraWeights(BatchAdapterWeights):
(
adapter_weights[idx].weights_a.data_ptr()
if idx in adapter_weights
else EMPTY_TENSOR.data_ptr()
else 0
)
for idx in segment_indices
],
@ -305,7 +307,7 @@ class BatchLoraWeights(BatchAdapterWeights):
(
adapter_weights[idx].weights_b.data_ptr()
if idx in adapter_weights
else EMPTY_TENSOR.data_ptr()
else 0
)
for idx in segment_indices
],
@ -319,7 +321,7 @@ class BatchLoraWeights(BatchAdapterWeights):
(
adapter_weights[idx].weights_a_t.data_ptr()
if idx in adapter_weights
else EMPTY_TENSOR.data_ptr()
else 0
)
for idx in segment_indices
],
@ -331,7 +333,7 @@ class BatchLoraWeights(BatchAdapterWeights):
(
adapter_weights[idx].weights_b_t.data_ptr()
if idx in adapter_weights
else EMPTY_TENSOR.data_ptr()
else 0
)
for idx in segment_indices
],

View File

@ -1,3 +1,7 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/weights.py
# License: Apache License Version 2.0, January 2004
from abc import ABC, abstractclassmethod
from collections import defaultdict
from dataclasses import dataclass

View File

@ -59,7 +59,7 @@ def load_attention(config, prefix, weights, layer_id):
# if specific model type, load the correct attention
if config.model_type == "phi3":
base_layer = TensorParallelColumnLinear.load_qkv(
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.qkv_proj",
weights=weights,
@ -68,7 +68,7 @@ def load_attention(config, prefix, weights, layer_id):
num_key_value_heads=config.num_key_value_heads,
)
elif config.model_type == "baichuan":
base_layer = TensorParallelColumnLinear.load_qkv(
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.W_pack",
weights=weights,
@ -76,28 +76,28 @@ def load_attention(config, prefix, weights, layer_id):
num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
)
else:
# otherwise, load the default attention based on the number of heads
base_layer = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=bias,
)
# otherwise, load the default attention based on the number of heads
base_layer = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=bias,
)
head_size = config.hidden_size // config.num_attention_heads
return TensorParallelMultiAdapterLinear.load(
base_layer,
layer_id,
["q_proj", "k_proj", "v_proj"],
sizes=[
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
],
process_group=weights.process_group,
)
head_size = config.hidden_size // config.num_attention_heads
return TensorParallelMultiAdapterLinear.load(
base_layer,
layer_id,
["q_proj", "k_proj", "v_proj"],
sizes=[
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
],
process_group=weights.process_group,
)
class FlashLlamaAttention(torch.nn.Module):
@ -240,7 +240,7 @@ class LlamaMLP(nn.Module):
# Fuse gate and up proj
bias = getattr(config, "mlp_bias", False)
if config.model_type == "phi3":
gate_up_proj = TensorParallelColumnLinear.load_gate_up(
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
config,
prefix=f"{prefix}.gate_up_proj",
weights=weights,
@ -255,16 +255,16 @@ class LlamaMLP(nn.Module):
bias=bias,
)
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
index,
["gate_proj", "up_proj"],
sizes=[
config.intermediate_size,
config.intermediate_size,
],
process_group=weights.process_group,
)
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
index,
["gate_proj", "up_proj"],
sizes=[
config.intermediate_size,
config.intermediate_size,
],
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config,
@ -273,12 +273,15 @@ class LlamaMLP(nn.Module):
bias=bias,
)
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
index,
"down_proj",
process_group=weights.process_group,
)
if config.model_type == "phi3":
self.down_proj = down_proj
else:
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
index,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
@ -471,9 +474,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
weights=weights,
)
def get_lora_index(self, adapter_id):
return self.model.layers[0].self_attn.key_to_index[adapter_id]
def forward(
self,
input_ids: torch.Tensor,

View File

@ -1,3 +1,7 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/utils/adapter.py
# License: Apache License Version 2.0, January 2004
import warnings
from dataclasses import dataclass
from functools import lru_cache
@ -20,22 +24,20 @@ if TYPE_CHECKING:
BASE_MODEL_ADAPTER_ID = "__base_model__"
@dataclass
class AdapterParameters:
def __init__(
self, adapter_ids, weights, merge_strategy, density, majority_sign_method
):
self.adapter_ids = adapter_ids
self.weights = weights
self.merge_strategy = merge_strategy
self.density = density
self.majority_sign_method = majority_sign_method
adapter_ids: Tuple[str]
weights: Tuple[float]
merge_strategy: NotImplemented
density: float
majority_sign_method: NotImplemented
@dataclass
class AdapterSource:
def __init__(self, adapter_id: str, model_id: str, revision: str):
self.adapter_id = adapter_id
self.model_id = model_id
self.revision = revision
adapter_id: str
model_id: str
revision: str
def load_and_merge_adapters(
@ -65,11 +67,11 @@ def load_and_merge_adapters(
)
@dataclass
class AdapterParametersContainer:
def __init__(self, adapter_parameters, adapter_source, adapter_index):
self.adapter_parameters = adapter_parameters
self.adapter_source = adapter_source
self.adapter_index = adapter_index
adapter_parameters: AdapterParameters
adapter_source: str
adapter_index: int
def __hash__(self) -> int:
return self.adapter_index
@ -123,7 +125,7 @@ def check_architectures(
):
try:
if not adapter_config.base_model_name_or_path:
# Avoid execuation latency caused by the network connection retrying for AutoConfig.from_pretrained(None)
# Avoid execution latency caused by the network connection retrying for AutoConfig.from_pretrained(None)
return
expected_config = AutoConfig.from_pretrained(

View File

@ -1,5 +1,5 @@
import os
import json
from typing import Union
from loguru import logger
import torch
@ -45,7 +45,9 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
tokenizer.save_pretrained(cache_dir)
def download_peft(model_id, revision, trust_remote_code):
def download_peft(
model_id: Union[str, os.PathLike], revision: str, trust_remote_code: bool
):
torch_dtype = torch.float16
try:
_model = AutoPeftModelForCausalLM.from_pretrained(

View File

@ -1,3 +1,7 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/utils/segments.py
# License: Apache License Version 2.0, January 2004
from typing import List, Tuple, Union
import torch

View File

@ -1,3 +1,7 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/utils/sgmv.py
# License: Apache License Version 2.0, January 2004
import os
import warnings
from functools import lru_cache