From 91f407226d8fff463686b5c14437ee2d955e146f Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 7 Jun 2024 02:21:06 +0000 Subject: [PATCH] feat: support if vlm models --- server/text_generation_server/layers/lora.py | 2 ++ .../text_generation_server/models/flash_llama.py | 16 +++++++++++++--- .../models/flash_mistral.py | 14 ++++++++++++-- server/text_generation_server/models/mamba.py | 1 + .../models/vlm_causal_lm.py | 4 +++- 5 files changed, 31 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/layers/lora.py b/server/text_generation_server/layers/lora.py index 30070287..7adfbb29 100644 --- a/server/text_generation_server/layers/lora.py +++ b/server/text_generation_server/layers/lora.py @@ -41,6 +41,8 @@ class LoraLinear(nn.Module): start_idx: int, end_idx: int, ) -> torch.Tensor: + if adapter_data is None: + return result data = adapter_data.data.get(layer_type) data: Optional["BatchLoraWeights"] = ( data.get(LORA) if data is not None else None diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 327e4a6f..1e626768 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -108,7 +108,17 @@ class FlashLlama(FlashCausalLM): layer_weights = {} prefix = "model.layers" - for i, layer in enumerate(self.model.model.layers): + + # This accounts for VLMs (e.g. LlavaNext, Idefics2) + # that have a language_model inside of the larger model. + if hasattr(self.model, "language_model"): + _model = self.model.language_model + elif hasattr(self.model, "text_model"): + _model = self.model.text_model + else: + _model = self.model + + for i, layer in enumerate(_model.model.layers): layer_weights[(i, "q_proj")] = ( f"{prefix}.{i}.self_attn.q_proj", layer.self_attn.query_key_value, @@ -139,7 +149,7 @@ class FlashLlama(FlashCausalLM): layer.mlp.down_proj, ) - layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head) + layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head) return layer_weights @property @@ -151,7 +161,7 @@ class FlashLlama(FlashCausalLM): return ["q_proj", "v_proj"] def get_num_layers_for_type(self, layer_type: str) -> int: - return 1 if layer_type == LM_HEAD else len(self.model.model.layers) + return 1 if layer_type == "lm_head" else len(self.model.model.layers) def is_row_parallel(self, layer_type: str) -> bool: return layer_type in ROW_PARALLEL diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 90a95c41..bdeeef36 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -119,7 +119,17 @@ class BaseFlashMistral(FlashCausalLM): layer_weights = {} prefix = "model.layers" - for i, layer in enumerate(self.model.model.layers): + + # This accounts for VLMs (e.g. LlavaNext, Idefics2) + # that have a language_model inside of the larger model. + if hasattr(self.model, "language_model"): + _model = self.model.language_model + elif hasattr(self.model, "text_model"): + _model = self.model.text_model + else: + _model = self.model + + for i, layer in enumerate(_model.model.layers): layer_weights[(i, "q_proj")] = ( f"{prefix}.{i}.self_attn.q_proj", layer.self_attn.query_key_value, @@ -150,7 +160,7 @@ class BaseFlashMistral(FlashCausalLM): layer.mlp.down_proj, ) - layer_weights[(0, "lm_head")] = ("lm_head", self.model.lm_head) + layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head) return layer_weights @property diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 3133a137..9189b45c 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -453,6 +453,7 @@ class Mamba(Model): model = MambaModel(config, weights) torch.distributed.barrier(group=self.process_group) super(Mamba, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 59a6fab1..b8b0f207 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -219,7 +219,9 @@ class VlmCausalLM(BaseFlashMistral): return VlmCausalLMBatch def forward( - self, batch: VlmCausalLMBatch + self, + batch: VlmCausalLMBatch, + adapter_data: Optional[Dict[str, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Model Forward if batch.speculative_ids is not None: