From de56a81c5c8e774e2a4022a304679bb09f6c1010 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 6 Jun 2024 22:44:58 +0000 Subject: [PATCH] feat: add lora support to mistral and refactors --- server/text_generation_server/cli.py | 9 +- .../custom_modeling/flash_mistral_modeling.py | 93 ++++++++++++++----- .../models/flash_mistral.py | 84 ++++++++++++++++- .../models/flash_mixtral.py | 1 - 4 files changed, 160 insertions(+), 27 deletions(-) diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 87721097..45c2fab9 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -168,9 +168,12 @@ def download_weights( except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): pass else: - utils.peft.download_peft( - model_id, revision, trust_remote_code=trust_remote_code - ) + try: + utils.peft.download_peft( + model_id, revision, trust_remote_code=trust_remote_code + ) + except Exception: + pass try: import json diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 77a8a384..d1ba5564 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -38,6 +38,8 @@ from text_generation_server.layers import ( TensorParallelEmbedding, SpeculativeHead, get_linear, + TensorParallelMultiAdapterLinear, + TensorParallelAdapterRowLinear, ) from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( @@ -107,12 +109,7 @@ class MistralConfig(PretrainedConfig): class MistralAttention(torch.nn.Module): - def __init__( - self, - prefix: str, - config, - weights, - ): + def __init__(self, prefix: str, config, weights, layer_id): super().__init__() self.max_past = ( config.sliding_window if config.sliding_window is not None else -1 @@ -140,7 +137,7 @@ class MistralAttention(torch.nn.Module): config.num_key_value_heads // weights.process_group.size() ) - self.query_key_value = TensorParallelColumnLinear.load_multi( + query_key_value = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], dim=0, @@ -148,12 +145,31 @@ class MistralAttention(torch.nn.Module): bias=False, ) - self.o_proj = TensorParallelRowLinear.load( + head_size = config.hidden_size // config.num_attention_heads + self.query_key_value = TensorParallelMultiAdapterLinear.load( + query_key_value, + layer_id, + ["q_proj", "k_proj", "v_proj"], + sizes=[ + head_size * config.num_attention_heads, + head_size * config.num_key_value_heads, + head_size * config.num_key_value_heads, + ], + process_group=weights.process_group, + ) + + o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, bias=False, ) + self.o_proj = TensorParallelAdapterRowLinear.load( + o_proj, + layer_id, + "o_proj", + process_group=weights.process_group, + ) self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device @@ -171,8 +187,9 @@ class MistralAttention(torch.nn.Module): input_lengths, max_s, prefill_cache_indices, + adapter_data, ): - qkv = self.query_key_value(hidden_states) + qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( [ self.head_size * self.num_heads, @@ -224,11 +241,13 @@ class MistralAttention(torch.nn.Module): max_s, ) - return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + return self.o_proj( + attn_output.view(-1, self.num_heads * self.head_size), adapter_data + ) class MistralMLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix, config, weights, layer_id): super().__init__() self.hidden_act = config.hidden_act self.act = ( @@ -244,19 +263,37 @@ class MistralMLP(nn.Module): ) ) # Fuse gate and up proj - self.gate_up_proj = TensorParallelColumnLinear.load_multi( + gate_up_proj = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], weights=weights, dim=0, bias=False, ) - self.down_proj = TensorParallelRowLinear.load( + self.gate_up_proj = TensorParallelMultiAdapterLinear.load( + gate_up_proj, + layer_id, + ["gate_proj", "up_proj"], + sizes=[ + config.intermediate_size, + config.intermediate_size, + ], + process_group=weights.process_group, + ) + + down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, bias=False, ) + + self.down_proj = TensorParallelAdapterRowLinear.load( + down_proj, + layer_id, + "down_proj", + process_group=weights.process_group, + ) self.intermediate_size = ( config.intermediate_size // weights.process_group.size() ) @@ -264,7 +301,7 @@ class MistralMLP(nn.Module): # TODO: This is a hotfix to be removed & properly refactored. self.quantize = config.quantize - def forward(self, hidden_states): + def forward(self, hidden_states, adapter_data): if ( SYSTEM == "rocm" and self.hidden_act == "silu" @@ -278,20 +315,27 @@ class MistralMLP(nn.Module): device="cuda", ) _custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) - return self.down_proj(out) + return self.down_proj(out, adapter_data) else: - gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = self.gate_up_proj(hidden_states, adapter_data) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + ) class MistralLayer(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix, config, weights, layer_id): super().__init__() self.self_attn = MistralAttention( - prefix=f"{prefix}.self_attn", config=config, weights=weights + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + layer_id=layer_id, + ) + self.mlp = MistralMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id ) - self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.input_layernorm = FastRMSNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps @@ -315,6 +359,7 @@ class MistralLayer(nn.Module): input_lengths, max_s, prefill_cache_indices, + adapter_data, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -330,6 +375,7 @@ class MistralLayer(nn.Module): input_lengths, max_s, prefill_cache_indices, + adapter_data, ) # faster post attention rms norm @@ -337,7 +383,7 @@ class MistralLayer(nn.Module): attn_output, res ) - mlp_output = self.mlp(normed_attn_res_output) + mlp_output = self.mlp(normed_attn_res_output, adapter_data) return mlp_output, attn_res @@ -355,6 +401,7 @@ class MistralModel(torch.nn.Module): prefix=f"{prefix}.layers.{layer_id}", config=config, weights=weights, + layer_id=layer_id, ) for layer_id in range(config.num_hidden_layers) ] @@ -381,6 +428,7 @@ class MistralModel(torch.nn.Module): max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], + adapter_data: Optional[torch.Tensor] = None, ): hidden_states = inputs_embeds # Get rotary cos and sin for this forward @@ -403,6 +451,7 @@ class MistralModel(torch.nn.Module): input_lengths, max_s, prefill_cache_indices, + adapter_data, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -454,6 +503,7 @@ class FlashMistralForCausalLM(torch.nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: true_max_s = max_s if prefill_cache_indices is not None: @@ -476,6 +526,7 @@ class FlashMistralForCausalLM(torch.nn.Module): max_s, true_max_s, prefill_cache_indices, + adapter_data, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index e66a1c3d..37cc0235 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -3,7 +3,7 @@ import torch.distributed from opentelemetry import trace from transformers import AutoTokenizer, AutoConfig -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict, List from text_generation_server.models import FlashCausalLM from text_generation_server.models.flash_causal_lm import set_sliding_window @@ -21,6 +21,31 @@ from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) +Q_PROJ = "q_proj" +K_PROJ = "k_proj" +V_PROJ = "v_proj" +O_PROJ = "o_proj" + +GATE_PROJ = "gate_proj" +UP_PROJ = "up_proj" +DOWN_PROJ = "down_proj" + +LM_HEAD = "lm_head" + + +# TODO(travis): re-enable LM_HEAD after resolving issues with outputs +ADAPTER_LAYERS = [ + Q_PROJ, + K_PROJ, + V_PROJ, + O_PROJ, + GATE_PROJ, + UP_PROJ, + DOWN_PROJ, +] # LM_HEAD +ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD} + + class BaseFlashMistral(FlashCausalLM): def __init__( self, @@ -99,6 +124,62 @@ class BaseFlashMistral(FlashCausalLM): model.model.head_size, ) + @property + def supports_adapter_loading(self) -> bool: + return True + + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: + layer_weights = {} + + prefix = "model.layers" + for i, layer in enumerate(self.model.model.layers): + layer_weights[(i, Q_PROJ)] = ( + f"{prefix}.{i}.self_attn.q_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, K_PROJ)] = ( + f"{prefix}.{i}.self_attn.k_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, V_PROJ)] = ( + f"{prefix}.{i}.self_attn.v_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, O_PROJ)] = ( + f"{prefix}.{i}.self_attn.o_proj", + layer.self_attn.o_proj, + ) + + layer_weights[(i, GATE_PROJ)] = ( + f"{prefix}.{i}.mlp.gate_proj", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, UP_PROJ)] = ( + f"{prefix}.{i}.mlp.up_proj", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, DOWN_PROJ)] = ( + f"{prefix}.{i}.mlp.down_proj", + layer.mlp.down_proj, + ) + + layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head) + return layer_weights + + @property + def adapter_layers(self) -> List[str]: + return ADAPTER_LAYERS + + @property + def default_traced_adapter_layers(self) -> List[str]: + return [Q_PROJ, V_PROJ] + + def get_num_layers_for_type(self, layer_type: str) -> int: + return 1 if layer_type == LM_HEAD else len(self.model.model.layers) + + def is_row_parallel(self, layer_type: str) -> bool: + return layer_type in ROW_PARALLEL + class FlashMistral(BaseFlashMistral): def __init__( @@ -111,7 +192,6 @@ class FlashMistral(BaseFlashMistral): trust_remote_code: bool = False, ): super(FlashMistral, self).__init__( - model_id=model_id, config_cls=MistralConfig, model_cls=FlashMistralForCausalLM, model_id=model_id, diff --git a/server/text_generation_server/models/flash_mixtral.py b/server/text_generation_server/models/flash_mixtral.py index 216279a8..587d423f 100644 --- a/server/text_generation_server/models/flash_mixtral.py +++ b/server/text_generation_server/models/flash_mixtral.py @@ -20,7 +20,6 @@ class FlashMixtral(BaseFlashMistral): trust_remote_code: bool = False, ): super(FlashMixtral, self).__init__( - model_id=model_id, config_cls=MixtralConfig, model_cls=FlashMixtralForCausalLM, model_id=model_id,