feat: refactor lora linear and remove adapter layers
This commit is contained in:
parent
da82c63a4f
commit
82fc879e17
|
@ -208,7 +208,7 @@ class LoraWeights(AdapterWeights):
|
||||||
for layer_id in range(nlayers):
|
for layer_id in range(nlayers):
|
||||||
key = (layer_id, layer_type)
|
key = (layer_id, layer_type)
|
||||||
weight_name, layer = model.target_to_layer[key]
|
weight_name, layer = model.target_to_layer[key]
|
||||||
base_weight = layer.base_layer.linear.weight
|
base_weight = layer.linear.weight
|
||||||
base_device = base_weight.device
|
base_device = base_weight.device
|
||||||
|
|
||||||
if weight_name not in module_map:
|
if weight_name not in module_map:
|
||||||
|
|
|
@ -13,8 +13,13 @@ from text_generation_server.layers.speculative import SpeculativeHead
|
||||||
from text_generation_server.layers.layernorm import load_layer_norm
|
from text_generation_server.layers.layernorm import load_layer_norm
|
||||||
from text_generation_server.layers.conv import load_conv2d
|
from text_generation_server.layers.conv import load_conv2d
|
||||||
|
|
||||||
from text_generation_server.layers.lora import (
|
__all__ = [
|
||||||
LoraLinear,
|
"TensorParallelColumnLinear",
|
||||||
TensorParallelMultiAdapterLinear,
|
"TensorParallelRowLinear",
|
||||||
TensorParallelAdapterRowLinear,
|
"TensorParallelEmbedding",
|
||||||
)
|
"get_linear",
|
||||||
|
"FastLinear",
|
||||||
|
"SpeculativeHead",
|
||||||
|
"load_layer_norm",
|
||||||
|
"load_conv2d",
|
||||||
|
]
|
||||||
|
|
|
@ -1,14 +1,11 @@
|
||||||
import math
|
from typing import Optional
|
||||||
import os
|
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple, List
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
from accelerate import init_empty_weights
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from text_generation_server.adapters import AdapterBatchData
|
||||||
|
from text_generation_server.adapters.lora import BatchLoraWeights
|
||||||
from text_generation_server.utils.sgmv import (
|
from text_generation_server.utils.sgmv import (
|
||||||
add_lora_a_bgmv,
|
add_lora_a_bgmv,
|
||||||
add_lora_b_bgmv,
|
add_lora_b_bgmv,
|
||||||
|
@ -18,37 +15,47 @@ from text_generation_server.utils.sgmv import (
|
||||||
orient_for_rank,
|
orient_for_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from text_generation_server.adapters import AdapterBatchData
|
def gather_lora_weights(
|
||||||
from text_generation_server.adapters.lora import BatchLoraWeights
|
process_group: ProcessGroup,
|
||||||
|
weights: torch.Tensor,
|
||||||
|
use_all_gather: bool = False,
|
||||||
|
) -> BatchLoraWeights:
|
||||||
|
if use_all_gather:
|
||||||
|
# Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise.
|
||||||
|
# We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks.
|
||||||
|
#
|
||||||
|
# TODO: this is not very efficient as we do an all-gather for every adapter,
|
||||||
|
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
|
||||||
|
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
|
||||||
|
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
|
||||||
|
gathered_tensors = [
|
||||||
|
torch.empty_like(weights) for _ in range(process_group.size())
|
||||||
|
]
|
||||||
|
torch.distributed.all_gather(gathered_tensors, weights, group=process_group)
|
||||||
|
return torch.cat(gathered_tensors, dim=1)
|
||||||
|
else:
|
||||||
|
torch.distributed.all_reduce(weights, group=process_group)
|
||||||
|
return weights
|
||||||
|
|
||||||
|
|
||||||
class LoraLinear(nn.Module):
|
def forward_layer_type(
|
||||||
def __init__(
|
process_group: ProcessGroup,
|
||||||
self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup
|
layer_id: int,
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.base_layer = base_layer
|
|
||||||
self.layer_id = layer_id
|
|
||||||
self.process_group = process_group
|
|
||||||
|
|
||||||
def forward_layer_type(
|
|
||||||
self,
|
|
||||||
result: torch.Tensor,
|
result: torch.Tensor,
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
adapter_data: "AdapterBatchData",
|
adapter_data: "AdapterBatchData",
|
||||||
layer_type: str,
|
layer_type: str,
|
||||||
start_idx: int,
|
start_idx: int,
|
||||||
end_idx: int,
|
end_idx: int,
|
||||||
) -> torch.Tensor:
|
use_all_gather: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
if adapter_data is None:
|
if adapter_data is None:
|
||||||
return result
|
return result
|
||||||
data = adapter_data.data.get(layer_type)
|
data = adapter_data.data.get(layer_type)
|
||||||
data: Optional["BatchLoraWeights"] = (
|
data: Optional["BatchLoraWeights"] = data.get("lora") if data is not None else None
|
||||||
data.get("lora") if data is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
|
if has_sgmv() and data is not None and data.can_vectorize(process_group):
|
||||||
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
|
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
|
||||||
# The 'result' tensor represents the full output, which can vary in size based on
|
# The 'result' tensor represents the full output, which can vary in size based on
|
||||||
# the layer type (e.g., attention vs. feed-forward layers). We define the current
|
# the layer type (e.g., attention vs. feed-forward layers). We define the current
|
||||||
|
@ -81,12 +88,12 @@ class LoraLinear(nn.Module):
|
||||||
lora_a_ptr,
|
lora_a_ptr,
|
||||||
rank_segments.segment_starts,
|
rank_segments.segment_starts,
|
||||||
rank_segments.segment_ends,
|
rank_segments.segment_ends,
|
||||||
self.layer_id,
|
layer_id,
|
||||||
r,
|
r,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.process_group.size() > 1:
|
if process_group.size() > 1:
|
||||||
v = self.collect_lora_a(v)
|
v = gather_lora_weights(process_group, v, use_all_gather)
|
||||||
|
|
||||||
lora_b_sgmv_cutlass(
|
lora_b_sgmv_cutlass(
|
||||||
proj,
|
proj,
|
||||||
|
@ -95,7 +102,7 @@ class LoraLinear(nn.Module):
|
||||||
lora_b_ptr,
|
lora_b_ptr,
|
||||||
rank_segments.segment_starts,
|
rank_segments.segment_starts,
|
||||||
rank_segments.segment_ends,
|
rank_segments.segment_ends,
|
||||||
self.layer_id,
|
layer_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use BGMV for decode
|
# Use BGMV for decode
|
||||||
|
@ -108,18 +115,18 @@ class LoraLinear(nn.Module):
|
||||||
input,
|
input,
|
||||||
lora_a_ptr,
|
lora_a_ptr,
|
||||||
rank_segments.indices,
|
rank_segments.indices,
|
||||||
self.layer_id,
|
layer_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.process_group.size() > 1:
|
if process_group.size() > 1:
|
||||||
v = self.collect_lora_a(v)
|
v = gather_lora_weights(process_group, v, use_all_gather)
|
||||||
|
|
||||||
add_lora_b_bgmv(
|
add_lora_b_bgmv(
|
||||||
proj,
|
proj,
|
||||||
v,
|
v,
|
||||||
lora_b_ptr,
|
lora_b_ptr,
|
||||||
rank_segments.indices,
|
rank_segments.indices,
|
||||||
self.layer_id,
|
layer_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if end_idx - start_idx != result.shape[1]:
|
if end_idx - start_idx != result.shape[1]:
|
||||||
|
@ -132,155 +139,36 @@ class LoraLinear(nn.Module):
|
||||||
.to(input.dtype)
|
.to(input.dtype)
|
||||||
.view(-1, 1)
|
.view(-1, 1)
|
||||||
)
|
)
|
||||||
layer_result = self.forward_lora(
|
layer_result = forward_lora(
|
||||||
input, data, adapter_index, adapter_mask
|
process_group=process_group,
|
||||||
|
layer_id=layer_id,
|
||||||
|
input=input,
|
||||||
|
data=data,
|
||||||
|
adapter_index=adapter_index,
|
||||||
|
adapter_mask=adapter_mask,
|
||||||
|
use_all_gather=use_all_gather,
|
||||||
)
|
)
|
||||||
result[:, start_idx:end_idx] += layer_result
|
result[:, start_idx:end_idx] += layer_result
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def forward_lora(
|
|
||||||
self,
|
def forward_lora(
|
||||||
|
process_group: ProcessGroup,
|
||||||
|
layer_id,
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
data: "BatchLoraWeights",
|
data: "BatchLoraWeights",
|
||||||
adapter_index: int,
|
adapter_index: int,
|
||||||
adapter_mask: torch.Tensor,
|
adapter_mask: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
use_all_gather: bool = False,
|
||||||
lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
|
) -> torch.Tensor:
|
||||||
lora_b = data.lora_b[adapter_index][self.layer_id, :, :]
|
lora_a = data.lora_a[adapter_index][layer_id, :, :]
|
||||||
|
lora_b = data.lora_b[adapter_index][layer_id, :, :]
|
||||||
|
|
||||||
lora_a = orient_for_rank(lora_a, lora_b.size(0))
|
lora_a = orient_for_rank(lora_a, lora_b.size(0))
|
||||||
|
|
||||||
a_out = input @ lora_a
|
a_out = input @ lora_a
|
||||||
if self.process_group.size() > 1:
|
if process_group.size() > 1:
|
||||||
a_out = self.collect_lora_a(a_out)
|
a_out = gather_lora_weights(process_group, a_out, use_all_gather)
|
||||||
|
|
||||||
result = (a_out @ lora_b) * adapter_mask
|
result = (a_out @ lora_b) * adapter_mask
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
|
||||||
raise NotImplementedError("Implemented in subclasses")
|
|
||||||
|
|
||||||
|
|
||||||
class TensorParallelMultiAdapterLinear(LoraLinear):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
base_layer: nn.Module,
|
|
||||||
layer_id: int,
|
|
||||||
layer_names: List[str],
|
|
||||||
sizes: List[int],
|
|
||||||
process_group: ProcessGroup,
|
|
||||||
):
|
|
||||||
super().__init__(base_layer, layer_id, process_group)
|
|
||||||
self.layer_names = layer_names
|
|
||||||
self.sizes = sizes
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(
|
|
||||||
cls,
|
|
||||||
base_layer: nn.Module,
|
|
||||||
layer_id: int,
|
|
||||||
layer_names: List[str],
|
|
||||||
sizes: List[int],
|
|
||||||
process_group: ProcessGroup,
|
|
||||||
):
|
|
||||||
return TensorParallelMultiAdapterLinear(
|
|
||||||
base_layer, layer_id, layer_names, sizes, process_group
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
|
|
||||||
) -> torch.Tensor:
|
|
||||||
result = self.base_layer(input)
|
|
||||||
|
|
||||||
# noop if no layer names are provided (e.g. for models without adapters)
|
|
||||||
if self.layer_names is None:
|
|
||||||
return result
|
|
||||||
|
|
||||||
# handle models like Bloom that have inputs of shape
|
|
||||||
# (batch_size, sequence_length, hidden_size)
|
|
||||||
# we need to reshape them to (batch_size * sequence_length, hidden_size)
|
|
||||||
# for the LoRA computation, then reshape back
|
|
||||||
prev_shape = result.shape
|
|
||||||
is_3d = len(input.shape) >= 3
|
|
||||||
if is_3d:
|
|
||||||
input = input.reshape(-1, input.shape[-1])
|
|
||||||
result = result.reshape(-1, result.shape[-1])
|
|
||||||
|
|
||||||
offset = 0
|
|
||||||
for i, layer_name in enumerate(self.layer_names):
|
|
||||||
start_idx = offset // self.process_group.size()
|
|
||||||
# The 'sizes' parameter is essential in tensor-parallel setups for handling multiple
|
|
||||||
# projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It
|
|
||||||
# ensures correct slicing of the result tensor, accommodating variations like grouped-query
|
|
||||||
# attention where k_proj and v_proj differ from q_proj. This allows precise application of
|
|
||||||
# LoRA adapters to each sub-component of the multi-head attention mechanism, managing the
|
|
||||||
# different projection sizes across layers and model architectures.
|
|
||||||
if self.sizes is not None:
|
|
||||||
offset += self.sizes[i]
|
|
||||||
end_idx = offset // self.process_group.size()
|
|
||||||
else:
|
|
||||||
end_idx = result.shape[1]
|
|
||||||
|
|
||||||
result = self.forward_layer_type(
|
|
||||||
result, input, adapter_data, layer_name, start_idx, end_idx
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_3d:
|
|
||||||
result = result.reshape(prev_shape)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
|
||||||
# Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise.
|
|
||||||
# We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks.
|
|
||||||
#
|
|
||||||
# TODO(travis): this is not very efficient as we do an all-gather for every adapter,
|
|
||||||
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
|
|
||||||
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
|
|
||||||
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
|
|
||||||
gathered_tensors = [
|
|
||||||
torch.empty_like(a_out) for _ in range(self.process_group.size())
|
|
||||||
]
|
|
||||||
torch.distributed.all_gather(gathered_tensors, a_out)
|
|
||||||
return torch.cat(gathered_tensors, dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
class TensorParallelAdapterRowLinear(LoraLinear):
|
|
||||||
def __init__(self, base_layer, layer_id, layer_name, process_group):
|
|
||||||
super().__init__(base_layer, layer_id, process_group)
|
|
||||||
self.layer_name = layer_name
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(cls, base_layer, layer_id, layer_name, process_group):
|
|
||||||
return cls(base_layer, layer_id, layer_name, process_group)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
|
|
||||||
) -> torch.Tensor:
|
|
||||||
result = self.base_layer(input)
|
|
||||||
|
|
||||||
if self.layer_name is None:
|
|
||||||
return result
|
|
||||||
|
|
||||||
# Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285
|
|
||||||
stride = result.shape[-1] // self.process_group.size()
|
|
||||||
start_idx = self.process_group.rank() * stride
|
|
||||||
end_idx = (self.process_group.rank() + 1) * stride
|
|
||||||
|
|
||||||
self.forward_layer_type(
|
|
||||||
result, input, adapter_data, self.layer_name, start_idx, end_idx
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
|
||||||
# Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise.
|
|
||||||
# We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks.
|
|
||||||
#
|
|
||||||
# TODO(travis): this is not very efficient as we do an all-reduce for every adapter,
|
|
||||||
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
|
|
||||||
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
|
|
||||||
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
|
|
||||||
torch.distributed.all_reduce(a_out, group=self.process_group)
|
|
||||||
return a_out
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ from typing import Iterable, List
|
||||||
from text_generation_server.layers.linear import get_linear, FastLinear
|
from text_generation_server.layers.linear import get_linear, FastLinear
|
||||||
from text_generation_server.layers.exl2 import Exl2Weight
|
from text_generation_server.layers.exl2 import Exl2Weight
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.layers.lora import forward_layer_type
|
||||||
|
|
||||||
if SYSTEM == "ipex":
|
if SYSTEM == "ipex":
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
@ -126,6 +127,20 @@ class TensorParallelHead(SuperLayer):
|
||||||
|
|
||||||
|
|
||||||
class TensorParallelColumnLinear(SuperLayer):
|
class TensorParallelColumnLinear(SuperLayer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
linear,
|
||||||
|
process_group: torch.distributed.ProcessGroup = None,
|
||||||
|
layer_names: List[str] = None,
|
||||||
|
sizes: List[int] = None,
|
||||||
|
layer_id: str = None,
|
||||||
|
):
|
||||||
|
super().__init__(linear)
|
||||||
|
self.process_group = process_group
|
||||||
|
self.layer_names = layer_names
|
||||||
|
self.sizes = sizes
|
||||||
|
self.layer_id = layer_id
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_gate_up(cls, config, prefix: str, weights, bias: bool):
|
def load_gate_up(cls, config, prefix: str, weights, bias: bool):
|
||||||
"""Specific method when the QKV was joined after the fact"""
|
"""Specific method when the QKV was joined after the fact"""
|
||||||
|
@ -171,7 +186,18 @@ class TensorParallelColumnLinear(SuperLayer):
|
||||||
return cls(linear)
|
return cls(linear)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
|
def load_multi(
|
||||||
|
cls,
|
||||||
|
config,
|
||||||
|
prefixes: List[str],
|
||||||
|
weights,
|
||||||
|
bias: bool,
|
||||||
|
dim: int,
|
||||||
|
sizes: List[int] = None,
|
||||||
|
layer_id: str = None,
|
||||||
|
):
|
||||||
|
# infer layer_names from prefixes
|
||||||
|
layer_names = [prefix.split(".")[-1] for prefix in prefixes]
|
||||||
if config.quantize == "exl2":
|
if config.quantize == "exl2":
|
||||||
linears = []
|
linears = []
|
||||||
for prefix in prefixes:
|
for prefix in prefixes:
|
||||||
|
@ -187,17 +213,75 @@ class TensorParallelColumnLinear(SuperLayer):
|
||||||
else:
|
else:
|
||||||
bias = None
|
bias = None
|
||||||
linear = get_linear(weight, bias, config.quantize)
|
linear = get_linear(weight, bias, config.quantize)
|
||||||
return cls(linear)
|
|
||||||
|
return cls(
|
||||||
|
linear,
|
||||||
|
process_group=weights.process_group,
|
||||||
|
layer_names=layer_names,
|
||||||
|
sizes=sizes,
|
||||||
|
layer_id=layer_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor, adapter_data=None) -> torch.Tensor:
|
||||||
|
result = super().forward(input)
|
||||||
|
|
||||||
|
# noop if no lora data is provided
|
||||||
|
if adapter_data is None:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# handle models like Bloom that have inputs of shape
|
||||||
|
# (batch_size, sequence_length, hidden_size)
|
||||||
|
# we need to reshape them to (batch_size * sequence_length, hidden_size)
|
||||||
|
# for the LoRA computation, then reshape back
|
||||||
|
prev_shape = result.shape
|
||||||
|
is_3d = len(input.shape) >= 3
|
||||||
|
if is_3d:
|
||||||
|
input = input.reshape(-1, input.shape[-1])
|
||||||
|
result = result.reshape(-1, result.shape[-1])
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
for i, layer_name in enumerate(self.layer_names):
|
||||||
|
start_idx = offset // self.process_group.size()
|
||||||
|
# The 'sizes' parameter is essential in tensor-parallel setups for handling multiple
|
||||||
|
# projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It
|
||||||
|
# ensures correct slicing of the result tensor, accommodating variations like grouped-query
|
||||||
|
# attention where k_proj and v_proj differ from q_proj. This allows precise application of
|
||||||
|
# LoRA adapters to each sub-component of the multi-head attention mechanism, managing the
|
||||||
|
# different projection sizes across layers and model architectures.
|
||||||
|
if self.sizes is not None:
|
||||||
|
offset += self.sizes[i]
|
||||||
|
end_idx = offset // self.process_group.size()
|
||||||
|
else:
|
||||||
|
end_idx = result.shape[1]
|
||||||
|
|
||||||
|
result = forward_layer_type(
|
||||||
|
process_group=self.process_group,
|
||||||
|
layer_id=self.layer_id,
|
||||||
|
result=result,
|
||||||
|
input=input,
|
||||||
|
adapter_data=adapter_data,
|
||||||
|
layer_type=layer_name,
|
||||||
|
start_idx=start_idx,
|
||||||
|
end_idx=end_idx,
|
||||||
|
use_all_gather=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_3d:
|
||||||
|
result = result.reshape(prev_shape)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class TensorParallelRowLinear(SuperLayer):
|
class TensorParallelRowLinear(SuperLayer):
|
||||||
def __init__(self, linear, process_group):
|
def __init__(self, linear, process_group, layer_name):
|
||||||
super().__init__(linear)
|
super().__init__(linear)
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
|
self.layer_name = layer_name
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, config, prefix: str, weights, bias: bool):
|
def load(cls, config, prefix: str, weights, bias: bool):
|
||||||
weight = weights.get_weights_row(prefix)
|
weight = weights.get_weights_row(prefix)
|
||||||
|
layer_name = prefix.split(".")[-1]
|
||||||
|
|
||||||
if bias and weights.process_group.rank() == 0:
|
if bias and weights.process_group.rank() == 0:
|
||||||
# Rank is only on the first rank process
|
# Rank is only on the first rank process
|
||||||
|
@ -207,17 +291,42 @@ class TensorParallelRowLinear(SuperLayer):
|
||||||
return cls(
|
return cls(
|
||||||
get_linear(weight, bias, config.quantize),
|
get_linear(weight, bias, config.quantize),
|
||||||
process_group=weights.process_group,
|
process_group=weights.process_group,
|
||||||
|
layer_name=layer_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
|
def forward(
|
||||||
|
self, input: torch.Tensor, adapter_data=None, reduce: bool = True
|
||||||
|
) -> torch.Tensor:
|
||||||
out = super().forward(input)
|
out = super().forward(input)
|
||||||
if self.process_group.size() > 1 and reduce:
|
if self.process_group.size() > 1 and reduce:
|
||||||
if SYSTEM == "ipex":
|
if SYSTEM == "ipex":
|
||||||
ipex.distributed.all_reduce(out, group=self.process_group)
|
ipex.distributed.all_reduce(out, group=self.process_group)
|
||||||
else:
|
else:
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
|
||||||
|
# noop if no lora data is provided
|
||||||
|
if adapter_data is None:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
# Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285
|
||||||
|
stride = out.shape[-1] // self.process_group.size()
|
||||||
|
start_idx = self.process_group.rank() * stride
|
||||||
|
end_idx = (self.process_group.rank() + 1) * stride
|
||||||
|
|
||||||
|
res = forward_layer_type(
|
||||||
|
process_group=self.process_group,
|
||||||
|
layer_id=self.layer_name,
|
||||||
|
result=out,
|
||||||
|
input=input,
|
||||||
|
adapter_data=adapter_data,
|
||||||
|
layer_type=self.layer_name,
|
||||||
|
start_idx=start_idx,
|
||||||
|
end_idx=end_idx,
|
||||||
|
use_all_gather=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
class TensorParallelEmbedding(torch.nn.Module):
|
class TensorParallelEmbedding(torch.nn.Module):
|
||||||
def __init__(self, prefix: str, weights, reduce=True):
|
def __init__(self, prefix: str, weights, reduce=True):
|
||||||
|
|
|
@ -18,8 +18,6 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
@ -33,14 +31,11 @@ from text_generation_server.layers.attention import (
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.globals import FLASH_DECODING
|
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
TensorParallelMultiAdapterLinear,
|
|
||||||
TensorParallelAdapterRowLinear,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
|
@ -58,8 +53,6 @@ def load_attention(config, prefix: str, weights, layer_id):
|
||||||
# Only defined in granite.
|
# Only defined in granite.
|
||||||
bias = getattr(config, "attention_bias", False)
|
bias = getattr(config, "attention_bias", False)
|
||||||
head_size = config.hidden_size // config.num_attention_heads
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
sizes = None
|
|
||||||
prefixes = None
|
|
||||||
|
|
||||||
if config.model_type == "phi3":
|
if config.model_type == "phi3":
|
||||||
prefix = f"{prefix}.qkv_proj"
|
prefix = f"{prefix}.qkv_proj"
|
||||||
|
@ -82,27 +75,21 @@ def load_attention(config, prefix: str, weights, layer_id):
|
||||||
num_key_value_heads=config.num_key_value_heads,
|
num_key_value_heads=config.num_key_value_heads,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prefixes = ["q_proj", "k_proj", "v_proj"]
|
|
||||||
sizes = [
|
|
||||||
head_size * config.num_attention_heads,
|
|
||||||
head_size * config.num_key_value_heads,
|
|
||||||
head_size * config.num_key_value_heads,
|
|
||||||
]
|
|
||||||
base_layer = TensorParallelColumnLinear.load_multi(
|
base_layer = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
dim=0,
|
dim=0,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
sizes=[
|
||||||
|
head_size * config.num_attention_heads,
|
||||||
return TensorParallelMultiAdapterLinear.load(
|
head_size * config.num_key_value_heads,
|
||||||
base_layer=base_layer,
|
head_size * config.num_key_value_heads,
|
||||||
|
],
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
layer_names=prefixes,
|
|
||||||
sizes=sizes,
|
|
||||||
process_group=weights.process_group,
|
|
||||||
)
|
)
|
||||||
|
return base_layer
|
||||||
|
|
||||||
|
|
||||||
class FlashLlamaAttention(torch.nn.Module):
|
class FlashLlamaAttention(torch.nn.Module):
|
||||||
|
@ -150,20 +137,13 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
self.query_key_value = load_attention(config, prefix, weights, index)
|
self.query_key_value = load_attention(config, prefix, weights, index)
|
||||||
self.index = index
|
self.index = index
|
||||||
|
|
||||||
o_proj = TensorParallelRowLinear.load(
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.o_proj",
|
prefix=f"{prefix}.o_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.o_proj = TensorParallelAdapterRowLinear.load(
|
|
||||||
o_proj,
|
|
||||||
index,
|
|
||||||
"o_proj",
|
|
||||||
process_group=weights.process_group,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
self.kv_head_mapping = torch.arange(
|
self.kv_head_mapping = torch.arange(
|
||||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
|
@ -247,54 +227,37 @@ class LlamaMLP(nn.Module):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
prefixes = None
|
|
||||||
sizes = None
|
|
||||||
|
|
||||||
# Fuse gate and up proj
|
# Fuse gate and up proj
|
||||||
bias = getattr(config, "mlp_bias", False)
|
bias = getattr(config, "mlp_bias", False)
|
||||||
if config.model_type == "phi3":
|
if config.model_type == "phi3":
|
||||||
gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.gate_up_proj",
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prefixes = [f"gate_proj", f"up_proj"]
|
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
sizes = [
|
|
||||||
config.intermediate_size,
|
|
||||||
config.intermediate_size,
|
|
||||||
]
|
|
||||||
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
weights=weights,
|
weights=weights,
|
||||||
dim=0,
|
dim=0,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
|
sizes=[
|
||||||
|
config.intermediate_size,
|
||||||
|
config.intermediate_size,
|
||||||
|
],
|
||||||
|
layer_id=index,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
self.down_proj = TensorParallelRowLinear.load(
|
||||||
gate_up_proj,
|
|
||||||
index,
|
|
||||||
layer_names=prefixes,
|
|
||||||
sizes=sizes,
|
|
||||||
process_group=weights.process_group,
|
|
||||||
)
|
|
||||||
|
|
||||||
down_proj = TensorParallelRowLinear.load(
|
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.down_proj",
|
prefix=f"{prefix}.down_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.down_proj = TensorParallelAdapterRowLinear.load(
|
|
||||||
down_proj,
|
|
||||||
index,
|
|
||||||
"down_proj",
|
|
||||||
process_group=weights.process_group,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.intermediate_size = (
|
self.intermediate_size = (
|
||||||
config.intermediate_size // weights.process_group.size()
|
config.intermediate_size // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
|
|
@ -28,7 +28,6 @@ from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
Seqlen,
|
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
|
@ -38,9 +37,6 @@ from text_generation_server.layers import (
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
|
||||||
TensorParallelMultiAdapterLinear,
|
|
||||||
TensorParallelAdapterRowLinear,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
|
@ -138,39 +134,26 @@ class MistralAttention(torch.nn.Module):
|
||||||
config.num_key_value_heads // weights.process_group.size()
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
|
||||||
query_key_value = TensorParallelColumnLinear.load_multi(
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
dim=0,
|
dim=0,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
|
||||||
|
|
||||||
head_size = config.hidden_size // config.num_attention_heads
|
|
||||||
self.query_key_value = TensorParallelMultiAdapterLinear.load(
|
|
||||||
query_key_value,
|
|
||||||
layer_id,
|
|
||||||
["q_proj", "k_proj", "v_proj"],
|
|
||||||
sizes=[
|
sizes=[
|
||||||
head_size * config.num_attention_heads,
|
head_size * config.num_attention_heads,
|
||||||
head_size * config.num_key_value_heads,
|
head_size * config.num_key_value_heads,
|
||||||
head_size * config.num_key_value_heads,
|
head_size * config.num_key_value_heads,
|
||||||
],
|
],
|
||||||
process_group=weights.process_group,
|
layer_id=layer_id,
|
||||||
)
|
)
|
||||||
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
o_proj = TensorParallelRowLinear.load(
|
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.o_proj",
|
prefix=f"{prefix}.o_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
self.o_proj = TensorParallelAdapterRowLinear.load(
|
|
||||||
o_proj,
|
|
||||||
layer_id,
|
|
||||||
"o_proj",
|
|
||||||
process_group=weights.process_group,
|
|
||||||
)
|
|
||||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
self.kv_head_mapping = torch.arange(
|
self.kv_head_mapping = torch.arange(
|
||||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
|
@ -264,37 +247,24 @@ class MistralMLP(nn.Module):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Fuse gate and up proj
|
# Fuse gate and up proj
|
||||||
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
weights=weights,
|
weights=weights,
|
||||||
dim=0,
|
dim=0,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
|
||||||
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
|
||||||
gate_up_proj,
|
|
||||||
layer_id,
|
|
||||||
["gate_proj", "up_proj"],
|
|
||||||
sizes=[
|
sizes=[
|
||||||
config.intermediate_size,
|
config.intermediate_size,
|
||||||
config.intermediate_size,
|
config.intermediate_size,
|
||||||
],
|
],
|
||||||
process_group=weights.process_group,
|
layer_id=layer_id,
|
||||||
)
|
)
|
||||||
|
self.down_proj = TensorParallelRowLinear.load(
|
||||||
down_proj = TensorParallelRowLinear.load(
|
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.down_proj",
|
prefix=f"{prefix}.down_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.down_proj = TensorParallelAdapterRowLinear.load(
|
|
||||||
down_proj,
|
|
||||||
layer_id,
|
|
||||||
"down_proj",
|
|
||||||
process_group=weights.process_group,
|
|
||||||
)
|
|
||||||
self.intermediate_size = (
|
self.intermediate_size = (
|
||||||
config.intermediate_size // weights.process_group.size()
|
config.intermediate_size // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue