fix: refactors and helpful comments
This commit is contained in:
parent
a07b612989
commit
a7556ba800
|
@ -1,12 +1,13 @@
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple
|
from typing import TYPE_CHECKING, Optional, Tuple, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
from text_generation_server.utils.sgmv import (
|
from text_generation_server.utils.sgmv import (
|
||||||
add_lora_a_bgmv,
|
add_lora_a_bgmv,
|
||||||
|
@ -26,7 +27,9 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
|
|
||||||
class LoraLinear(nn.Module):
|
class LoraLinear(nn.Module):
|
||||||
def __init__(self, base_layer, layer_id, process_group):
|
def __init__(
|
||||||
|
self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.base_layer = base_layer
|
self.base_layer = base_layer
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
|
@ -49,6 +52,18 @@ class LoraLinear(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
|
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
|
||||||
|
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
|
||||||
|
# The 'result' tensor represents the full output, which can vary in size based on
|
||||||
|
# the layer type (e.g., attention vs. feed-forward layers). We define the current
|
||||||
|
# segment using start_idx and end_idx. If the segment size doesn't match this GPU's
|
||||||
|
# slice of 'result', we create a zero tensor of the correct size for LoRA computation.
|
||||||
|
# This approach ensures accurate LoRA application across various layer sizes and
|
||||||
|
# configurations, adapting to different model architectures and parallelization strategies.
|
||||||
|
#
|
||||||
|
# Example scenarios where this is necessary:
|
||||||
|
# 1. The adapter's size doesn't evenly divide across GPUs.
|
||||||
|
# 2. We're processing the last segment which might be smaller.
|
||||||
|
# 3. Different projection layers (q, k, v) have different sizes.
|
||||||
if end_idx - start_idx != result.shape[1]:
|
if end_idx - start_idx != result.shape[1]:
|
||||||
proj = torch.zeros_like(result[:, start_idx:end_idx])
|
proj = torch.zeros_like(result[:, start_idx:end_idx])
|
||||||
else:
|
else:
|
||||||
|
@ -149,13 +164,27 @@ class LoraLinear(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class TensorParallelMultiAdapterLinear(LoraLinear):
|
class TensorParallelMultiAdapterLinear(LoraLinear):
|
||||||
def __init__(self, base_layer, layer_id, layer_names, sizes, process_group):
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_layer: nn.Module,
|
||||||
|
layer_id: int,
|
||||||
|
layer_names: List[str],
|
||||||
|
sizes: List[int],
|
||||||
|
process_group: ProcessGroup,
|
||||||
|
):
|
||||||
super().__init__(base_layer, layer_id, process_group)
|
super().__init__(base_layer, layer_id, process_group)
|
||||||
self.layer_names = layer_names
|
self.layer_names = layer_names
|
||||||
self.sizes = sizes
|
self.sizes = sizes
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, base_layer, layer_id, layer_names, sizes, process_group):
|
def load(
|
||||||
|
cls,
|
||||||
|
base_layer: nn.Module,
|
||||||
|
layer_id: int,
|
||||||
|
layer_names: List[str],
|
||||||
|
sizes: List[int],
|
||||||
|
process_group: ProcessGroup,
|
||||||
|
):
|
||||||
return TensorParallelMultiAdapterLinear(
|
return TensorParallelMultiAdapterLinear(
|
||||||
base_layer, layer_id, layer_names, sizes, process_group
|
base_layer, layer_id, layer_names, sizes, process_group
|
||||||
)
|
)
|
||||||
|
@ -178,7 +207,12 @@ class TensorParallelMultiAdapterLinear(LoraLinear):
|
||||||
offset = 0
|
offset = 0
|
||||||
for i, layer_name in enumerate(self.layer_names):
|
for i, layer_name in enumerate(self.layer_names):
|
||||||
start_idx = offset // self.process_group.size()
|
start_idx = offset // self.process_group.size()
|
||||||
|
# The 'sizes' parameter is essential in tensor-parallel setups for handling multiple
|
||||||
|
# projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It
|
||||||
|
# ensures correct slicing of the result tensor, accommodating variations like grouped-query
|
||||||
|
# attention where k_proj and v_proj differ from q_proj. This allows precise application of
|
||||||
|
# LoRA adapters to each sub-component of the multi-head attention mechanism, managing the
|
||||||
|
# different projection sizes across layers and model architectures.
|
||||||
if self.sizes is not None:
|
if self.sizes is not None:
|
||||||
offset += self.sizes[i]
|
offset += self.sizes[i]
|
||||||
end_idx = offset // self.process_group.size()
|
end_idx = offset // self.process_group.size()
|
||||||
|
|
|
@ -292,31 +292,3 @@ class Model(ABC):
|
||||||
]
|
]
|
||||||
|
|
||||||
return weights_a, weights_b
|
return weights_a, weights_b
|
||||||
|
|
||||||
def offload_adapter(
|
|
||||||
self,
|
|
||||||
adapter_parameters: AdapterParameters,
|
|
||||||
adapter_source: AdapterSource,
|
|
||||||
adapter_index: int,
|
|
||||||
):
|
|
||||||
"""Offloads the adapter weights from GPU to CPU or disk."""
|
|
||||||
if adapter_index not in self.loaded_adapters:
|
|
||||||
# Adapter already offloaded
|
|
||||||
return
|
|
||||||
|
|
||||||
if not self.supports_adapter_loading:
|
|
||||||
raise ValueError("This model does not support adapter loading.")
|
|
||||||
|
|
||||||
if not self.dynamic_adapter_loading_enabled:
|
|
||||||
raise ValueError(
|
|
||||||
f"This model was initialized with the adapter {self.static_adapter_id} "
|
|
||||||
f"and therefore does not support dynamic adapter loading. "
|
|
||||||
f"Please initialize a new model instance from the base model in "
|
|
||||||
f"order to use the dynamic adapter loading feature."
|
|
||||||
)
|
|
||||||
|
|
||||||
for layer_name in self.adapter_layers:
|
|
||||||
if layer_name in self.layer_to_adapter_weights:
|
|
||||||
self.layer_to_adapter_weights[layer_name].remove_adapter(adapter_index)
|
|
||||||
|
|
||||||
self.loaded_adapters.remove(adapter_index)
|
|
||||||
|
|
|
@ -1,3 +1,7 @@
|
||||||
|
# Origin: https://github.com/predibase/lorax
|
||||||
|
# Path: lorax/server/lorax_server/utils/merges/strategies.py
|
||||||
|
# License: Apache License Version 2.0, January 2004
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
Loading…
Reference in New Issue