From b1169273fdab165b976d62cf25a5fab56fce47a3 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 7 Jun 2024 03:03:15 +0000 Subject: [PATCH] fix: add adapter_data param and avoid missing layers --- .../custom_modeling/flash_rw_modeling.py | 1 + .../flash_santacoder_modeling.py | 1 + .../flash_starcoder2_modeling.py | 1 + .../models/flash_mistral.py | 27 ++++++++++--------- .../models/idefics_causal_lm.py | 1 + 5 files changed, 19 insertions(+), 12 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 7d3c72a7..04d4ba51 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -672,6 +672,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.transformer( input_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 4fa6516e..ec43d641 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -488,6 +488,7 @@ class FlashSantacoderForCausalLM(nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.transformer( input_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 37486e9d..49596372 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -525,6 +525,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: true_max_s = max_s if prefill_cache_indices is not None: diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index bdeeef36..6333c67f 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -147,18 +147,21 @@ class BaseFlashMistral(FlashCausalLM): layer.self_attn.o_proj, ) - layer_weights[(i, "gate_proj")] = ( - f"{prefix}.{i}.mlp.gate_proj", - layer.mlp.gate_up_proj, - ) - layer_weights[(i, "up_proj")] = ( - f"{prefix}.{i}.mlp.up_proj", - layer.mlp.gate_up_proj, - ) - layer_weights[(i, "down_proj")] = ( - f"{prefix}.{i}.mlp.down_proj", - layer.mlp.down_proj, - ) + # TODO: this is a hack to avoid the gate_proj for + # FlashStarcoder2 that doesnt have these layers + if hasattr(layer.mlp, "gate_up_proj"): + layer_weights[(i, "gate_proj")] = ( + f"{prefix}.{i}.mlp.gate_proj", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, "up_proj")] = ( + f"{prefix}.{i}.mlp.up_proj", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, "down_proj")] = ( + f"{prefix}.{i}.mlp.down_proj", + layer.mlp.down_proj, + ) layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head) return layer_weights diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index f507d669..6c562980 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -634,6 +634,7 @@ class IdeficsCausalLM(Model): tokenizer.add_special_tokens({"pad_token": ""}) super(IdeficsCausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True,