feat: refactor lora linear and remove adapter layers

This commit is contained in:
drbh 2024-07-18 19:58:55 +00:00
parent da82c63a4f
commit 82fc879e17
6 changed files with 297 additions and 362 deletions

View File

@ -208,7 +208,7 @@ class LoraWeights(AdapterWeights):
for layer_id in range(nlayers): for layer_id in range(nlayers):
key = (layer_id, layer_type) key = (layer_id, layer_type)
weight_name, layer = model.target_to_layer[key] weight_name, layer = model.target_to_layer[key]
base_weight = layer.base_layer.linear.weight base_weight = layer.linear.weight
base_device = base_weight.device base_device = base_weight.device
if weight_name not in module_map: if weight_name not in module_map:

View File

@ -13,8 +13,13 @@ from text_generation_server.layers.speculative import SpeculativeHead
from text_generation_server.layers.layernorm import load_layer_norm from text_generation_server.layers.layernorm import load_layer_norm
from text_generation_server.layers.conv import load_conv2d from text_generation_server.layers.conv import load_conv2d
from text_generation_server.layers.lora import ( __all__ = [
LoraLinear, "TensorParallelColumnLinear",
TensorParallelMultiAdapterLinear, "TensorParallelRowLinear",
TensorParallelAdapterRowLinear, "TensorParallelEmbedding",
) "get_linear",
"FastLinear",
"SpeculativeHead",
"load_layer_norm",
"load_conv2d",
]

View File

@ -1,14 +1,11 @@
import math from typing import Optional
import os
from typing import TYPE_CHECKING, Optional, Tuple, List
import torch import torch
import torch.distributed import torch.distributed
from accelerate import init_empty_weights
from torch import nn
from torch.nn import functional as F
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from text_generation_server.adapters import AdapterBatchData
from text_generation_server.adapters.lora import BatchLoraWeights
from text_generation_server.utils.sgmv import ( from text_generation_server.utils.sgmv import (
add_lora_a_bgmv, add_lora_a_bgmv,
add_lora_b_bgmv, add_lora_b_bgmv,
@ -18,37 +15,47 @@ from text_generation_server.utils.sgmv import (
orient_for_rank, orient_for_rank,
) )
if TYPE_CHECKING:
from text_generation_server.adapters import AdapterBatchData def gather_lora_weights(
from text_generation_server.adapters.lora import BatchLoraWeights process_group: ProcessGroup,
weights: torch.Tensor,
use_all_gather: bool = False,
) -> BatchLoraWeights:
if use_all_gather:
# Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise.
# We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks.
#
# TODO: this is not very efficient as we do an all-gather for every adapter,
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
gathered_tensors = [
torch.empty_like(weights) for _ in range(process_group.size())
]
torch.distributed.all_gather(gathered_tensors, weights, group=process_group)
return torch.cat(gathered_tensors, dim=1)
else:
torch.distributed.all_reduce(weights, group=process_group)
return weights
class LoraLinear(nn.Module): def forward_layer_type(
def __init__( process_group: ProcessGroup,
self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup layer_id: int,
):
super().__init__()
self.base_layer = base_layer
self.layer_id = layer_id
self.process_group = process_group
def forward_layer_type(
self,
result: torch.Tensor, result: torch.Tensor,
input: torch.Tensor, input: torch.Tensor,
adapter_data: "AdapterBatchData", adapter_data: "AdapterBatchData",
layer_type: str, layer_type: str,
start_idx: int, start_idx: int,
end_idx: int, end_idx: int,
) -> torch.Tensor: use_all_gather: bool = False,
) -> torch.Tensor:
if adapter_data is None: if adapter_data is None:
return result return result
data = adapter_data.data.get(layer_type) data = adapter_data.data.get(layer_type)
data: Optional["BatchLoraWeights"] = ( data: Optional["BatchLoraWeights"] = data.get("lora") if data is not None else None
data.get("lora") if data is not None else None
)
if has_sgmv() and data is not None and data.can_vectorize(self.process_group): if has_sgmv() and data is not None and data.can_vectorize(process_group):
# In tensor-parallel configurations, each GPU processes a specific segment of the output. # In tensor-parallel configurations, each GPU processes a specific segment of the output.
# The 'result' tensor represents the full output, which can vary in size based on # The 'result' tensor represents the full output, which can vary in size based on
# the layer type (e.g., attention vs. feed-forward layers). We define the current # the layer type (e.g., attention vs. feed-forward layers). We define the current
@ -81,12 +88,12 @@ class LoraLinear(nn.Module):
lora_a_ptr, lora_a_ptr,
rank_segments.segment_starts, rank_segments.segment_starts,
rank_segments.segment_ends, rank_segments.segment_ends,
self.layer_id, layer_id,
r, r,
) )
if self.process_group.size() > 1: if process_group.size() > 1:
v = self.collect_lora_a(v) v = gather_lora_weights(process_group, v, use_all_gather)
lora_b_sgmv_cutlass( lora_b_sgmv_cutlass(
proj, proj,
@ -95,7 +102,7 @@ class LoraLinear(nn.Module):
lora_b_ptr, lora_b_ptr,
rank_segments.segment_starts, rank_segments.segment_starts,
rank_segments.segment_ends, rank_segments.segment_ends,
self.layer_id, layer_id,
) )
else: else:
# Use BGMV for decode # Use BGMV for decode
@ -108,18 +115,18 @@ class LoraLinear(nn.Module):
input, input,
lora_a_ptr, lora_a_ptr,
rank_segments.indices, rank_segments.indices,
self.layer_id, layer_id,
) )
if self.process_group.size() > 1: if process_group.size() > 1:
v = self.collect_lora_a(v) v = gather_lora_weights(process_group, v, use_all_gather)
add_lora_b_bgmv( add_lora_b_bgmv(
proj, proj,
v, v,
lora_b_ptr, lora_b_ptr,
rank_segments.indices, rank_segments.indices,
self.layer_id, layer_id,
) )
if end_idx - start_idx != result.shape[1]: if end_idx - start_idx != result.shape[1]:
@ -132,155 +139,36 @@ class LoraLinear(nn.Module):
.to(input.dtype) .to(input.dtype)
.view(-1, 1) .view(-1, 1)
) )
layer_result = self.forward_lora( layer_result = forward_lora(
input, data, adapter_index, adapter_mask process_group=process_group,
layer_id=layer_id,
input=input,
data=data,
adapter_index=adapter_index,
adapter_mask=adapter_mask,
use_all_gather=use_all_gather,
) )
result[:, start_idx:end_idx] += layer_result result[:, start_idx:end_idx] += layer_result
return result return result
def forward_lora(
self, def forward_lora(
process_group: ProcessGroup,
layer_id,
input: torch.Tensor, input: torch.Tensor,
data: "BatchLoraWeights", data: "BatchLoraWeights",
adapter_index: int, adapter_index: int,
adapter_mask: torch.Tensor, adapter_mask: torch.Tensor,
) -> torch.Tensor: use_all_gather: bool = False,
lora_a = data.lora_a[adapter_index][self.layer_id, :, :] ) -> torch.Tensor:
lora_b = data.lora_b[adapter_index][self.layer_id, :, :] lora_a = data.lora_a[adapter_index][layer_id, :, :]
lora_b = data.lora_b[adapter_index][layer_id, :, :]
lora_a = orient_for_rank(lora_a, lora_b.size(0)) lora_a = orient_for_rank(lora_a, lora_b.size(0))
a_out = input @ lora_a a_out = input @ lora_a
if self.process_group.size() > 1: if process_group.size() > 1:
a_out = self.collect_lora_a(a_out) a_out = gather_lora_weights(process_group, a_out, use_all_gather)
result = (a_out @ lora_b) * adapter_mask result = (a_out @ lora_b) * adapter_mask
return result return result
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
raise NotImplementedError("Implemented in subclasses")
class TensorParallelMultiAdapterLinear(LoraLinear):
def __init__(
self,
base_layer: nn.Module,
layer_id: int,
layer_names: List[str],
sizes: List[int],
process_group: ProcessGroup,
):
super().__init__(base_layer, layer_id, process_group)
self.layer_names = layer_names
self.sizes = sizes
@classmethod
def load(
cls,
base_layer: nn.Module,
layer_id: int,
layer_names: List[str],
sizes: List[int],
process_group: ProcessGroup,
):
return TensorParallelMultiAdapterLinear(
base_layer, layer_id, layer_names, sizes, process_group
)
def forward(
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
) -> torch.Tensor:
result = self.base_layer(input)
# noop if no layer names are provided (e.g. for models without adapters)
if self.layer_names is None:
return result
# handle models like Bloom that have inputs of shape
# (batch_size, sequence_length, hidden_size)
# we need to reshape them to (batch_size * sequence_length, hidden_size)
# for the LoRA computation, then reshape back
prev_shape = result.shape
is_3d = len(input.shape) >= 3
if is_3d:
input = input.reshape(-1, input.shape[-1])
result = result.reshape(-1, result.shape[-1])
offset = 0
for i, layer_name in enumerate(self.layer_names):
start_idx = offset // self.process_group.size()
# The 'sizes' parameter is essential in tensor-parallel setups for handling multiple
# projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It
# ensures correct slicing of the result tensor, accommodating variations like grouped-query
# attention where k_proj and v_proj differ from q_proj. This allows precise application of
# LoRA adapters to each sub-component of the multi-head attention mechanism, managing the
# different projection sizes across layers and model architectures.
if self.sizes is not None:
offset += self.sizes[i]
end_idx = offset // self.process_group.size()
else:
end_idx = result.shape[1]
result = self.forward_layer_type(
result, input, adapter_data, layer_name, start_idx, end_idx
)
if is_3d:
result = result.reshape(prev_shape)
return result
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
# Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise.
# We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks.
#
# TODO(travis): this is not very efficient as we do an all-gather for every adapter,
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
gathered_tensors = [
torch.empty_like(a_out) for _ in range(self.process_group.size())
]
torch.distributed.all_gather(gathered_tensors, a_out)
return torch.cat(gathered_tensors, dim=1)
class TensorParallelAdapterRowLinear(LoraLinear):
def __init__(self, base_layer, layer_id, layer_name, process_group):
super().__init__(base_layer, layer_id, process_group)
self.layer_name = layer_name
@classmethod
def load(cls, base_layer, layer_id, layer_name, process_group):
return cls(base_layer, layer_id, layer_name, process_group)
def forward(
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
) -> torch.Tensor:
result = self.base_layer(input)
if self.layer_name is None:
return result
# Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285
stride = result.shape[-1] // self.process_group.size()
start_idx = self.process_group.rank() * stride
end_idx = (self.process_group.rank() + 1) * stride
self.forward_layer_type(
result, input, adapter_data, self.layer_name, start_idx, end_idx
)
return result
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
# Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise.
# We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks.
#
# TODO(travis): this is not very efficient as we do an all-reduce for every adapter,
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
torch.distributed.all_reduce(a_out, group=self.process_group)
return a_out

View File

@ -4,6 +4,7 @@ from typing import Iterable, List
from text_generation_server.layers.linear import get_linear, FastLinear from text_generation_server.layers.linear import get_linear, FastLinear
from text_generation_server.layers.exl2 import Exl2Weight from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.lora import forward_layer_type
if SYSTEM == "ipex": if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
@ -126,6 +127,20 @@ class TensorParallelHead(SuperLayer):
class TensorParallelColumnLinear(SuperLayer): class TensorParallelColumnLinear(SuperLayer):
def __init__(
self,
linear,
process_group: torch.distributed.ProcessGroup = None,
layer_names: List[str] = None,
sizes: List[int] = None,
layer_id: str = None,
):
super().__init__(linear)
self.process_group = process_group
self.layer_names = layer_names
self.sizes = sizes
self.layer_id = layer_id
@classmethod @classmethod
def load_gate_up(cls, config, prefix: str, weights, bias: bool): def load_gate_up(cls, config, prefix: str, weights, bias: bool):
"""Specific method when the QKV was joined after the fact""" """Specific method when the QKV was joined after the fact"""
@ -171,7 +186,18 @@ class TensorParallelColumnLinear(SuperLayer):
return cls(linear) return cls(linear)
@classmethod @classmethod
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): def load_multi(
cls,
config,
prefixes: List[str],
weights,
bias: bool,
dim: int,
sizes: List[int] = None,
layer_id: str = None,
):
# infer layer_names from prefixes
layer_names = [prefix.split(".")[-1] for prefix in prefixes]
if config.quantize == "exl2": if config.quantize == "exl2":
linears = [] linears = []
for prefix in prefixes: for prefix in prefixes:
@ -187,17 +213,75 @@ class TensorParallelColumnLinear(SuperLayer):
else: else:
bias = None bias = None
linear = get_linear(weight, bias, config.quantize) linear = get_linear(weight, bias, config.quantize)
return cls(linear)
return cls(
linear,
process_group=weights.process_group,
layer_names=layer_names,
sizes=sizes,
layer_id=layer_id,
)
def forward(self, input: torch.Tensor, adapter_data=None) -> torch.Tensor:
result = super().forward(input)
# noop if no lora data is provided
if adapter_data is None:
return result
# handle models like Bloom that have inputs of shape
# (batch_size, sequence_length, hidden_size)
# we need to reshape them to (batch_size * sequence_length, hidden_size)
# for the LoRA computation, then reshape back
prev_shape = result.shape
is_3d = len(input.shape) >= 3
if is_3d:
input = input.reshape(-1, input.shape[-1])
result = result.reshape(-1, result.shape[-1])
offset = 0
for i, layer_name in enumerate(self.layer_names):
start_idx = offset // self.process_group.size()
# The 'sizes' parameter is essential in tensor-parallel setups for handling multiple
# projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It
# ensures correct slicing of the result tensor, accommodating variations like grouped-query
# attention where k_proj and v_proj differ from q_proj. This allows precise application of
# LoRA adapters to each sub-component of the multi-head attention mechanism, managing the
# different projection sizes across layers and model architectures.
if self.sizes is not None:
offset += self.sizes[i]
end_idx = offset // self.process_group.size()
else:
end_idx = result.shape[1]
result = forward_layer_type(
process_group=self.process_group,
layer_id=self.layer_id,
result=result,
input=input,
adapter_data=adapter_data,
layer_type=layer_name,
start_idx=start_idx,
end_idx=end_idx,
use_all_gather=True,
)
if is_3d:
result = result.reshape(prev_shape)
return result
class TensorParallelRowLinear(SuperLayer): class TensorParallelRowLinear(SuperLayer):
def __init__(self, linear, process_group): def __init__(self, linear, process_group, layer_name):
super().__init__(linear) super().__init__(linear)
self.process_group = process_group self.process_group = process_group
self.layer_name = layer_name
@classmethod @classmethod
def load(cls, config, prefix: str, weights, bias: bool): def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_weights_row(prefix) weight = weights.get_weights_row(prefix)
layer_name = prefix.split(".")[-1]
if bias and weights.process_group.rank() == 0: if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process # Rank is only on the first rank process
@ -207,17 +291,42 @@ class TensorParallelRowLinear(SuperLayer):
return cls( return cls(
get_linear(weight, bias, config.quantize), get_linear(weight, bias, config.quantize),
process_group=weights.process_group, process_group=weights.process_group,
layer_name=layer_name,
) )
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor: def forward(
self, input: torch.Tensor, adapter_data=None, reduce: bool = True
) -> torch.Tensor:
out = super().forward(input) out = super().forward(input)
if self.process_group.size() > 1 and reduce: if self.process_group.size() > 1 and reduce:
if SYSTEM == "ipex": if SYSTEM == "ipex":
ipex.distributed.all_reduce(out, group=self.process_group) ipex.distributed.all_reduce(out, group=self.process_group)
else: else:
torch.distributed.all_reduce(out, group=self.process_group) torch.distributed.all_reduce(out, group=self.process_group)
# noop if no lora data is provided
if adapter_data is None:
return out return out
# Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285
stride = out.shape[-1] // self.process_group.size()
start_idx = self.process_group.rank() * stride
end_idx = (self.process_group.rank() + 1) * stride
res = forward_layer_type(
process_group=self.process_group,
layer_id=self.layer_name,
result=out,
input=input,
adapter_data=adapter_data,
layer_type=self.layer_name,
start_idx=start_idx,
end_idx=end_idx,
use_all_gather=False,
)
return res
class TensorParallelEmbedding(torch.nn.Module): class TensorParallelEmbedding(torch.nn.Module):
def __init__(self, prefix: str, weights, reduce=True): def __init__(self, prefix: str, weights, reduce=True):

