Fix: don't apply post layernorm in SiglipVisionTransformer (#2459)

* Fix: don't apply post layernorm in SiglipVisionTransformer

This fixes a bug with LLaVA Next when using Siglip as the vision model. LLaVA Next expects the output of the vision model to be the encoder outputs before layernorm (see original transformers implementation here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L813).

This also makes Siglip consistent with the existing Clip implementation:

https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/models/custom_modeling/clip.py#L613

* fix: adjust pali gemma for post layer norm and small refactors

---------

Co-authored-by: Travis Addair <tgaddair@gmail.com>
This commit is contained in:
drbh 2024-08-26 17:04:46 -04:00 committed by GitHub
parent f3c5d7d92f
commit 30be188400
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 13 deletions

View File

@ -34,6 +34,11 @@ class PaliGemmaForConditionalGeneration(nn.Module):
config=config.vision_config, config=config.vision_config,
weights=weights, weights=weights,
) )
self.post_vision_tower_layernorm = nn.LayerNorm.load(
prefix="vision_tower.vision_model.post_layernorm",
weights=weights,
eps=config.vision_config.layer_norm_eps,
)
self.multi_modal_projector = TensorParallelColumnLinear.load( self.multi_modal_projector = TensorParallelColumnLinear.load(
config, config,
@ -84,7 +89,10 @@ class PaliGemmaForConditionalGeneration(nn.Module):
if pixel_values is not None: if pixel_values is not None:
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype) pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
image_outputs = self.vision_tower(pixel_values) image_outputs = self.vision_tower(pixel_values)
image_features = self.multi_modal_projector(image_outputs.last_hidden_state) last_hidden_state = self.post_vision_tower_layernorm(
image_outputs.last_hidden_state
)
image_features = self.multi_modal_projector(last_hidden_state)
# mask where image or padding tokens # mask where image or padding tokens
mask = input_ids == self.config.image_token_index mask = input_ids == self.config.image_token_index

View File

@ -364,7 +364,6 @@ class SiglipEncoder(nn.Module):
inputs_embeds, inputs_embeds,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
): ):
hidden_states = inputs_embeds hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
hidden_states, _ = encoder_layer( hidden_states, _ = encoder_layer(
@ -386,20 +385,11 @@ class SiglipVisionTransformer(nn.Module):
self.encoder = SiglipEncoder( self.encoder = SiglipEncoder(
prefix=f"{prefix}.encoder", config=config, weights=weights prefix=f"{prefix}.encoder", config=config, weights=weights
) )
self.post_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.post_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
def forward( def forward(
self, self,
pixel_values: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None,
): ):
r"""
Returns:
"""
if pixel_values is None: if pixel_values is None:
raise ValueError("You have to specify pixel_values") raise ValueError("You have to specify pixel_values")
@ -412,10 +402,9 @@ class SiglipVisionTransformer(nn.Module):
inputs_embeds=hidden_states, inputs_embeds=hidden_states,
) )
last_hidden_state = encoder_outputs last_hidden_state = encoder_outputs
post_last_hidden_state = self.post_layernorm(last_hidden_state)
return BaseModelOutputWithPooling( return BaseModelOutputWithPooling(
last_hidden_state=post_last_hidden_state, last_hidden_state=last_hidden_state,
# pooler_output=pooled_output, # pooler_output=pooled_output,
# hidden_states=encoder_outputs, # hidden_states=encoder_outputs,
) )