From 82fc879e175a6da381f23a4ebbc27c68c93d810b Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 18 Jul 2024 19:58:55 +0000 Subject: [PATCH] feat: refactor lora linear and remove adapter layers --- .../text_generation_server/adapters/lora.py | 2 +- .../text_generation_server/layers/__init__.py | 15 +- server/text_generation_server/layers/lora.py | 408 +++++++----------- .../layers/tensor_parallel.py | 119 ++++- .../custom_modeling/flash_llama_modeling.py | 71 +-- .../custom_modeling/flash_mistral_modeling.py | 44 +- 6 files changed, 297 insertions(+), 362 deletions(-) diff --git a/server/text_generation_server/adapters/lora.py b/server/text_generation_server/adapters/lora.py index 87543be2..a5840ade 100644 --- a/server/text_generation_server/adapters/lora.py +++ b/server/text_generation_server/adapters/lora.py @@ -208,7 +208,7 @@ class LoraWeights(AdapterWeights): for layer_id in range(nlayers): key = (layer_id, layer_type) weight_name, layer = model.target_to_layer[key] - base_weight = layer.base_layer.linear.weight + base_weight = layer.linear.weight base_device = base_weight.device if weight_name not in module_map: diff --git a/server/text_generation_server/layers/__init__.py b/server/text_generation_server/layers/__init__.py index 32c8d121..f430dcff 100644 --- a/server/text_generation_server/layers/__init__.py +++ b/server/text_generation_server/layers/__init__.py @@ -13,8 +13,13 @@ from text_generation_server.layers.speculative import SpeculativeHead from text_generation_server.layers.layernorm import load_layer_norm from text_generation_server.layers.conv import load_conv2d -from text_generation_server.layers.lora import ( - LoraLinear, - TensorParallelMultiAdapterLinear, - TensorParallelAdapterRowLinear, -) +__all__ = [ + "TensorParallelColumnLinear", + "TensorParallelRowLinear", + "TensorParallelEmbedding", + "get_linear", + "FastLinear", + "SpeculativeHead", + "load_layer_norm", + "load_conv2d", +] diff --git a/server/text_generation_server/layers/lora.py b/server/text_generation_server/layers/lora.py index 0bb6db41..c68d5c37 100644 --- a/server/text_generation_server/layers/lora.py +++ b/server/text_generation_server/layers/lora.py @@ -1,14 +1,11 @@ -import math -import os -from typing import TYPE_CHECKING, Optional, Tuple, List +from typing import Optional import torch import torch.distributed -from accelerate import init_empty_weights -from torch import nn -from torch.nn import functional as F from torch.distributed import ProcessGroup +from text_generation_server.adapters import AdapterBatchData +from text_generation_server.adapters.lora import BatchLoraWeights from text_generation_server.utils.sgmv import ( add_lora_a_bgmv, add_lora_b_bgmv, @@ -18,269 +15,160 @@ from text_generation_server.utils.sgmv import ( orient_for_rank, ) -if TYPE_CHECKING: - from text_generation_server.adapters import AdapterBatchData - from text_generation_server.adapters.lora import BatchLoraWeights - -class LoraLinear(nn.Module): - def __init__( - self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup - ): - super().__init__() - self.base_layer = base_layer - self.layer_id = layer_id - self.process_group = process_group - - def forward_layer_type( - self, - result: torch.Tensor, - input: torch.Tensor, - adapter_data: "AdapterBatchData", - layer_type: str, - start_idx: int, - end_idx: int, - ) -> torch.Tensor: - if adapter_data is None: - return result - data = adapter_data.data.get(layer_type) - data: Optional["BatchLoraWeights"] = ( - data.get("lora") if data is not None else None - ) - - if has_sgmv() and data is not None and data.can_vectorize(self.process_group): - # In tensor-parallel configurations, each GPU processes a specific segment of the output. - # The 'result' tensor represents the full output, which can vary in size based on - # the layer type (e.g., attention vs. feed-forward layers). We define the current - # segment using start_idx and end_idx. If the segment size doesn't match this GPU's - # slice of 'result', we create a zero tensor of the correct size for LoRA computation. - # This approach ensures accurate LoRA application across various layer sizes and - # configurations, adapting to different model architectures and parallelization strategies. - # - # Example scenarios where this is necessary: - # 1. The adapter's size doesn't evenly divide across GPUs. - # 2. We're processing the last segment which might be smaller. - # 3. Different projection layers (q, k, v) have different sizes. - if end_idx - start_idx != result.shape[1]: - proj = torch.zeros_like(result[:, start_idx:end_idx]) - else: - proj = result - - for r, rank_segments in data.rank_data.items(): - lora_a_ptr = rank_segments.lora_a_ptr - lora_b_ptr = rank_segments.lora_b_ptr - - if lora_a_ptr is None or lora_b_ptr is None: - raise ValueError("LoRA data is missing") - - if data.use_sgmv: - # Use SGMV for prefill - v = lora_a_sgmv_cutlass( - input, - rank_segments.tmp_shrink, - lora_a_ptr, - rank_segments.segment_starts, - rank_segments.segment_ends, - self.layer_id, - r, - ) - - if self.process_group.size() > 1: - v = self.collect_lora_a(v) - - lora_b_sgmv_cutlass( - proj, - v, - rank_segments.tmp_expand, - lora_b_ptr, - rank_segments.segment_starts, - rank_segments.segment_ends, - self.layer_id, - ) - else: - # Use BGMV for decode - v = torch.zeros( - (input.size(0), r), dtype=input.dtype, device=input.device - ) - # TODO: error with [-1, 0], but not [0, -1] - add_lora_a_bgmv( - v, - input, - lora_a_ptr, - rank_segments.indices, - self.layer_id, - ) - - if self.process_group.size() > 1: - v = self.collect_lora_a(v) - - add_lora_b_bgmv( - proj, - v, - lora_b_ptr, - rank_segments.indices, - self.layer_id, - ) - - if end_idx - start_idx != result.shape[1]: - result[:, start_idx:end_idx] += proj - else: - for adapter_index in adapter_data.meta.adapter_set: - if data is not None and data.has_adapter(adapter_index): - adapter_mask = ( - (adapter_data.meta.adapter_indices == adapter_index) - .to(input.dtype) - .view(-1, 1) - ) - layer_result = self.forward_lora( - input, data, adapter_index, adapter_mask - ) - result[:, start_idx:end_idx] += layer_result - - return result - - def forward_lora( - self, - input: torch.Tensor, - data: "BatchLoraWeights", - adapter_index: int, - adapter_mask: torch.Tensor, - ) -> torch.Tensor: - lora_a = data.lora_a[adapter_index][self.layer_id, :, :] - lora_b = data.lora_b[adapter_index][self.layer_id, :, :] - - lora_a = orient_for_rank(lora_a, lora_b.size(0)) - - a_out = input @ lora_a - if self.process_group.size() > 1: - a_out = self.collect_lora_a(a_out) - - result = (a_out @ lora_b) * adapter_mask - return result - - def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: - raise NotImplementedError("Implemented in subclasses") - - -class TensorParallelMultiAdapterLinear(LoraLinear): - def __init__( - self, - base_layer: nn.Module, - layer_id: int, - layer_names: List[str], - sizes: List[int], - process_group: ProcessGroup, - ): - super().__init__(base_layer, layer_id, process_group) - self.layer_names = layer_names - self.sizes = sizes - - @classmethod - def load( - cls, - base_layer: nn.Module, - layer_id: int, - layer_names: List[str], - sizes: List[int], - process_group: ProcessGroup, - ): - return TensorParallelMultiAdapterLinear( - base_layer, layer_id, layer_names, sizes, process_group - ) - - def forward( - self, input: torch.Tensor, adapter_data: "AdapterBatchData" - ) -> torch.Tensor: - result = self.base_layer(input) - - # noop if no layer names are provided (e.g. for models without adapters) - if self.layer_names is None: - return result - - # handle models like Bloom that have inputs of shape - # (batch_size, sequence_length, hidden_size) - # we need to reshape them to (batch_size * sequence_length, hidden_size) - # for the LoRA computation, then reshape back - prev_shape = result.shape - is_3d = len(input.shape) >= 3 - if is_3d: - input = input.reshape(-1, input.shape[-1]) - result = result.reshape(-1, result.shape[-1]) - - offset = 0 - for i, layer_name in enumerate(self.layer_names): - start_idx = offset // self.process_group.size() - # The 'sizes' parameter is essential in tensor-parallel setups for handling multiple - # projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It - # ensures correct slicing of the result tensor, accommodating variations like grouped-query - # attention where k_proj and v_proj differ from q_proj. This allows precise application of - # LoRA adapters to each sub-component of the multi-head attention mechanism, managing the - # different projection sizes across layers and model architectures. - if self.sizes is not None: - offset += self.sizes[i] - end_idx = offset // self.process_group.size() - else: - end_idx = result.shape[1] - - result = self.forward_layer_type( - result, input, adapter_data, layer_name, start_idx, end_idx - ) - - if is_3d: - result = result.reshape(prev_shape) - - return result - - def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: +def gather_lora_weights( + process_group: ProcessGroup, + weights: torch.Tensor, + use_all_gather: bool = False, +) -> BatchLoraWeights: + if use_all_gather: # Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise. # We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks. # - # TODO(travis): this is not very efficient as we do an all-gather for every adapter, - # instead we could pre-allocate a (B, a, r) tensor for all adapters with the same - # rank, compute `a_out` on each, and then slice them into the buffer as shown here: - # https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609 + # TODO: this is not very efficient as we do an all-gather for every adapter, + # instead we could pre-allocate a (B, a, r) tensor for all adapters with the same + # rank, compute `a_out` on each, and then slice them into the buffer as shown here: + # https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609 gathered_tensors = [ - torch.empty_like(a_out) for _ in range(self.process_group.size()) + torch.empty_like(weights) for _ in range(process_group.size()) ] - torch.distributed.all_gather(gathered_tensors, a_out) + torch.distributed.all_gather(gathered_tensors, weights, group=process_group) return torch.cat(gathered_tensors, dim=1) + else: + torch.distributed.all_reduce(weights, group=process_group) + return weights -class TensorParallelAdapterRowLinear(LoraLinear): - def __init__(self, base_layer, layer_id, layer_name, process_group): - super().__init__(base_layer, layer_id, process_group) - self.layer_name = layer_name - - @classmethod - def load(cls, base_layer, layer_id, layer_name, process_group): - return cls(base_layer, layer_id, layer_name, process_group) - - def forward( - self, input: torch.Tensor, adapter_data: "AdapterBatchData" - ) -> torch.Tensor: - result = self.base_layer(input) - - if self.layer_name is None: - return result - - # Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285 - stride = result.shape[-1] // self.process_group.size() - start_idx = self.process_group.rank() * stride - end_idx = (self.process_group.rank() + 1) * stride - - self.forward_layer_type( - result, input, adapter_data, self.layer_name, start_idx, end_idx - ) - +def forward_layer_type( + process_group: ProcessGroup, + layer_id: int, + result: torch.Tensor, + input: torch.Tensor, + adapter_data: "AdapterBatchData", + layer_type: str, + start_idx: int, + end_idx: int, + use_all_gather: bool = False, +) -> torch.Tensor: + if adapter_data is None: return result + data = adapter_data.data.get(layer_type) + data: Optional["BatchLoraWeights"] = data.get("lora") if data is not None else None - def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: - # Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise. - # We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks. + if has_sgmv() and data is not None and data.can_vectorize(process_group): + # In tensor-parallel configurations, each GPU processes a specific segment of the output. + # The 'result' tensor represents the full output, which can vary in size based on + # the layer type (e.g., attention vs. feed-forward layers). We define the current + # segment using start_idx and end_idx. If the segment size doesn't match this GPU's + # slice of 'result', we create a zero tensor of the correct size for LoRA computation. + # This approach ensures accurate LoRA application across various layer sizes and + # configurations, adapting to different model architectures and parallelization strategies. # - # TODO(travis): this is not very efficient as we do an all-reduce for every adapter, - # instead we could pre-allocate a (B, a, r) tensor for all adapters with the same - # rank, compute `a_out` on each, and then slice them into the buffer as shown here: - # https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609 - torch.distributed.all_reduce(a_out, group=self.process_group) - return a_out + # Example scenarios where this is necessary: + # 1. The adapter's size doesn't evenly divide across GPUs. + # 2. We're processing the last segment which might be smaller. + # 3. Different projection layers (q, k, v) have different sizes. + if end_idx - start_idx != result.shape[1]: + proj = torch.zeros_like(result[:, start_idx:end_idx]) + else: + proj = result + + for r, rank_segments in data.rank_data.items(): + lora_a_ptr = rank_segments.lora_a_ptr + lora_b_ptr = rank_segments.lora_b_ptr + + if lora_a_ptr is None or lora_b_ptr is None: + raise ValueError("LoRA data is missing") + + if data.use_sgmv: + # Use SGMV for prefill + v = lora_a_sgmv_cutlass( + input, + rank_segments.tmp_shrink, + lora_a_ptr, + rank_segments.segment_starts, + rank_segments.segment_ends, + layer_id, + r, + ) + + if process_group.size() > 1: + v = gather_lora_weights(process_group, v, use_all_gather) + + lora_b_sgmv_cutlass( + proj, + v, + rank_segments.tmp_expand, + lora_b_ptr, + rank_segments.segment_starts, + rank_segments.segment_ends, + layer_id, + ) + else: + # Use BGMV for decode + v = torch.zeros( + (input.size(0), r), dtype=input.dtype, device=input.device + ) + # TODO: error with [-1, 0], but not [0, -1] + add_lora_a_bgmv( + v, + input, + lora_a_ptr, + rank_segments.indices, + layer_id, + ) + + if process_group.size() > 1: + v = gather_lora_weights(process_group, v, use_all_gather) + + add_lora_b_bgmv( + proj, + v, + lora_b_ptr, + rank_segments.indices, + layer_id, + ) + + if end_idx - start_idx != result.shape[1]: + result[:, start_idx:end_idx] += proj + else: + for adapter_index in adapter_data.meta.adapter_set: + if data is not None and data.has_adapter(adapter_index): + adapter_mask = ( + (adapter_data.meta.adapter_indices == adapter_index) + .to(input.dtype) + .view(-1, 1) + ) + layer_result = forward_lora( + process_group=process_group, + layer_id=layer_id, + input=input, + data=data, + adapter_index=adapter_index, + adapter_mask=adapter_mask, + use_all_gather=use_all_gather, + ) + result[:, start_idx:end_idx] += layer_result + return result + + +def forward_lora( + process_group: ProcessGroup, + layer_id, + input: torch.Tensor, + data: "BatchLoraWeights", + adapter_index: int, + adapter_mask: torch.Tensor, + use_all_gather: bool = False, +) -> torch.Tensor: + lora_a = data.lora_a[adapter_index][layer_id, :, :] + lora_b = data.lora_b[adapter_index][layer_id, :, :] + + lora_a = orient_for_rank(lora_a, lora_b.size(0)) + + a_out = input @ lora_a + if process_group.size() > 1: + a_out = gather_lora_weights(process_group, a_out, use_all_gather) + + result = (a_out @ lora_b) * adapter_mask + return result diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index 011f105b..9b641d11 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -4,6 +4,7 @@ from typing import Iterable, List from text_generation_server.layers.linear import get_linear, FastLinear from text_generation_server.layers.exl2 import Exl2Weight from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.layers.lora import forward_layer_type if SYSTEM == "ipex": import intel_extension_for_pytorch as ipex @@ -126,6 +127,20 @@ class TensorParallelHead(SuperLayer): class TensorParallelColumnLinear(SuperLayer): + def __init__( + self, + linear, + process_group: torch.distributed.ProcessGroup = None, + layer_names: List[str] = None, + sizes: List[int] = None, + layer_id: str = None, + ): + super().__init__(linear) + self.process_group = process_group + self.layer_names = layer_names + self.sizes = sizes + self.layer_id = layer_id + @classmethod def load_gate_up(cls, config, prefix: str, weights, bias: bool): """Specific method when the QKV was joined after the fact""" @@ -171,7 +186,18 @@ class TensorParallelColumnLinear(SuperLayer): return cls(linear) @classmethod - def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): + def load_multi( + cls, + config, + prefixes: List[str], + weights, + bias: bool, + dim: int, + sizes: List[int] = None, + layer_id: str = None, + ): + # infer layer_names from prefixes + layer_names = [prefix.split(".")[-1] for prefix in prefixes] if config.quantize == "exl2": linears = [] for prefix in prefixes: @@ -187,17 +213,75 @@ class TensorParallelColumnLinear(SuperLayer): else: bias = None linear = get_linear(weight, bias, config.quantize) - return cls(linear) + + return cls( + linear, + process_group=weights.process_group, + layer_names=layer_names, + sizes=sizes, + layer_id=layer_id, + ) + + def forward(self, input: torch.Tensor, adapter_data=None) -> torch.Tensor: + result = super().forward(input) + + # noop if no lora data is provided + if adapter_data is None: + return result + + # handle models like Bloom that have inputs of shape + # (batch_size, sequence_length, hidden_size) + # we need to reshape them to (batch_size * sequence_length, hidden_size) + # for the LoRA computation, then reshape back + prev_shape = result.shape + is_3d = len(input.shape) >= 3 + if is_3d: + input = input.reshape(-1, input.shape[-1]) + result = result.reshape(-1, result.shape[-1]) + + offset = 0 + for i, layer_name in enumerate(self.layer_names): + start_idx = offset // self.process_group.size() + # The 'sizes' parameter is essential in tensor-parallel setups for handling multiple + # projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It + # ensures correct slicing of the result tensor, accommodating variations like grouped-query + # attention where k_proj and v_proj differ from q_proj. This allows precise application of + # LoRA adapters to each sub-component of the multi-head attention mechanism, managing the + # different projection sizes across layers and model architectures. + if self.sizes is not None: + offset += self.sizes[i] + end_idx = offset // self.process_group.size() + else: + end_idx = result.shape[1] + + result = forward_layer_type( + process_group=self.process_group, + layer_id=self.layer_id, + result=result, + input=input, + adapter_data=adapter_data, + layer_type=layer_name, + start_idx=start_idx, + end_idx=end_idx, + use_all_gather=True, + ) + + if is_3d: + result = result.reshape(prev_shape) + + return result class TensorParallelRowLinear(SuperLayer): - def __init__(self, linear, process_group): + def __init__(self, linear, process_group, layer_name): super().__init__(linear) self.process_group = process_group + self.layer_name = layer_name @classmethod def load(cls, config, prefix: str, weights, bias: bool): weight = weights.get_weights_row(prefix) + layer_name = prefix.split(".")[-1] if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process @@ -207,16 +291,41 @@ class TensorParallelRowLinear(SuperLayer): return cls( get_linear(weight, bias, config.quantize), process_group=weights.process_group, + layer_name=layer_name, ) - def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor: + def forward( + self, input: torch.Tensor, adapter_data=None, reduce: bool = True + ) -> torch.Tensor: out = super().forward(input) if self.process_group.size() > 1 and reduce: if SYSTEM == "ipex": ipex.distributed.all_reduce(out, group=self.process_group) else: torch.distributed.all_reduce(out, group=self.process_group) - return out + + # noop if no lora data is provided + if adapter_data is None: + return out + + # Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285 + stride = out.shape[-1] // self.process_group.size() + start_idx = self.process_group.rank() * stride + end_idx = (self.process_group.rank() + 1) * stride + + res = forward_layer_type( + process_group=self.process_group, + layer_id=self.layer_name, + result=out, + input=input, + adapter_data=adapter_data, + layer_type=self.layer_name, + start_idx=start_idx, + end_idx=end_idx, + use_all_gather=False, + ) + + return res class TensorParallelEmbedding(torch.nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 78832341..c6a394b4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -18,8 +18,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple - import torch import torch.distributed @@ -33,14 +31,11 @@ from text_generation_server.layers.attention import ( attention, reshape_and_cache, ) -from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, SpeculativeHead, - TensorParallelMultiAdapterLinear, - TensorParallelAdapterRowLinear, ) from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( @@ -58,8 +53,6 @@ def load_attention(config, prefix: str, weights, layer_id): # Only defined in granite. bias = getattr(config, "attention_bias", False) head_size = config.hidden_size // config.num_attention_heads - sizes = None - prefixes = None if config.model_type == "phi3": prefix = f"{prefix}.qkv_proj" @@ -82,27 +75,21 @@ def load_attention(config, prefix: str, weights, layer_id): num_key_value_heads=config.num_key_value_heads, ) else: - prefixes = ["q_proj", "k_proj", "v_proj"] - sizes = [ - head_size * config.num_attention_heads, - head_size * config.num_key_value_heads, - head_size * config.num_key_value_heads, - ] + base_layer = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], dim=0, weights=weights, bias=bias, + sizes=[ + head_size * config.num_attention_heads, + head_size * config.num_key_value_heads, + head_size * config.num_key_value_heads, + ], + layer_id=layer_id, ) - - return TensorParallelMultiAdapterLinear.load( - base_layer=base_layer, - layer_id=layer_id, - layer_names=prefixes, - sizes=sizes, - process_group=weights.process_group, - ) + return base_layer class FlashLlamaAttention(torch.nn.Module): @@ -150,20 +137,13 @@ class FlashLlamaAttention(torch.nn.Module): self.query_key_value = load_attention(config, prefix, weights, index) self.index = index - o_proj = TensorParallelRowLinear.load( + self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, bias=False, ) - self.o_proj = TensorParallelAdapterRowLinear.load( - o_proj, - index, - "o_proj", - process_group=weights.process_group, - ) - self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device @@ -247,54 +227,37 @@ class LlamaMLP(nn.Module): ), ) ) - prefixes = None - sizes = None # Fuse gate and up proj bias = getattr(config, "mlp_bias", False) if config.model_type == "phi3": - gate_up_proj = TensorParallelColumnLinear.load_gate_up( + self.gate_up_proj = TensorParallelColumnLinear.load_gate_up( config, prefix=f"{prefix}.gate_up_proj", weights=weights, bias=bias, ) else: - prefixes = [f"gate_proj", f"up_proj"] - sizes = [ - config.intermediate_size, - config.intermediate_size, - ] - gate_up_proj = TensorParallelColumnLinear.load_multi( + self.gate_up_proj = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], weights=weights, dim=0, bias=bias, + sizes=[ + config.intermediate_size, + config.intermediate_size, + ], + layer_id=index, ) - self.gate_up_proj = TensorParallelMultiAdapterLinear.load( - gate_up_proj, - index, - layer_names=prefixes, - sizes=sizes, - process_group=weights.process_group, - ) - - down_proj = TensorParallelRowLinear.load( + self.down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, bias=bias, ) - self.down_proj = TensorParallelAdapterRowLinear.load( - down_proj, - index, - "down_proj", - process_group=weights.process_group, - ) - self.intermediate_size = ( config.intermediate_size // weights.process_group.size() ) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 8028dbe8..6bc7a073 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -28,7 +28,6 @@ from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( - Seqlen, paged_attention, attention, reshape_and_cache, @@ -38,9 +37,6 @@ from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, SpeculativeHead, - get_linear, - TensorParallelMultiAdapterLinear, - TensorParallelAdapterRowLinear, ) from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( @@ -138,39 +134,26 @@ class MistralAttention(torch.nn.Module): config.num_key_value_heads // weights.process_group.size() ) - query_key_value = TensorParallelColumnLinear.load_multi( + head_size = config.hidden_size // config.num_attention_heads + self.query_key_value = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], dim=0, weights=weights, bias=False, - ) - - head_size = config.hidden_size // config.num_attention_heads - self.query_key_value = TensorParallelMultiAdapterLinear.load( - query_key_value, - layer_id, - ["q_proj", "k_proj", "v_proj"], sizes=[ head_size * config.num_attention_heads, head_size * config.num_key_value_heads, head_size * config.num_key_value_heads, ], - process_group=weights.process_group, + layer_id=layer_id, ) - - o_proj = TensorParallelRowLinear.load( + self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, bias=False, ) - self.o_proj = TensorParallelAdapterRowLinear.load( - o_proj, - layer_id, - "o_proj", - process_group=weights.process_group, - ) self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device @@ -264,37 +247,24 @@ class MistralMLP(nn.Module): ) ) # Fuse gate and up proj - gate_up_proj = TensorParallelColumnLinear.load_multi( + self.gate_up_proj = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], weights=weights, dim=0, bias=False, - ) - self.gate_up_proj = TensorParallelMultiAdapterLinear.load( - gate_up_proj, - layer_id, - ["gate_proj", "up_proj"], sizes=[ config.intermediate_size, config.intermediate_size, ], - process_group=weights.process_group, + layer_id=layer_id, ) - - down_proj = TensorParallelRowLinear.load( + self.down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, bias=False, ) - - self.down_proj = TensorParallelAdapterRowLinear.load( - down_proj, - layer_id, - "down_proj", - process_group=weights.process_group, - ) self.intermediate_size = ( config.intermediate_size // weights.process_group.size() )