diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index d10efb41..e08a2aad 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -34,6 +34,11 @@ class PaliGemmaForConditionalGeneration(nn.Module): config=config.vision_config, weights=weights, ) + self.post_vision_tower_layernorm = nn.LayerNorm.load( + prefix="vision_tower.vision_model.post_layernorm", + weights=weights, + eps=config.vision_config.layer_norm_eps, + ) self.multi_modal_projector = TensorParallelColumnLinear.load( config, @@ -84,7 +89,10 @@ class PaliGemmaForConditionalGeneration(nn.Module): if pixel_values is not None: pixel_values = pixel_values.to(dtype=inputs_embeds.dtype) image_outputs = self.vision_tower(pixel_values) - image_features = self.multi_modal_projector(image_outputs.last_hidden_state) + last_hidden_state = self.post_vision_tower_layernorm( + image_outputs.last_hidden_state + ) + image_features = self.multi_modal_projector(last_hidden_state) # mask where image or padding tokens mask = input_ids == self.config.image_token_index diff --git a/server/text_generation_server/models/custom_modeling/siglip.py b/server/text_generation_server/models/custom_modeling/siglip.py index 480d0f9f..95ac9ede 100644 --- a/server/text_generation_server/models/custom_modeling/siglip.py +++ b/server/text_generation_server/models/custom_modeling/siglip.py @@ -364,7 +364,6 @@ class SiglipEncoder(nn.Module): inputs_embeds, attention_mask: Optional[torch.Tensor] = None, ): - hidden_states = inputs_embeds for idx, encoder_layer in enumerate(self.layers): hidden_states, _ = encoder_layer( @@ -386,20 +385,11 @@ class SiglipVisionTransformer(nn.Module): self.encoder = SiglipEncoder( prefix=f"{prefix}.encoder", config=config, weights=weights ) - self.post_layernorm = nn.LayerNorm.load( - prefix=f"{prefix}.post_layernorm", - weights=weights, - eps=config.layer_norm_eps, - ) def forward( self, pixel_values: Optional[torch.FloatTensor] = None, ): - r""" - Returns: - - """ if pixel_values is None: raise ValueError("You have to specify pixel_values") @@ -412,10 +402,9 @@ class SiglipVisionTransformer(nn.Module): inputs_embeds=hidden_states, ) last_hidden_state = encoder_outputs - post_last_hidden_state = self.post_layernorm(last_hidden_state) return BaseModelOutputWithPooling( - last_hidden_state=post_last_hidden_state, + last_hidden_state=last_hidden_state, # pooler_output=pooled_output, # hidden_states=encoder_outputs, )