fix: refactors and adjust flash llama lora logic
This commit is contained in:
parent
224455f389
commit
4f1543d3c7
|
@ -1,31 +1,13 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
# Origin: https://github.com/predibase/lorax
|
||||
# Path: lorax/server/lorax_server/adapters/__init__.py
|
||||
# License: Apache License Version 2.0, January 2004
|
||||
|
||||
from text_generation_server.adapters.config import AdapterConfig
|
||||
from text_generation_server.adapters.lora import LoraConfig
|
||||
from text_generation_server.adapters.weights import (
|
||||
AdapterBatchData,
|
||||
AdapterBatchMetadata,
|
||||
)
|
||||
|
||||
|
||||
def load_adapter_config(
|
||||
config_path: Optional[Path],
|
||||
adapter_config_path: Optional[Path],
|
||||
api_token: str,
|
||||
) -> AdapterConfig:
|
||||
if adapter_config_path is not None and adapter_config_path.exists():
|
||||
return LoraConfig.load(str(adapter_config_path.parent), api_token)
|
||||
|
||||
raise ValueError(
|
||||
f"No valid adapter config file found: "
|
||||
f"tried {adapter_config_path} and {config_path}"
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AdapterBatchData",
|
||||
"AdapterBatchMetadata",
|
||||
"load_adapter_config",
|
||||
]
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
# Origin: https://github.com/predibase/lorax
|
||||
# Path: lorax/server/lorax_server/adapters/config.py
|
||||
# License: Apache License Version 2.0, January 2004
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
|
||||
|
@ -10,7 +14,10 @@ if TYPE_CHECKING:
|
|||
from text_generation_server.models.model import Model
|
||||
|
||||
|
||||
ModuleMap = Dict[str, Dict[str, Tuple[torch.Tensor, str]]]
|
||||
@dataclass
|
||||
class ModuleMap:
|
||||
module_name: str
|
||||
module_weights: Dict[str, Tuple[torch.Tensor, str]]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -20,7 +27,7 @@ class AdapterConfig(ABC):
|
|||
@abstractmethod
|
||||
def map_weights_for_model(
|
||||
self,
|
||||
adapter_weights: Dict,
|
||||
adapter_weights: Dict[int, AdapterWeights],
|
||||
weight_names: Tuple[str],
|
||||
) -> Tuple[ModuleMap, Set[str]]:
|
||||
pass
|
||||
|
@ -29,7 +36,7 @@ class AdapterConfig(ABC):
|
|||
def load_batched_adapter_weights(
|
||||
self,
|
||||
model: "Model",
|
||||
module_map: Dict[str, Dict],
|
||||
module_map: ModuleMap,
|
||||
layer_type: str,
|
||||
unused_weight_names: Set[str],
|
||||
dynamic: bool,
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
# Origin: https://github.com/predibase/lorax
|
||||
# Path: lorax/server/lorax_server/adapters/lora.py
|
||||
# License: Apache License Version 2.0, January 2004
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
@ -25,8 +29,6 @@ from text_generation_server.utils.sgmv import (
|
|||
if TYPE_CHECKING:
|
||||
from text_generation_server.models.model import Model
|
||||
|
||||
EMPTY_TENSOR = torch.tensor([])
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoraConfig(AdapterConfig):
|
||||
|
@ -38,7 +40,7 @@ class LoraConfig(AdapterConfig):
|
|||
|
||||
def map_weights_for_model(
|
||||
self,
|
||||
adapter_weights: Dict,
|
||||
adapter_weights: Dict[int, AdapterWeights],
|
||||
weight_names: Tuple[str],
|
||||
) -> Tuple[ModuleMap, Set[str]]:
|
||||
adapter_weight_names = set()
|
||||
|
@ -262,7 +264,7 @@ class BatchLoraWeights(BatchAdapterWeights):
|
|||
if not adapter_weights:
|
||||
return None
|
||||
|
||||
first_weights = list(adapter_weights.values())[0]
|
||||
first_weights = next(iter(adapter_weights.values()))
|
||||
device = first_weights.weights_a.device
|
||||
segment_indices = meta.segment_indices
|
||||
|
||||
|
@ -293,7 +295,7 @@ class BatchLoraWeights(BatchAdapterWeights):
|
|||
(
|
||||
adapter_weights[idx].weights_a.data_ptr()
|
||||
if idx in adapter_weights
|
||||
else EMPTY_TENSOR.data_ptr()
|
||||
else 0
|
||||
)
|
||||
for idx in segment_indices
|
||||
],
|
||||
|
@ -305,7 +307,7 @@ class BatchLoraWeights(BatchAdapterWeights):
|
|||
(
|
||||
adapter_weights[idx].weights_b.data_ptr()
|
||||
if idx in adapter_weights
|
||||
else EMPTY_TENSOR.data_ptr()
|
||||
else 0
|
||||
)
|
||||
for idx in segment_indices
|
||||
],
|
||||
|
@ -319,7 +321,7 @@ class BatchLoraWeights(BatchAdapterWeights):
|
|||
(
|
||||
adapter_weights[idx].weights_a_t.data_ptr()
|
||||
if idx in adapter_weights
|
||||
else EMPTY_TENSOR.data_ptr()
|
||||
else 0
|
||||
)
|
||||
for idx in segment_indices
|
||||
],
|
||||
|
@ -331,7 +333,7 @@ class BatchLoraWeights(BatchAdapterWeights):
|
|||
(
|
||||
adapter_weights[idx].weights_b_t.data_ptr()
|
||||
if idx in adapter_weights
|
||||
else EMPTY_TENSOR.data_ptr()
|
||||
else 0
|
||||
)
|
||||
for idx in segment_indices
|
||||
],
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
# Origin: https://github.com/predibase/lorax
|
||||
# Path: lorax/server/lorax_server/adapters/weights.py
|
||||
# License: Apache License Version 2.0, January 2004
|
||||
|
||||
from abc import ABC, abstractclassmethod
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
|
|
|
@ -59,7 +59,7 @@ def load_attention(config, prefix, weights, layer_id):
|
|||
|
||||
# if specific model type, load the correct attention
|
||||
if config.model_type == "phi3":
|
||||
base_layer = TensorParallelColumnLinear.load_qkv(
|
||||
return TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
weights=weights,
|
||||
|
@ -68,7 +68,7 @@ def load_attention(config, prefix, weights, layer_id):
|
|||
num_key_value_heads=config.num_key_value_heads,
|
||||
)
|
||||
elif config.model_type == "baichuan":
|
||||
base_layer = TensorParallelColumnLinear.load_qkv(
|
||||
return TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.W_pack",
|
||||
weights=weights,
|
||||
|
@ -76,28 +76,28 @@ def load_attention(config, prefix, weights, layer_id):
|
|||
num_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
)
|
||||
else:
|
||||
# otherwise, load the default attention based on the number of heads
|
||||
base_layer = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
# otherwise, load the default attention based on the number of heads
|
||||
base_layer = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
return TensorParallelMultiAdapterLinear.load(
|
||||
base_layer,
|
||||
layer_id,
|
||||
["q_proj", "k_proj", "v_proj"],
|
||||
sizes=[
|
||||
head_size * config.num_attention_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
],
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
return TensorParallelMultiAdapterLinear.load(
|
||||
base_layer,
|
||||
layer_id,
|
||||
["q_proj", "k_proj", "v_proj"],
|
||||
sizes=[
|
||||
head_size * config.num_attention_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
],
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
|
||||
class FlashLlamaAttention(torch.nn.Module):
|
||||
|
@ -240,7 +240,7 @@ class LlamaMLP(nn.Module):
|
|||
# Fuse gate and up proj
|
||||
bias = getattr(config, "mlp_bias", False)
|
||||
if config.model_type == "phi3":
|
||||
gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
||||
config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
weights=weights,
|
||||
|
@ -255,16 +255,16 @@ class LlamaMLP(nn.Module):
|
|||
bias=bias,
|
||||
)
|
||||
|
||||
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||
gate_up_proj,
|
||||
index,
|
||||
["gate_proj", "up_proj"],
|
||||
sizes=[
|
||||
config.intermediate_size,
|
||||
config.intermediate_size,
|
||||
],
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||
gate_up_proj,
|
||||
index,
|
||||
["gate_proj", "up_proj"],
|
||||
sizes=[
|
||||
config.intermediate_size,
|
||||
config.intermediate_size,
|
||||
],
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
down_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
|
@ -273,12 +273,15 @@ class LlamaMLP(nn.Module):
|
|||
bias=bias,
|
||||
)
|
||||
|
||||
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||
down_proj,
|
||||
index,
|
||||
"down_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
if config.model_type == "phi3":
|
||||
self.down_proj = down_proj
|
||||
else:
|
||||
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||
down_proj,
|
||||
index,
|
||||
"down_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
|
@ -471,9 +474,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
weights=weights,
|
||||
)
|
||||
|
||||
def get_lora_index(self, adapter_id):
|
||||
return self.model.layers[0].self_attn.key_to_index[adapter_id]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
# Origin: https://github.com/predibase/lorax
|
||||
# Path: lorax/server/lorax_server/utils/adapter.py
|
||||
# License: Apache License Version 2.0, January 2004
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
|
@ -20,22 +24,20 @@ if TYPE_CHECKING:
|
|||
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterParameters:
|
||||
def __init__(
|
||||
self, adapter_ids, weights, merge_strategy, density, majority_sign_method
|
||||
):
|
||||
self.adapter_ids = adapter_ids
|
||||
self.weights = weights
|
||||
self.merge_strategy = merge_strategy
|
||||
self.density = density
|
||||
self.majority_sign_method = majority_sign_method
|
||||
adapter_ids: Tuple[str]
|
||||
weights: Tuple[float]
|
||||
merge_strategy: NotImplemented
|
||||
density: float
|
||||
majority_sign_method: NotImplemented
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterSource:
|
||||
def __init__(self, adapter_id: str, model_id: str, revision: str):
|
||||
self.adapter_id = adapter_id
|
||||
self.model_id = model_id
|
||||
self.revision = revision
|
||||
adapter_id: str
|
||||
model_id: str
|
||||
revision: str
|
||||
|
||||
|
||||
def load_and_merge_adapters(
|
||||
|
@ -65,11 +67,11 @@ def load_and_merge_adapters(
|
|||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterParametersContainer:
|
||||
def __init__(self, adapter_parameters, adapter_source, adapter_index):
|
||||
self.adapter_parameters = adapter_parameters
|
||||
self.adapter_source = adapter_source
|
||||
self.adapter_index = adapter_index
|
||||
adapter_parameters: AdapterParameters
|
||||
adapter_source: str
|
||||
adapter_index: int
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.adapter_index
|
||||
|
@ -123,7 +125,7 @@ def check_architectures(
|
|||
):
|
||||
try:
|
||||
if not adapter_config.base_model_name_or_path:
|
||||
# Avoid execuation latency caused by the network connection retrying for AutoConfig.from_pretrained(None)
|
||||
# Avoid execution latency caused by the network connection retrying for AutoConfig.from_pretrained(None)
|
||||
return
|
||||
|
||||
expected_config = AutoConfig.from_pretrained(
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import os
|
||||
import json
|
||||
from typing import Union
|
||||
from loguru import logger
|
||||
import torch
|
||||
|
||||
|
@ -45,7 +45,9 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
|
|||
tokenizer.save_pretrained(cache_dir)
|
||||
|
||||
|
||||
def download_peft(model_id, revision, trust_remote_code):
|
||||
def download_peft(
|
||||
model_id: Union[str, os.PathLike], revision: str, trust_remote_code: bool
|
||||
):
|
||||
torch_dtype = torch.float16
|
||||
try:
|
||||
_model = AutoPeftModelForCausalLM.from_pretrained(
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
# Origin: https://github.com/predibase/lorax
|
||||
# Path: lorax/server/lorax_server/utils/segments.py
|
||||
# License: Apache License Version 2.0, January 2004
|
||||
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
# Origin: https://github.com/predibase/lorax
|
||||
# Path: lorax/server/lorax_server/utils/sgmv.py
|
||||
# License: Apache License Version 2.0, January 2004
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
|
|
Loading…
Reference in New Issue