fix: refactors and helpful comments

This commit is contained in:
drbh 2024-06-24 13:39:56 +00:00
parent a07b612989
commit a7556ba800
3 changed files with 43 additions and 33 deletions

View File

@ -1,12 +1,13 @@
import math
import os
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Optional, Tuple, List
import torch
import torch.distributed
from accelerate import init_empty_weights
from torch import nn
from torch.nn import functional as F
from torch.distributed import ProcessGroup
from text_generation_server.utils.sgmv import (
add_lora_a_bgmv,
@ -26,7 +27,9 @@ if TYPE_CHECKING:
class LoraLinear(nn.Module):
def __init__(self, base_layer, layer_id, process_group):
def __init__(
self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup
):
super().__init__()
self.base_layer = base_layer
self.layer_id = layer_id
@ -49,6 +52,18 @@ class LoraLinear(nn.Module):
)
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
# The 'result' tensor represents the full output, which can vary in size based on
# the layer type (e.g., attention vs. feed-forward layers). We define the current
# segment using start_idx and end_idx. If the segment size doesn't match this GPU's
# slice of 'result', we create a zero tensor of the correct size for LoRA computation.
# This approach ensures accurate LoRA application across various layer sizes and
# configurations, adapting to different model architectures and parallelization strategies.
#
# Example scenarios where this is necessary:
# 1. The adapter's size doesn't evenly divide across GPUs.
# 2. We're processing the last segment which might be smaller.
# 3. Different projection layers (q, k, v) have different sizes.
if end_idx - start_idx != result.shape[1]:
proj = torch.zeros_like(result[:, start_idx:end_idx])
else:
@ -149,13 +164,27 @@ class LoraLinear(nn.Module):
class TensorParallelMultiAdapterLinear(LoraLinear):
def __init__(self, base_layer, layer_id, layer_names, sizes, process_group):
def __init__(
self,
base_layer: nn.Module,
layer_id: int,
layer_names: List[str],
sizes: List[int],
process_group: ProcessGroup,
):
super().__init__(base_layer, layer_id, process_group)
self.layer_names = layer_names
self.sizes = sizes
@classmethod
def load(cls, base_layer, layer_id, layer_names, sizes, process_group):
def load(
cls,
base_layer: nn.Module,
layer_id: int,
layer_names: List[str],
sizes: List[int],
process_group: ProcessGroup,
):
return TensorParallelMultiAdapterLinear(
base_layer, layer_id, layer_names, sizes, process_group
)
@ -178,7 +207,12 @@ class TensorParallelMultiAdapterLinear(LoraLinear):
offset = 0
for i, layer_name in enumerate(self.layer_names):
start_idx = offset // self.process_group.size()
# The 'sizes' parameter is essential in tensor-parallel setups for handling multiple
# projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It
# ensures correct slicing of the result tensor, accommodating variations like grouped-query
# attention where k_proj and v_proj differ from q_proj. This allows precise application of
# LoRA adapters to each sub-component of the multi-head attention mechanism, managing the
# different projection sizes across layers and model architectures.
if self.sizes is not None:
offset += self.sizes[i]
end_idx = offset // self.process_group.size()

View File

@ -292,31 +292,3 @@ class Model(ABC):
]
return weights_a, weights_b
def offload_adapter(
self,
adapter_parameters: AdapterParameters,
adapter_source: AdapterSource,
adapter_index: int,
):
"""Offloads the adapter weights from GPU to CPU or disk."""
if adapter_index not in self.loaded_adapters:
# Adapter already offloaded
return
if not self.supports_adapter_loading:
raise ValueError("This model does not support adapter loading.")
if not self.dynamic_adapter_loading_enabled:
raise ValueError(
f"This model was initialized with the adapter {self.static_adapter_id} "
f"and therefore does not support dynamic adapter loading. "
f"Please initialize a new model instance from the base model in "
f"order to use the dynamic adapter loading feature."
)
for layer_name in self.adapter_layers:
if layer_name in self.layer_to_adapter_weights:
self.layer_to_adapter_weights[layer_name].remove_adapter(adapter_index)
self.loaded_adapters.remove(adapter_index)

View File

@ -1,3 +1,7 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/utils/merges/strategies.py
# License: Apache License Version 2.0, January 2004
import copy
from abc import ABC
from collections import defaultdict