feat: refactor lora linear and remove adapter layers
This commit is contained in:
parent
da82c63a4f
commit
82fc879e17
|
@ -208,7 +208,7 @@ class LoraWeights(AdapterWeights):
|
|||
for layer_id in range(nlayers):
|
||||
key = (layer_id, layer_type)
|
||||
weight_name, layer = model.target_to_layer[key]
|
||||
base_weight = layer.base_layer.linear.weight
|
||||
base_weight = layer.linear.weight
|
||||
base_device = base_weight.device
|
||||
|
||||
if weight_name not in module_map:
|
||||
|
|
|
@ -13,8 +13,13 @@ from text_generation_server.layers.speculative import SpeculativeHead
|
|||
from text_generation_server.layers.layernorm import load_layer_norm
|
||||
from text_generation_server.layers.conv import load_conv2d
|
||||
|
||||
from text_generation_server.layers.lora import (
|
||||
LoraLinear,
|
||||
TensorParallelMultiAdapterLinear,
|
||||
TensorParallelAdapterRowLinear,
|
||||
)
|
||||
__all__ = [
|
||||
"TensorParallelColumnLinear",
|
||||
"TensorParallelRowLinear",
|
||||
"TensorParallelEmbedding",
|
||||
"get_linear",
|
||||
"FastLinear",
|
||||
"SpeculativeHead",
|
||||
"load_layer_norm",
|
||||
"load_conv2d",
|
||||
]
|
||||
|
|
|
@ -1,14 +1,11 @@
|
|||
import math
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, List
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from accelerate import init_empty_weights
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from text_generation_server.adapters import AdapterBatchData
|
||||
from text_generation_server.adapters.lora import BatchLoraWeights
|
||||
from text_generation_server.utils.sgmv import (
|
||||
add_lora_a_bgmv,
|
||||
add_lora_b_bgmv,
|
||||
|
@ -18,269 +15,160 @@ from text_generation_server.utils.sgmv import (
|
|||
orient_for_rank,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from text_generation_server.adapters import AdapterBatchData
|
||||
from text_generation_server.adapters.lora import BatchLoraWeights
|
||||
|
||||
|
||||
class LoraLinear(nn.Module):
|
||||
def __init__(
|
||||
self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup
|
||||
):
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.layer_id = layer_id
|
||||
self.process_group = process_group
|
||||
|
||||
def forward_layer_type(
|
||||
self,
|
||||
result: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
adapter_data: "AdapterBatchData",
|
||||
layer_type: str,
|
||||
start_idx: int,
|
||||
end_idx: int,
|
||||
) -> torch.Tensor:
|
||||
if adapter_data is None:
|
||||
return result
|
||||
data = adapter_data.data.get(layer_type)
|
||||
data: Optional["BatchLoraWeights"] = (
|
||||
data.get("lora") if data is not None else None
|
||||
)
|
||||
|
||||
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
|
||||
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
|
||||
# The 'result' tensor represents the full output, which can vary in size based on
|
||||
# the layer type (e.g., attention vs. feed-forward layers). We define the current
|
||||
# segment using start_idx and end_idx. If the segment size doesn't match this GPU's
|
||||
# slice of 'result', we create a zero tensor of the correct size for LoRA computation.
|
||||
# This approach ensures accurate LoRA application across various layer sizes and
|
||||
# configurations, adapting to different model architectures and parallelization strategies.
|
||||
#
|
||||
# Example scenarios where this is necessary:
|
||||
# 1. The adapter's size doesn't evenly divide across GPUs.
|
||||
# 2. We're processing the last segment which might be smaller.
|
||||
# 3. Different projection layers (q, k, v) have different sizes.
|
||||
if end_idx - start_idx != result.shape[1]:
|
||||
proj = torch.zeros_like(result[:, start_idx:end_idx])
|
||||
else:
|
||||
proj = result
|
||||
|
||||
for r, rank_segments in data.rank_data.items():
|
||||
lora_a_ptr = rank_segments.lora_a_ptr
|
||||
lora_b_ptr = rank_segments.lora_b_ptr
|
||||
|
||||
if lora_a_ptr is None or lora_b_ptr is None:
|
||||
raise ValueError("LoRA data is missing")
|
||||
|
||||
if data.use_sgmv:
|
||||
# Use SGMV for prefill
|
||||
v = lora_a_sgmv_cutlass(
|
||||
input,
|
||||
rank_segments.tmp_shrink,
|
||||
lora_a_ptr,
|
||||
rank_segments.segment_starts,
|
||||
rank_segments.segment_ends,
|
||||
self.layer_id,
|
||||
r,
|
||||
)
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
v = self.collect_lora_a(v)
|
||||
|
||||
lora_b_sgmv_cutlass(
|
||||
proj,
|
||||
v,
|
||||
rank_segments.tmp_expand,
|
||||
lora_b_ptr,
|
||||
rank_segments.segment_starts,
|
||||
rank_segments.segment_ends,
|
||||
self.layer_id,
|
||||
)
|
||||
else:
|
||||
# Use BGMV for decode
|
||||
v = torch.zeros(
|
||||
(input.size(0), r), dtype=input.dtype, device=input.device
|
||||
)
|
||||
# TODO: error with [-1, 0], but not [0, -1]
|
||||
add_lora_a_bgmv(
|
||||
v,
|
||||
input,
|
||||
lora_a_ptr,
|
||||
rank_segments.indices,
|
||||
self.layer_id,
|
||||
)
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
v = self.collect_lora_a(v)
|
||||
|
||||
add_lora_b_bgmv(
|
||||
proj,
|
||||
v,
|
||||
lora_b_ptr,
|
||||
rank_segments.indices,
|
||||
self.layer_id,
|
||||
)
|
||||
|
||||
if end_idx - start_idx != result.shape[1]:
|
||||
result[:, start_idx:end_idx] += proj
|
||||
else:
|
||||
for adapter_index in adapter_data.meta.adapter_set:
|
||||
if data is not None and data.has_adapter(adapter_index):
|
||||
adapter_mask = (
|
||||
(adapter_data.meta.adapter_indices == adapter_index)
|
||||
.to(input.dtype)
|
||||
.view(-1, 1)
|
||||
)
|
||||
layer_result = self.forward_lora(
|
||||
input, data, adapter_index, adapter_mask
|
||||
)
|
||||
result[:, start_idx:end_idx] += layer_result
|
||||
|
||||
return result
|
||||
|
||||
def forward_lora(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
data: "BatchLoraWeights",
|
||||
adapter_index: int,
|
||||
adapter_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
|
||||
lora_b = data.lora_b[adapter_index][self.layer_id, :, :]
|
||||
|
||||
lora_a = orient_for_rank(lora_a, lora_b.size(0))
|
||||
|
||||
a_out = input @ lora_a
|
||||
if self.process_group.size() > 1:
|
||||
a_out = self.collect_lora_a(a_out)
|
||||
|
||||
result = (a_out @ lora_b) * adapter_mask
|
||||
return result
|
||||
|
||||
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError("Implemented in subclasses")
|
||||
|
||||
|
||||
class TensorParallelMultiAdapterLinear(LoraLinear):
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: nn.Module,
|
||||
layer_id: int,
|
||||
layer_names: List[str],
|
||||
sizes: List[int],
|
||||
process_group: ProcessGroup,
|
||||
):
|
||||
super().__init__(base_layer, layer_id, process_group)
|
||||
self.layer_names = layer_names
|
||||
self.sizes = sizes
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
base_layer: nn.Module,
|
||||
layer_id: int,
|
||||
layer_names: List[str],
|
||||
sizes: List[int],
|
||||
process_group: ProcessGroup,
|
||||
):
|
||||
return TensorParallelMultiAdapterLinear(
|
||||
base_layer, layer_id, layer_names, sizes, process_group
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
|
||||
) -> torch.Tensor:
|
||||
result = self.base_layer(input)
|
||||
|
||||
# noop if no layer names are provided (e.g. for models without adapters)
|
||||
if self.layer_names is None:
|
||||
return result
|
||||
|
||||
# handle models like Bloom that have inputs of shape
|
||||
# (batch_size, sequence_length, hidden_size)
|
||||
# we need to reshape them to (batch_size * sequence_length, hidden_size)
|
||||
# for the LoRA computation, then reshape back
|
||||
prev_shape = result.shape
|
||||
is_3d = len(input.shape) >= 3
|
||||
if is_3d:
|
||||
input = input.reshape(-1, input.shape[-1])
|
||||
result = result.reshape(-1, result.shape[-1])
|
||||
|
||||
offset = 0
|
||||
for i, layer_name in enumerate(self.layer_names):
|
||||
start_idx = offset // self.process_group.size()
|
||||
# The 'sizes' parameter is essential in tensor-parallel setups for handling multiple
|
||||
# projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It
|
||||
# ensures correct slicing of the result tensor, accommodating variations like grouped-query
|
||||
# attention where k_proj and v_proj differ from q_proj. This allows precise application of
|
||||
# LoRA adapters to each sub-component of the multi-head attention mechanism, managing the
|
||||
# different projection sizes across layers and model architectures.
|
||||
if self.sizes is not None:
|
||||
offset += self.sizes[i]
|
||||
end_idx = offset // self.process_group.size()
|
||||
else:
|
||||
end_idx = result.shape[1]
|
||||
|
||||
result = self.forward_layer_type(
|
||||
result, input, adapter_data, layer_name, start_idx, end_idx
|
||||
)
|
||||
|
||||
if is_3d:
|
||||
result = result.reshape(prev_shape)
|
||||
|
||||
return result
|
||||
|
||||
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
||||
def gather_lora_weights(
|
||||
process_group: ProcessGroup,
|
||||
weights: torch.Tensor,
|
||||
use_all_gather: bool = False,
|
||||
) -> BatchLoraWeights:
|
||||
if use_all_gather:
|
||||
# Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise.
|
||||
# We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks.
|
||||
#
|
||||
# TODO(travis): this is not very efficient as we do an all-gather for every adapter,
|
||||
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
|
||||
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
|
||||
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
|
||||
# TODO: this is not very efficient as we do an all-gather for every adapter,
|
||||
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
|
||||
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
|
||||
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
|
||||
gathered_tensors = [
|
||||
torch.empty_like(a_out) for _ in range(self.process_group.size())
|
||||
torch.empty_like(weights) for _ in range(process_group.size())
|
||||
]
|
||||
torch.distributed.all_gather(gathered_tensors, a_out)
|
||||
torch.distributed.all_gather(gathered_tensors, weights, group=process_group)
|
||||
return torch.cat(gathered_tensors, dim=1)
|
||||
else:
|
||||
torch.distributed.all_reduce(weights, group=process_group)
|
||||
return weights
|
||||
|
||||
|
||||
class TensorParallelAdapterRowLinear(LoraLinear):
|
||||
def __init__(self, base_layer, layer_id, layer_name, process_group):
|
||||
super().__init__(base_layer, layer_id, process_group)
|
||||
self.layer_name = layer_name
|
||||
|
||||
@classmethod
|
||||
def load(cls, base_layer, layer_id, layer_name, process_group):
|
||||
return cls(base_layer, layer_id, layer_name, process_group)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
|
||||
) -> torch.Tensor:
|
||||
result = self.base_layer(input)
|
||||
|
||||
if self.layer_name is None:
|
||||
return result
|
||||
|
||||
# Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285
|
||||
stride = result.shape[-1] // self.process_group.size()
|
||||
start_idx = self.process_group.rank() * stride
|
||||
end_idx = (self.process_group.rank() + 1) * stride
|
||||
|
||||
self.forward_layer_type(
|
||||
result, input, adapter_data, self.layer_name, start_idx, end_idx
|
||||
)
|
||||
|
||||
def forward_layer_type(
|
||||
process_group: ProcessGroup,
|
||||
layer_id: int,
|
||||
result: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
adapter_data: "AdapterBatchData",
|
||||
layer_type: str,
|
||||
start_idx: int,
|
||||
end_idx: int,
|
||||
use_all_gather: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if adapter_data is None:
|
||||
return result
|
||||
data = adapter_data.data.get(layer_type)
|
||||
data: Optional["BatchLoraWeights"] = data.get("lora") if data is not None else None
|
||||
|
||||
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
||||
# Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise.
|
||||
# We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks.
|
||||
if has_sgmv() and data is not None and data.can_vectorize(process_group):
|
||||
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
|
||||
# The 'result' tensor represents the full output, which can vary in size based on
|
||||
# the layer type (e.g., attention vs. feed-forward layers). We define the current
|
||||
# segment using start_idx and end_idx. If the segment size doesn't match this GPU's
|
||||
# slice of 'result', we create a zero tensor of the correct size for LoRA computation.
|
||||
# This approach ensures accurate LoRA application across various layer sizes and
|
||||
# configurations, adapting to different model architectures and parallelization strategies.
|
||||
#
|
||||
# TODO(travis): this is not very efficient as we do an all-reduce for every adapter,
|
||||
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
|
||||
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
|
||||
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
|
||||
torch.distributed.all_reduce(a_out, group=self.process_group)
|
||||
return a_out
|
||||
# Example scenarios where this is necessary:
|
||||
# 1. The adapter's size doesn't evenly divide across GPUs.
|
||||
# 2. We're processing the last segment which might be smaller.
|
||||
# 3. Different projection layers (q, k, v) have different sizes.
|
||||
if end_idx - start_idx != result.shape[1]:
|
||||
proj = torch.zeros_like(result[:, start_idx:end_idx])
|
||||
else:
|
||||
proj = result
|
||||
|
||||
for r, rank_segments in data.rank_data.items():
|
||||
lora_a_ptr = rank_segments.lora_a_ptr
|
||||
lora_b_ptr = rank_segments.lora_b_ptr
|
||||
|
||||
if lora_a_ptr is None or lora_b_ptr is None:
|
||||
raise ValueError("LoRA data is missing")
|
||||
|
||||
if data.use_sgmv:
|
||||
# Use SGMV for prefill
|
||||
v = lora_a_sgmv_cutlass(
|
||||
input,
|
||||
rank_segments.tmp_shrink,
|
||||
lora_a_ptr,
|
||||
rank_segments.segment_starts,
|
||||
rank_segments.segment_ends,
|
||||
layer_id,
|
||||
r,
|
||||
)
|
||||
|
||||
if process_group.size() > 1:
|
||||
v = gather_lora_weights(process_group, v, use_all_gather)
|
||||
|
||||
lora_b_sgmv_cutlass(
|
||||
proj,
|
||||
v,
|
||||
rank_segments.tmp_expand,
|
||||
lora_b_ptr,
|
||||
rank_segments.segment_starts,
|
||||
rank_segments.segment_ends,
|
||||
layer_id,
|
||||
)
|
||||
else:
|
||||
# Use BGMV for decode
|
||||
v = torch.zeros(
|
||||
(input.size(0), r), dtype=input.dtype, device=input.device
|
||||
)
|
||||
# TODO: error with [-1, 0], but not [0, -1]
|
||||
add_lora_a_bgmv(
|
||||
v,
|
||||
input,
|
||||
lora_a_ptr,
|
||||
rank_segments.indices,
|
||||
layer_id,
|
||||
)
|
||||
|
||||
if process_group.size() > 1:
|
||||
v = gather_lora_weights(process_group, v, use_all_gather)
|
||||
|
||||
add_lora_b_bgmv(
|
||||
proj,
|
||||
v,
|
||||
lora_b_ptr,
|
||||
rank_segments.indices,
|
||||
layer_id,
|
||||
)
|
||||
|
||||
if end_idx - start_idx != result.shape[1]:
|
||||
result[:, start_idx:end_idx] += proj
|
||||
else:
|
||||
for adapter_index in adapter_data.meta.adapter_set:
|
||||
if data is not None and data.has_adapter(adapter_index):
|
||||
adapter_mask = (
|
||||
(adapter_data.meta.adapter_indices == adapter_index)
|
||||
.to(input.dtype)
|
||||
.view(-1, 1)
|
||||
)
|
||||
layer_result = forward_lora(
|
||||
process_group=process_group,
|
||||
layer_id=layer_id,
|
||||
input=input,
|
||||
data=data,
|
||||
adapter_index=adapter_index,
|
||||
adapter_mask=adapter_mask,
|
||||
use_all_gather=use_all_gather,
|
||||
)
|
||||
result[:, start_idx:end_idx] += layer_result
|
||||
return result
|
||||
|
||||
|
||||
def forward_lora(
|
||||
process_group: ProcessGroup,
|
||||
layer_id,
|
||||
input: torch.Tensor,
|
||||
data: "BatchLoraWeights",
|
||||
adapter_index: int,
|
||||
adapter_mask: torch.Tensor,
|
||||
use_all_gather: bool = False,
|
||||
) -> torch.Tensor:
|
||||
lora_a = data.lora_a[adapter_index][layer_id, :, :]
|
||||
lora_b = data.lora_b[adapter_index][layer_id, :, :]
|
||||
|
||||
lora_a = orient_for_rank(lora_a, lora_b.size(0))
|
||||
|
||||
a_out = input @ lora_a
|
||||
if process_group.size() > 1:
|
||||
a_out = gather_lora_weights(process_group, a_out, use_all_gather)
|
||||
|
||||
result = (a_out @ lora_b) * adapter_mask
|
||||
return result
|
||||
|
|
|
@ -4,6 +4,7 @@ from typing import Iterable, List
|
|||
from text_generation_server.layers.linear import get_linear, FastLinear
|
||||
from text_generation_server.layers.exl2 import Exl2Weight
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.lora import forward_layer_type
|
||||
|
||||
if SYSTEM == "ipex":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
@ -126,6 +127,20 @@ class TensorParallelHead(SuperLayer):
|
|||
|
||||
|
||||
class TensorParallelColumnLinear(SuperLayer):
|
||||
def __init__(
|
||||
self,
|
||||
linear,
|
||||
process_group: torch.distributed.ProcessGroup = None,
|
||||
layer_names: List[str] = None,
|
||||
sizes: List[int] = None,
|
||||
layer_id: str = None,
|
||||
):
|
||||
super().__init__(linear)
|
||||
self.process_group = process_group
|
||||
self.layer_names = layer_names
|
||||
self.sizes = sizes
|
||||
self.layer_id = layer_id
|
||||
|
||||
@classmethod
|
||||
def load_gate_up(cls, config, prefix: str, weights, bias: bool):
|
||||
"""Specific method when the QKV was joined after the fact"""
|
||||
|
@ -171,7 +186,18 @@ class TensorParallelColumnLinear(SuperLayer):
|
|||
return cls(linear)
|
||||
|
||||
@classmethod
|
||||
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
|
||||
def load_multi(
|
||||
cls,
|
||||
config,
|
||||
prefixes: List[str],
|
||||
weights,
|
||||
bias: bool,
|
||||
dim: int,
|
||||
sizes: List[int] = None,
|
||||
layer_id: str = None,
|
||||
):
|
||||
# infer layer_names from prefixes
|
||||
layer_names = [prefix.split(".")[-1] for prefix in prefixes]
|
||||
if config.quantize == "exl2":
|
||||
linears = []
|
||||
for prefix in prefixes:
|
||||
|
@ -187,17 +213,75 @@ class TensorParallelColumnLinear(SuperLayer):
|
|||
else:
|
||||
bias = None
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
return cls(linear)
|
||||
|
||||
return cls(
|
||||
linear,
|
||||
process_group=weights.process_group,
|
||||
layer_names=layer_names,
|
||||
sizes=sizes,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor, adapter_data=None) -> torch.Tensor:
|
||||
result = super().forward(input)
|
||||
|
||||
# noop if no lora data is provided
|
||||
if adapter_data is None:
|
||||
return result
|
||||
|
||||
# handle models like Bloom that have inputs of shape
|
||||
# (batch_size, sequence_length, hidden_size)
|
||||
# we need to reshape them to (batch_size * sequence_length, hidden_size)
|
||||
# for the LoRA computation, then reshape back
|
||||
prev_shape = result.shape
|
||||
is_3d = len(input.shape) >= 3
|
||||
if is_3d:
|
||||
input = input.reshape(-1, input.shape[-1])
|
||||
result = result.reshape(-1, result.shape[-1])
|
||||
|
||||
offset = 0
|
||||
for i, layer_name in enumerate(self.layer_names):
|
||||
start_idx = offset // self.process_group.size()
|
||||
# The 'sizes' parameter is essential in tensor-parallel setups for handling multiple
|
||||
# projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It
|
||||
# ensures correct slicing of the result tensor, accommodating variations like grouped-query
|
||||
# attention where k_proj and v_proj differ from q_proj. This allows precise application of
|
||||
# LoRA adapters to each sub-component of the multi-head attention mechanism, managing the
|
||||
# different projection sizes across layers and model architectures.
|
||||
if self.sizes is not None:
|
||||
offset += self.sizes[i]
|
||||
end_idx = offset // self.process_group.size()
|
||||
else:
|
||||
end_idx = result.shape[1]
|
||||
|
||||
result = forward_layer_type(
|
||||
process_group=self.process_group,
|
||||
layer_id=self.layer_id,
|
||||
result=result,
|
||||
input=input,
|
||||
adapter_data=adapter_data,
|
||||
layer_type=layer_name,
|
||||
start_idx=start_idx,
|
||||
end_idx=end_idx,
|
||||
use_all_gather=True,
|
||||
)
|
||||
|
||||
if is_3d:
|
||||
result = result.reshape(prev_shape)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class TensorParallelRowLinear(SuperLayer):
|
||||
def __init__(self, linear, process_group):
|
||||
def __init__(self, linear, process_group, layer_name):
|
||||
super().__init__(linear)
|
||||
self.process_group = process_group
|
||||
self.layer_name = layer_name
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_weights_row(prefix)
|
||||
layer_name = prefix.split(".")[-1]
|
||||
|
||||
if bias and weights.process_group.rank() == 0:
|
||||
# Rank is only on the first rank process
|
||||
|
@ -207,16 +291,41 @@ class TensorParallelRowLinear(SuperLayer):
|
|||
return cls(
|
||||
get_linear(weight, bias, config.quantize),
|
||||
process_group=weights.process_group,
|
||||
layer_name=layer_name,
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
|
||||
def forward(
|
||||
self, input: torch.Tensor, adapter_data=None, reduce: bool = True
|
||||
) -> torch.Tensor:
|
||||
out = super().forward(input)
|
||||
if self.process_group.size() > 1 and reduce:
|
||||
if SYSTEM == "ipex":
|
||||
ipex.distributed.all_reduce(out, group=self.process_group)
|
||||
else:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
return out
|
||||
|
||||
# noop if no lora data is provided
|
||||
if adapter_data is None:
|
||||
return out
|
||||
|
||||
# Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285
|
||||
stride = out.shape[-1] // self.process_group.size()
|
||||
start_idx = self.process_group.rank() * stride
|
||||
end_idx = (self.process_group.rank() + 1) * stride
|
||||
|
||||
res = forward_layer_type(
|
||||
process_group=self.process_group,
|
||||
layer_id=self.layer_name,
|
||||
result=out,
|
||||
input=input,
|
||||
adapter_data=adapter_data,
|
||||
layer_type=self.layer_name,
|
||||
start_idx=start_idx,
|
||||
end_idx=end_idx,
|
||||
use_all_gather=False,
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class TensorParallelEmbedding(torch.nn.Module):
|
||||
|
|
|
@ -18,8 +18,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
@ -33,14 +31,11 @@ from text_generation_server.layers.attention import (
|
|||
attention,
|
||||
reshape_and_cache,
|
||||
)
|
||||
from text_generation_server.models.globals import FLASH_DECODING
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
SpeculativeHead,
|
||||
TensorParallelMultiAdapterLinear,
|
||||
TensorParallelAdapterRowLinear,
|
||||
)
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
|
@ -58,8 +53,6 @@ def load_attention(config, prefix: str, weights, layer_id):
|
|||
# Only defined in granite.
|
||||
bias = getattr(config, "attention_bias", False)
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
sizes = None
|
||||
prefixes = None
|
||||
|
||||
if config.model_type == "phi3":
|
||||
prefix = f"{prefix}.qkv_proj"
|
||||
|
@ -82,27 +75,21 @@ def load_attention(config, prefix: str, weights, layer_id):
|
|||
num_key_value_heads=config.num_key_value_heads,
|
||||
)
|
||||
else:
|
||||
prefixes = ["q_proj", "k_proj", "v_proj"]
|
||||
sizes = [
|
||||
head_size * config.num_attention_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
]
|
||||
|
||||
base_layer = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
sizes=[
|
||||
head_size * config.num_attention_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
],
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
return TensorParallelMultiAdapterLinear.load(
|
||||
base_layer=base_layer,
|
||||
layer_id=layer_id,
|
||||
layer_names=prefixes,
|
||||
sizes=sizes,
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
return base_layer
|
||||
|
||||
|
||||
class FlashLlamaAttention(torch.nn.Module):
|
||||
|
@ -150,20 +137,13 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
self.query_key_value = load_attention(config, prefix, weights, index)
|
||||
self.index = index
|
||||
|
||||
o_proj = TensorParallelRowLinear.load(
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||
o_proj,
|
||||
index,
|
||||
"o_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
|
@ -247,54 +227,37 @@ class LlamaMLP(nn.Module):
|
|||
),
|
||||
)
|
||||
)
|
||||
prefixes = None
|
||||
sizes = None
|
||||
|
||||
# Fuse gate and up proj
|
||||
bias = getattr(config, "mlp_bias", False)
|
||||
if config.model_type == "phi3":
|
||||
gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
||||
config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
prefixes = [f"gate_proj", f"up_proj"]
|
||||
sizes = [
|
||||
config.intermediate_size,
|
||||
config.intermediate_size,
|
||||
]
|
||||
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||
weights=weights,
|
||||
dim=0,
|
||||
bias=bias,
|
||||
sizes=[
|
||||
config.intermediate_size,
|
||||
config.intermediate_size,
|
||||
],
|
||||
layer_id=index,
|
||||
)
|
||||
|
||||
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||
gate_up_proj,
|
||||
index,
|
||||
layer_names=prefixes,
|
||||
sizes=sizes,
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
down_proj = TensorParallelRowLinear.load(
|
||||
self.down_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||
down_proj,
|
||||
index,
|
||||
"down_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
|
|
|
@ -28,7 +28,6 @@ from typing import Optional, List, Tuple
|
|||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
|
@ -38,9 +37,6 @@ from text_generation_server.layers import (
|
|||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
TensorParallelMultiAdapterLinear,
|
||||
TensorParallelAdapterRowLinear,
|
||||
)
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
|
@ -138,39 +134,26 @@ class MistralAttention(torch.nn.Module):
|
|||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
|
||||
query_key_value = TensorParallelColumnLinear.load_multi(
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
self.query_key_value = TensorParallelMultiAdapterLinear.load(
|
||||
query_key_value,
|
||||
layer_id,
|
||||
["q_proj", "k_proj", "v_proj"],
|
||||
sizes=[
|
||||
head_size * config.num_attention_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
],
|
||||
process_group=weights.process_group,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
o_proj = TensorParallelRowLinear.load(
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||
o_proj,
|
||||
layer_id,
|
||||
"o_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
|
@ -264,37 +247,24 @@ class MistralMLP(nn.Module):
|
|||
)
|
||||
)
|
||||
# Fuse gate and up proj
|
||||
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||
weights=weights,
|
||||
dim=0,
|
||||
bias=False,
|
||||
)
|
||||
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||
gate_up_proj,
|
||||
layer_id,
|
||||
["gate_proj", "up_proj"],
|
||||
sizes=[
|
||||
config.intermediate_size,
|
||||
config.intermediate_size,
|
||||
],
|
||||
process_group=weights.process_group,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
down_proj = TensorParallelRowLinear.load(
|
||||
self.down_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||
down_proj,
|
||||
layer_id,
|
||||
"down_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue