Fix: don't apply post layernorm in SiglipVisionTransformer (#2459)
* Fix: don't apply post layernorm in SiglipVisionTransformer This fixes a bug with LLaVA Next when using Siglip as the vision model. LLaVA Next expects the output of the vision model to be the encoder outputs before layernorm (see original transformers implementation here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L813). This also makes Siglip consistent with the existing Clip implementation: https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/models/custom_modeling/clip.py#L613 * fix: adjust pali gemma for post layer norm and small refactors --------- Co-authored-by: Travis Addair <tgaddair@gmail.com>
This commit is contained in:
parent
f3c5d7d92f
commit
30be188400
|
@ -34,6 +34,11 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||||
config=config.vision_config,
|
config=config.vision_config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
self.post_vision_tower_layernorm = nn.LayerNorm.load(
|
||||||
|
prefix="vision_tower.vision_model.post_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.vision_config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
self.multi_modal_projector = TensorParallelColumnLinear.load(
|
self.multi_modal_projector = TensorParallelColumnLinear.load(
|
||||||
config,
|
config,
|
||||||
|
@ -84,7 +89,10 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
|
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
|
||||||
image_outputs = self.vision_tower(pixel_values)
|
image_outputs = self.vision_tower(pixel_values)
|
||||||
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
|
last_hidden_state = self.post_vision_tower_layernorm(
|
||||||
|
image_outputs.last_hidden_state
|
||||||
|
)
|
||||||
|
image_features = self.multi_modal_projector(last_hidden_state)
|
||||||
|
|
||||||
# mask where image or padding tokens
|
# mask where image or padding tokens
|
||||||
mask = input_ids == self.config.image_token_index
|
mask = input_ids == self.config.image_token_index
|
||||||
|
|
|
@ -364,7 +364,6 @@ class SiglipEncoder(nn.Module):
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
for idx, encoder_layer in enumerate(self.layers):
|
for idx, encoder_layer in enumerate(self.layers):
|
||||||
hidden_states, _ = encoder_layer(
|
hidden_states, _ = encoder_layer(
|
||||||
|
@ -386,20 +385,11 @@ class SiglipVisionTransformer(nn.Module):
|
||||||
self.encoder = SiglipEncoder(
|
self.encoder = SiglipEncoder(
|
||||||
prefix=f"{prefix}.encoder", config=config, weights=weights
|
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||||
)
|
)
|
||||||
self.post_layernorm = nn.LayerNorm.load(
|
|
||||||
prefix=f"{prefix}.post_layernorm",
|
|
||||||
weights=weights,
|
|
||||||
eps=config.layer_norm_eps,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values: Optional[torch.FloatTensor] = None,
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
):
|
):
|
||||||
r"""
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
"""
|
|
||||||
if pixel_values is None:
|
if pixel_values is None:
|
||||||
raise ValueError("You have to specify pixel_values")
|
raise ValueError("You have to specify pixel_values")
|
||||||
|
|
||||||
|
@ -412,10 +402,9 @@ class SiglipVisionTransformer(nn.Module):
|
||||||
inputs_embeds=hidden_states,
|
inputs_embeds=hidden_states,
|
||||||
)
|
)
|
||||||
last_hidden_state = encoder_outputs
|
last_hidden_state = encoder_outputs
|
||||||
post_last_hidden_state = self.post_layernorm(last_hidden_state)
|
|
||||||
|
|
||||||
return BaseModelOutputWithPooling(
|
return BaseModelOutputWithPooling(
|
||||||
last_hidden_state=post_last_hidden_state,
|
last_hidden_state=last_hidden_state,
|
||||||
# pooler_output=pooled_output,
|
# pooler_output=pooled_output,
|
||||||
# hidden_states=encoder_outputs,
|
# hidden_states=encoder_outputs,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue