fix: refactors and adjust flash llama lora logic
This commit is contained in:
parent
224455f389
commit
4f1543d3c7
|
@ -1,31 +1,13 @@
|
||||||
import json
|
# Origin: https://github.com/predibase/lorax
|
||||||
from pathlib import Path
|
# Path: lorax/server/lorax_server/adapters/__init__.py
|
||||||
from typing import Dict, Optional
|
# License: Apache License Version 2.0, January 2004
|
||||||
|
|
||||||
from text_generation_server.adapters.config import AdapterConfig
|
|
||||||
from text_generation_server.adapters.lora import LoraConfig
|
|
||||||
from text_generation_server.adapters.weights import (
|
from text_generation_server.adapters.weights import (
|
||||||
AdapterBatchData,
|
AdapterBatchData,
|
||||||
AdapterBatchMetadata,
|
AdapterBatchMetadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_adapter_config(
|
|
||||||
config_path: Optional[Path],
|
|
||||||
adapter_config_path: Optional[Path],
|
|
||||||
api_token: str,
|
|
||||||
) -> AdapterConfig:
|
|
||||||
if adapter_config_path is not None and adapter_config_path.exists():
|
|
||||||
return LoraConfig.load(str(adapter_config_path.parent), api_token)
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
f"No valid adapter config file found: "
|
|
||||||
f"tried {adapter_config_path} and {config_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AdapterBatchData",
|
"AdapterBatchData",
|
||||||
"AdapterBatchMetadata",
|
"AdapterBatchMetadata",
|
||||||
"load_adapter_config",
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,3 +1,7 @@
|
||||||
|
# Origin: https://github.com/predibase/lorax
|
||||||
|
# Path: lorax/server/lorax_server/adapters/config.py
|
||||||
|
# License: Apache License Version 2.0, January 2004
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
|
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
|
||||||
|
@ -10,7 +14,10 @@ if TYPE_CHECKING:
|
||||||
from text_generation_server.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
|
|
||||||
|
|
||||||
ModuleMap = Dict[str, Dict[str, Tuple[torch.Tensor, str]]]
|
@dataclass
|
||||||
|
class ModuleMap:
|
||||||
|
module_name: str
|
||||||
|
module_weights: Dict[str, Tuple[torch.Tensor, str]]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -20,7 +27,7 @@ class AdapterConfig(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def map_weights_for_model(
|
def map_weights_for_model(
|
||||||
self,
|
self,
|
||||||
adapter_weights: Dict,
|
adapter_weights: Dict[int, AdapterWeights],
|
||||||
weight_names: Tuple[str],
|
weight_names: Tuple[str],
|
||||||
) -> Tuple[ModuleMap, Set[str]]:
|
) -> Tuple[ModuleMap, Set[str]]:
|
||||||
pass
|
pass
|
||||||
|
@ -29,7 +36,7 @@ class AdapterConfig(ABC):
|
||||||
def load_batched_adapter_weights(
|
def load_batched_adapter_weights(
|
||||||
self,
|
self,
|
||||||
model: "Model",
|
model: "Model",
|
||||||
module_map: Dict[str, Dict],
|
module_map: ModuleMap,
|
||||||
layer_type: str,
|
layer_type: str,
|
||||||
unused_weight_names: Set[str],
|
unused_weight_names: Set[str],
|
||||||
dynamic: bool,
|
dynamic: bool,
|
||||||
|
|
|
@ -1,3 +1,7 @@
|
||||||
|
# Origin: https://github.com/predibase/lorax
|
||||||
|
# Path: lorax/server/lorax_server/adapters/lora.py
|
||||||
|
# License: Apache License Version 2.0, January 2004
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union
|
||||||
|
@ -25,8 +29,6 @@ from text_generation_server.utils.sgmv import (
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from text_generation_server.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
|
|
||||||
EMPTY_TENSOR = torch.tensor([])
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LoraConfig(AdapterConfig):
|
class LoraConfig(AdapterConfig):
|
||||||
|
@ -38,7 +40,7 @@ class LoraConfig(AdapterConfig):
|
||||||
|
|
||||||
def map_weights_for_model(
|
def map_weights_for_model(
|
||||||
self,
|
self,
|
||||||
adapter_weights: Dict,
|
adapter_weights: Dict[int, AdapterWeights],
|
||||||
weight_names: Tuple[str],
|
weight_names: Tuple[str],
|
||||||
) -> Tuple[ModuleMap, Set[str]]:
|
) -> Tuple[ModuleMap, Set[str]]:
|
||||||
adapter_weight_names = set()
|
adapter_weight_names = set()
|
||||||
|
@ -262,7 +264,7 @@ class BatchLoraWeights(BatchAdapterWeights):
|
||||||
if not adapter_weights:
|
if not adapter_weights:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
first_weights = list(adapter_weights.values())[0]
|
first_weights = next(iter(adapter_weights.values()))
|
||||||
device = first_weights.weights_a.device
|
device = first_weights.weights_a.device
|
||||||
segment_indices = meta.segment_indices
|
segment_indices = meta.segment_indices
|
||||||
|
|
||||||
|
@ -293,7 +295,7 @@ class BatchLoraWeights(BatchAdapterWeights):
|
||||||
(
|
(
|
||||||
adapter_weights[idx].weights_a.data_ptr()
|
adapter_weights[idx].weights_a.data_ptr()
|
||||||
if idx in adapter_weights
|
if idx in adapter_weights
|
||||||
else EMPTY_TENSOR.data_ptr()
|
else 0
|
||||||
)
|
)
|
||||||
for idx in segment_indices
|
for idx in segment_indices
|
||||||
],
|
],
|
||||||
|
@ -305,7 +307,7 @@ class BatchLoraWeights(BatchAdapterWeights):
|
||||||
(
|
(
|
||||||
adapter_weights[idx].weights_b.data_ptr()
|
adapter_weights[idx].weights_b.data_ptr()
|
||||||
if idx in adapter_weights
|
if idx in adapter_weights
|
||||||
else EMPTY_TENSOR.data_ptr()
|
else 0
|
||||||
)
|
)
|
||||||
for idx in segment_indices
|
for idx in segment_indices
|
||||||
],
|
],
|
||||||
|
@ -319,7 +321,7 @@ class BatchLoraWeights(BatchAdapterWeights):
|
||||||
(
|
(
|
||||||
adapter_weights[idx].weights_a_t.data_ptr()
|
adapter_weights[idx].weights_a_t.data_ptr()
|
||||||
if idx in adapter_weights
|
if idx in adapter_weights
|
||||||
else EMPTY_TENSOR.data_ptr()
|
else 0
|
||||||
)
|
)
|
||||||
for idx in segment_indices
|
for idx in segment_indices
|
||||||
],
|
],
|
||||||
|
@ -331,7 +333,7 @@ class BatchLoraWeights(BatchAdapterWeights):
|
||||||
(
|
(
|
||||||
adapter_weights[idx].weights_b_t.data_ptr()
|
adapter_weights[idx].weights_b_t.data_ptr()
|
||||||
if idx in adapter_weights
|
if idx in adapter_weights
|
||||||
else EMPTY_TENSOR.data_ptr()
|
else 0
|
||||||
)
|
)
|
||||||
for idx in segment_indices
|
for idx in segment_indices
|
||||||
],
|
],
|
||||||
|
|
|
@ -1,3 +1,7 @@
|
||||||
|
# Origin: https://github.com/predibase/lorax
|
||||||
|
# Path: lorax/server/lorax_server/adapters/weights.py
|
||||||
|
# License: Apache License Version 2.0, January 2004
|
||||||
|
|
||||||
from abc import ABC, abstractclassmethod
|
from abc import ABC, abstractclassmethod
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
|
@ -59,7 +59,7 @@ def load_attention(config, prefix, weights, layer_id):
|
||||||
|
|
||||||
# if specific model type, load the correct attention
|
# if specific model type, load the correct attention
|
||||||
if config.model_type == "phi3":
|
if config.model_type == "phi3":
|
||||||
base_layer = TensorParallelColumnLinear.load_qkv(
|
return TensorParallelColumnLinear.load_qkv(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.qkv_proj",
|
prefix=f"{prefix}.qkv_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
@ -68,7 +68,7 @@ def load_attention(config, prefix, weights, layer_id):
|
||||||
num_key_value_heads=config.num_key_value_heads,
|
num_key_value_heads=config.num_key_value_heads,
|
||||||
)
|
)
|
||||||
elif config.model_type == "baichuan":
|
elif config.model_type == "baichuan":
|
||||||
base_layer = TensorParallelColumnLinear.load_qkv(
|
return TensorParallelColumnLinear.load_qkv(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.W_pack",
|
prefix=f"{prefix}.W_pack",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
@ -76,28 +76,28 @@ def load_attention(config, prefix, weights, layer_id):
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
num_key_value_heads=config.num_key_value_heads,
|
num_key_value_heads=config.num_key_value_heads,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# otherwise, load the default attention based on the number of heads
|
||||||
|
base_layer = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
# otherwise, load the default attention based on the number of heads
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
base_layer = TensorParallelColumnLinear.load_multi(
|
return TensorParallelMultiAdapterLinear.load(
|
||||||
config,
|
base_layer,
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
layer_id,
|
||||||
dim=0,
|
["q_proj", "k_proj", "v_proj"],
|
||||||
weights=weights,
|
sizes=[
|
||||||
bias=bias,
|
head_size * config.num_attention_heads,
|
||||||
)
|
head_size * config.num_key_value_heads,
|
||||||
|
head_size * config.num_key_value_heads,
|
||||||
head_size = config.hidden_size // config.num_attention_heads
|
],
|
||||||
return TensorParallelMultiAdapterLinear.load(
|
process_group=weights.process_group,
|
||||||
base_layer,
|
)
|
||||||
layer_id,
|
|
||||||
["q_proj", "k_proj", "v_proj"],
|
|
||||||
sizes=[
|
|
||||||
head_size * config.num_attention_heads,
|
|
||||||
head_size * config.num_key_value_heads,
|
|
||||||
head_size * config.num_key_value_heads,
|
|
||||||
],
|
|
||||||
process_group=weights.process_group,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashLlamaAttention(torch.nn.Module):
|
class FlashLlamaAttention(torch.nn.Module):
|
||||||
|
@ -240,7 +240,7 @@ class LlamaMLP(nn.Module):
|
||||||
# Fuse gate and up proj
|
# Fuse gate and up proj
|
||||||
bias = getattr(config, "mlp_bias", False)
|
bias = getattr(config, "mlp_bias", False)
|
||||||
if config.model_type == "phi3":
|
if config.model_type == "phi3":
|
||||||
gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.gate_up_proj",
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
@ -255,16 +255,16 @@ class LlamaMLP(nn.Module):
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||||
gate_up_proj,
|
gate_up_proj,
|
||||||
index,
|
index,
|
||||||
["gate_proj", "up_proj"],
|
["gate_proj", "up_proj"],
|
||||||
sizes=[
|
sizes=[
|
||||||
config.intermediate_size,
|
config.intermediate_size,
|
||||||
config.intermediate_size,
|
config.intermediate_size,
|
||||||
],
|
],
|
||||||
process_group=weights.process_group,
|
process_group=weights.process_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
down_proj = TensorParallelRowLinear.load(
|
down_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
|
@ -273,12 +273,15 @@ class LlamaMLP(nn.Module):
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.down_proj = TensorParallelAdapterRowLinear.load(
|
if config.model_type == "phi3":
|
||||||
down_proj,
|
self.down_proj = down_proj
|
||||||
index,
|
else:
|
||||||
"down_proj",
|
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||||
process_group=weights.process_group,
|
down_proj,
|
||||||
)
|
index,
|
||||||
|
"down_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
self.intermediate_size = (
|
self.intermediate_size = (
|
||||||
config.intermediate_size // weights.process_group.size()
|
config.intermediate_size // weights.process_group.size()
|
||||||
|
@ -471,9 +474,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_lora_index(self, adapter_id):
|
|
||||||
return self.model.layers[0].self_attn.key_to_index[adapter_id]
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
|
|
@ -1,3 +1,7 @@
|
||||||
|
# Origin: https://github.com/predibase/lorax
|
||||||
|
# Path: lorax/server/lorax_server/utils/adapter.py
|
||||||
|
# License: Apache License Version 2.0, January 2004
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
@ -20,22 +24,20 @@ if TYPE_CHECKING:
|
||||||
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class AdapterParameters:
|
class AdapterParameters:
|
||||||
def __init__(
|
adapter_ids: Tuple[str]
|
||||||
self, adapter_ids, weights, merge_strategy, density, majority_sign_method
|
weights: Tuple[float]
|
||||||
):
|
merge_strategy: NotImplemented
|
||||||
self.adapter_ids = adapter_ids
|
density: float
|
||||||
self.weights = weights
|
majority_sign_method: NotImplemented
|
||||||
self.merge_strategy = merge_strategy
|
|
||||||
self.density = density
|
|
||||||
self.majority_sign_method = majority_sign_method
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class AdapterSource:
|
class AdapterSource:
|
||||||
def __init__(self, adapter_id: str, model_id: str, revision: str):
|
adapter_id: str
|
||||||
self.adapter_id = adapter_id
|
model_id: str
|
||||||
self.model_id = model_id
|
revision: str
|
||||||
self.revision = revision
|
|
||||||
|
|
||||||
|
|
||||||
def load_and_merge_adapters(
|
def load_and_merge_adapters(
|
||||||
|
@ -65,11 +67,11 @@ def load_and_merge_adapters(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class AdapterParametersContainer:
|
class AdapterParametersContainer:
|
||||||
def __init__(self, adapter_parameters, adapter_source, adapter_index):
|
adapter_parameters: AdapterParameters
|
||||||
self.adapter_parameters = adapter_parameters
|
adapter_source: str
|
||||||
self.adapter_source = adapter_source
|
adapter_index: int
|
||||||
self.adapter_index = adapter_index
|
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
return self.adapter_index
|
return self.adapter_index
|
||||||
|
@ -123,7 +125,7 @@ def check_architectures(
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
if not adapter_config.base_model_name_or_path:
|
if not adapter_config.base_model_name_or_path:
|
||||||
# Avoid execuation latency caused by the network connection retrying for AutoConfig.from_pretrained(None)
|
# Avoid execution latency caused by the network connection retrying for AutoConfig.from_pretrained(None)
|
||||||
return
|
return
|
||||||
|
|
||||||
expected_config = AutoConfig.from_pretrained(
|
expected_config = AutoConfig.from_pretrained(
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import os
|
import os
|
||||||
import json
|
from typing import Union
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -45,7 +45,9 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
|
||||||
tokenizer.save_pretrained(cache_dir)
|
tokenizer.save_pretrained(cache_dir)
|
||||||
|
|
||||||
|
|
||||||
def download_peft(model_id, revision, trust_remote_code):
|
def download_peft(
|
||||||
|
model_id: Union[str, os.PathLike], revision: str, trust_remote_code: bool
|
||||||
|
):
|
||||||
torch_dtype = torch.float16
|
torch_dtype = torch.float16
|
||||||
try:
|
try:
|
||||||
_model = AutoPeftModelForCausalLM.from_pretrained(
|
_model = AutoPeftModelForCausalLM.from_pretrained(
|
||||||
|
|
|
@ -1,3 +1,7 @@
|
||||||
|
# Origin: https://github.com/predibase/lorax
|
||||||
|
# Path: lorax/server/lorax_server/utils/segments.py
|
||||||
|
# License: Apache License Version 2.0, January 2004
|
||||||
|
|
||||||
from typing import List, Tuple, Union
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -1,3 +1,7 @@
|
||||||
|
# Origin: https://github.com/predibase/lorax
|
||||||
|
# Path: lorax/server/lorax_server/utils/sgmv.py
|
||||||
|
# License: Apache License Version 2.0, January 2004
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
Loading…
Reference in New Issue