View File

@ -18,8 +18,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Optional, Tuple
import torch import torch
import torch.distributed import torch.distributed
@ -33,14 +31,11 @@ from text_generation_server.layers.attention import (
attention, attention,
reshape_and_cache, reshape_and_cache,
) )
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
SpeculativeHead, SpeculativeHead,
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
) )
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
@ -58,8 +53,6 @@ def load_attention(config, prefix: str, weights, layer_id):
# Only defined in granite. # Only defined in granite.
bias = getattr(config, "attention_bias", False) bias = getattr(config, "attention_bias", False)
head_size = config.hidden_size // config.num_attention_heads head_size = config.hidden_size // config.num_attention_heads
sizes = None
prefixes = None
if config.model_type == "phi3": if config.model_type == "phi3":
prefix = f"{prefix}.qkv_proj" prefix = f"{prefix}.qkv_proj"
@ -82,27 +75,21 @@ def load_attention(config, prefix: str, weights, layer_id):
num_key_value_heads=config.num_key_value_heads, num_key_value_heads=config.num_key_value_heads,
) )
else: else:
prefixes = ["q_proj", "k_proj", "v_proj"]
sizes = [
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
]
base_layer = TensorParallelColumnLinear.load_multi( base_layer = TensorParallelColumnLinear.load_multi(
config, config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0, dim=0,
weights=weights, weights=weights,
bias=bias, bias=bias,
) sizes=[
head_size * config.num_attention_heads,
return TensorParallelMultiAdapterLinear.load( head_size * config.num_key_value_heads,
base_layer=base_layer, head_size * config.num_key_value_heads,
],
layer_id=layer_id, layer_id=layer_id,
layer_names=prefixes,
sizes=sizes,
process_group=weights.process_group,
) )
return base_layer
class FlashLlamaAttention(torch.nn.Module): class FlashLlamaAttention(torch.nn.Module):
@ -150,20 +137,13 @@ class FlashLlamaAttention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights, index) self.query_key_value = load_attention(config, prefix, weights, index)
self.index = index self.index = index
o_proj = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
index,
"o_proj",
process_group=weights.process_group,
)
self.num_groups = self.num_heads // self.num_key_value_heads self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange( self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
@ -247,54 +227,37 @@ class LlamaMLP(nn.Module):
), ),
) )
) )
prefixes = None
sizes = None
# Fuse gate and up proj # Fuse gate and up proj
bias = getattr(config, "mlp_bias", False) bias = getattr(config, "mlp_bias", False)
if config.model_type == "phi3": if config.model_type == "phi3":
gate_up_proj = TensorParallelColumnLinear.load_gate_up( self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
config, config,
prefix=f"{prefix}.gate_up_proj", prefix=f"{prefix}.gate_up_proj",
weights=weights, weights=weights,
bias=bias, bias=bias,
) )
else: else:
prefixes = [f"gate_proj", f"up_proj"] self.gate_up_proj = TensorParallelColumnLinear.load_multi(
sizes = [
config.intermediate_size,
config.intermediate_size,
]
gate_up_proj = TensorParallelColumnLinear.load_multi(
config, config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights, weights=weights,
dim=0, dim=0,
bias=bias, bias=bias,
sizes=[
config.intermediate_size,
config.intermediate_size,
],
layer_id=index,
) )
self.gate_up_proj = TensorParallelMultiAdapterLinear.load( self.down_proj = TensorParallelRowLinear.load(
gate_up_proj,
index,
layer_names=prefixes,
sizes=sizes,
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",
weights=weights, weights=weights,
bias=bias, bias=bias,
) )
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
index,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = ( self.intermediate_size = (
config.intermediate_size // weights.process_group.size() config.intermediate_size // weights.process_group.size()
) )

View File

@ -28,7 +28,6 @@ from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
Seqlen,
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
@ -38,9 +37,6 @@ from text_generation_server.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
SpeculativeHead, SpeculativeHead,
get_linear,
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
) )
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
@ -138,39 +134,26 @@ class MistralAttention(torch.nn.Module):
config.num_key_value_heads // weights.process_group.size() config.num_key_value_heads // weights.process_group.size()
) )
query_key_value = TensorParallelColumnLinear.load_multi( head_size = config.hidden_size // config.num_attention_heads
self.query_key_value = TensorParallelColumnLinear.load_multi(
config, config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0, dim=0,
weights=weights, weights=weights,
bias=False, bias=False,
)
head_size = config.hidden_size // config.num_attention_heads
self.query_key_value = TensorParallelMultiAdapterLinear.load(
query_key_value,
layer_id,
["q_proj", "k_proj", "v_proj"],
sizes=[ sizes=[
head_size * config.num_attention_heads, head_size * config.num_attention_heads,
head_size * config.num_key_value_heads, head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads, head_size * config.num_key_value_heads,
], ],
process_group=weights.process_group, layer_id=layer_id,
) )
self.o_proj = TensorParallelRowLinear.load(
o_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
layer_id,
"o_proj",
process_group=weights.process_group,
)
self.num_groups = self.num_heads // self.num_key_value_heads self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange( self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
@ -264,37 +247,24 @@ class MistralMLP(nn.Module):
) )
) )
# Fuse gate and up proj # Fuse gate and up proj
gate_up_proj = TensorParallelColumnLinear.load_multi( self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config, config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights, weights=weights,
dim=0, dim=0,
bias=False, bias=False,
)
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
layer_id,
["gate_proj", "up_proj"],
sizes=[ sizes=[
config.intermediate_size, config.intermediate_size,
config.intermediate_size, config.intermediate_size,
], ],
process_group=weights.process_group, layer_id=layer_id,
) )
self.down_proj = TensorParallelRowLinear.load(
down_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
layer_id,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = ( self.intermediate_size = (
config.intermediate_size // weights.process_group.size() config.intermediate_size // weights.process_group.size()
) )