924 lines
40 KiB
Python
924 lines
40 KiB
Python
# coding=utf-8
|
|
# Copyright 2022 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" PyTorch CLIP model."""
|
|
|
|
import math
|
|
from dataclasses import dataclass
|
|
from typing import Any, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.utils.checkpoint
|
|
from torch import nn
|
|
|
|
import tqdm
|
|
from diffusers import (
|
|
ClassifierFreeGuidanceScheduler,
|
|
DiffusionPipeline,
|
|
GlideDDIMScheduler,
|
|
GLIDESuperResUNetModel,
|
|
GLIDETextToImageUNetModel,
|
|
)
|
|
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
|
|
from transformers.activations import ACT2FN
|
|
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
|
from transformers.modeling_utils import PreTrainedModel
|
|
from transformers.utils import (
|
|
ModelOutput,
|
|
add_start_docstrings,
|
|
add_start_docstrings_to_model_forward,
|
|
logging,
|
|
replace_return_docstrings,
|
|
)
|
|
|
|
|
|
#####################
|
|
# START OF THE CLIP MODEL COPY-PASTE (with a modified attention module)
|
|
#####################
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
_CHECKPOINT_FOR_DOC = "fusing/glide-base"
|
|
|
|
CLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|
"fusing/glide-base",
|
|
# See all CLIP models at https://huggingface.co/models?filter=clip
|
|
]
|
|
|
|
|
|
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
|
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
|
"""
|
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
|
"""
|
|
bsz, src_len = mask.size()
|
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
|
|
|
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
|
|
|
inverted_mask = 1.0 - expanded_mask
|
|
|
|
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
|
|
|
|
|
# contrastive loss function, adapted from
|
|
# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
|
|
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
|
|
return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
|
|
|
|
|
|
def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
|
|
caption_loss = contrastive_loss(similarity)
|
|
image_loss = contrastive_loss(similarity.T)
|
|
return (caption_loss + image_loss) / 2.0
|
|
|
|
|
|
@dataclass
|
|
class CLIPOutput(ModelOutput):
|
|
"""
|
|
Args:
|
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
|
Contrastive loss for image-text similarity.
|
|
logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
|
|
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
|
similarity scores.
|
|
logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
|
|
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
|
|
similarity scores.
|
|
text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
|
The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`].
|
|
image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
|
The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`].
|
|
text_model_output(`BaseModelOutputWithPooling`):
|
|
The output of the [`CLIPTextModel`].
|
|
vision_model_output(`BaseModelOutputWithPooling`):
|
|
The output of the [`CLIPVisionModel`].
|
|
"""
|
|
|
|
loss: Optional[torch.FloatTensor] = None
|
|
logits_per_image: torch.FloatTensor = None
|
|
logits_per_text: torch.FloatTensor = None
|
|
text_embeds: torch.FloatTensor = None
|
|
image_embeds: torch.FloatTensor = None
|
|
text_model_output: BaseModelOutputWithPooling = None
|
|
vision_model_output: BaseModelOutputWithPooling = None
|
|
|
|
def to_tuple(self) -> Tuple[Any]:
|
|
return tuple(
|
|
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
|
|
for k in self.keys()
|
|
)
|
|
|
|
|
|
class CLIPVisionEmbeddings(nn.Module):
|
|
def __init__(self, config: CLIPVisionConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.embed_dim = config.hidden_size
|
|
self.image_size = config.image_size
|
|
self.patch_size = config.patch_size
|
|
|
|
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
|
|
|
|
self.patch_embedding = nn.Conv2d(
|
|
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False
|
|
)
|
|
|
|
self.num_patches = (self.image_size // self.patch_size) ** 2
|
|
self.num_positions = self.num_patches + 1
|
|
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
|
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
|
|
|
|
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
|
batch_size = pixel_values.shape[0]
|
|
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
|
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
|
|
|
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
|
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
|
embeddings = embeddings + self.position_embedding(self.position_ids)
|
|
return embeddings
|
|
|
|
|
|
class CLIPTextEmbeddings(nn.Module):
|
|
def __init__(self, config: CLIPTextConfig):
|
|
super().__init__()
|
|
embed_dim = config.hidden_size
|
|
|
|
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
|
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
|
|
self.use_padding_embeddings = config.use_padding_embeddings
|
|
if self.use_padding_embeddings:
|
|
self.padding_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
|
|
|
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
|
|
|
if position_ids is None:
|
|
position_ids = self.position_ids[:, :seq_length]
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.token_embedding(input_ids)
|
|
|
|
position_embeddings = self.position_embedding(position_ids)
|
|
embeddings = inputs_embeds + position_embeddings
|
|
|
|
if self.use_padding_embeddings and attention_mask is not None:
|
|
padding_embeddings = self.padding_embedding(position_ids)
|
|
embeddings = torch.where(attention_mask.bool().unsqueeze(-1), embeddings, padding_embeddings)
|
|
|
|
return embeddings
|
|
|
|
|
|
class CLIPAttention(nn.Module):
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.embed_dim = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.head_dim = self.embed_dim // self.num_heads
|
|
if self.head_dim * self.num_heads != self.embed_dim:
|
|
raise ValueError(
|
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
|
f" {self.num_heads})."
|
|
)
|
|
self.scale = 1 / math.sqrt(math.sqrt(self.head_dim))
|
|
|
|
self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3)
|
|
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
causal_attention_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
"""Input shape: Batch x Time x Channel"""
|
|
|
|
bsz, tgt_len, embed_dim = hidden_states.size()
|
|
|
|
qkv_states = self.qkv_proj(hidden_states)
|
|
qkv_states = qkv_states.view(bsz, tgt_len, self.num_heads, -1)
|
|
query_states, key_states, value_states = torch.split(qkv_states, self.head_dim, dim=-1)
|
|
|
|
attn_weights = torch.einsum("bthc,bshc->bhts", query_states * self.scale, key_states * self.scale)
|
|
|
|
wdtype = attn_weights.dtype
|
|
attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1).type(wdtype)
|
|
|
|
attn_output = torch.einsum("bhts,bshc->bthc", attn_weights, value_states)
|
|
attn_output = attn_output.reshape(bsz, tgt_len, -1)
|
|
|
|
attn_output = self.out_proj(attn_output)
|
|
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class CLIPMLP(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.activation_fn = ACT2FN[config.hidden_act]
|
|
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
|
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.fc1(hidden_states)
|
|
hidden_states = self.activation_fn(hidden_states)
|
|
hidden_states = self.fc2(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class CLIPEncoderLayer(nn.Module):
|
|
def __init__(self, config: CLIPConfig):
|
|
super().__init__()
|
|
self.embed_dim = config.hidden_size
|
|
self.self_attn = CLIPAttention(config)
|
|
self.layer_norm1 = nn.LayerNorm(self.embed_dim)
|
|
self.mlp = CLIPMLP(config)
|
|
self.layer_norm2 = nn.LayerNorm(self.embed_dim)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
causal_attention_mask: torch.Tensor,
|
|
output_attentions: Optional[bool] = False,
|
|
) -> Tuple[torch.FloatTensor]:
|
|
"""
|
|
Args:
|
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
|
`(config.encoder_attention_heads,)`.
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
returned tensors for more detail.
|
|
"""
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.layer_norm1(hidden_states)
|
|
hidden_states, attn_weights = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
causal_attention_mask=causal_attention_mask,
|
|
output_attentions=output_attentions,
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.layer_norm2(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if output_attentions:
|
|
outputs += (attn_weights,)
|
|
|
|
return outputs
|
|
|
|
|
|
class CLIPPreTrainedModel(PreTrainedModel):
|
|
"""
|
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
|
models.
|
|
"""
|
|
|
|
config_class = CLIPConfig
|
|
base_model_prefix = "clip"
|
|
supports_gradient_checkpointing = True
|
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
|
|
|
def _init_weights(self, module):
|
|
"""Initialize the weights"""
|
|
factor = self.config.initializer_factor
|
|
if isinstance(module, CLIPTextEmbeddings):
|
|
module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
|
module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
|
if hasattr(module, "padding_embedding"):
|
|
module.padding_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
|
elif isinstance(module, CLIPVisionEmbeddings):
|
|
factor = self.config.initializer_factor
|
|
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
|
|
nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
|
|
nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
|
|
elif isinstance(module, CLIPAttention):
|
|
factor = self.config.initializer_factor
|
|
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
|
out_proj_std = (module.embed_dim**-0.5) * factor
|
|
nn.init.normal_(module.qkv_proj.weight, std=in_proj_std)
|
|
nn.init.normal_(module.out_proj.weight, std=out_proj_std)
|
|
elif isinstance(module, CLIPMLP):
|
|
factor = self.config.initializer_factor
|
|
in_proj_std = (
|
|
(module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
|
)
|
|
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
|
|
nn.init.normal_(module.fc1.weight, std=fc_std)
|
|
nn.init.normal_(module.fc2.weight, std=in_proj_std)
|
|
elif isinstance(module, CLIPModel):
|
|
nn.init.normal_(
|
|
module.text_projection.weight,
|
|
std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
|
|
)
|
|
nn.init.normal_(
|
|
module.visual_projection.weight,
|
|
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
|
|
)
|
|
|
|
if isinstance(module, nn.LayerNorm):
|
|
module.bias.data.zero_()
|
|
module.weight.data.fill_(1.0)
|
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
module.bias.data.zero_()
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False):
|
|
if isinstance(module, CLIPEncoder):
|
|
module.gradient_checkpointing = value
|
|
|
|
|
|
CLIP_START_DOCSTRING = r"""
|
|
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
|
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
|
behavior.
|
|
|
|
Parameters:
|
|
config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
|
|
Initializing with a config file does not load the weights associated with the model, only the
|
|
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
"""
|
|
|
|
CLIP_TEXT_INPUTS_DOCSTRING = r"""
|
|
Args:
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
|
it.
|
|
|
|
Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
[What are input IDs?](../glossary#input-ids)
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
|
config.max_position_embeddings - 1]`.
|
|
|
|
[What are position IDs?](../glossary#position-ids)
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
tensors for more detail.
|
|
output_hidden_states (`bool`, *optional*):
|
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
more detail.
|
|
return_dict (`bool`, *optional*):
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
"""
|
|
|
|
CLIP_VISION_INPUTS_DOCSTRING = r"""
|
|
Args:
|
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
|
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
|
[`CLIPFeatureExtractor`]. See [`CLIPFeatureExtractor.__call__`] for details.
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
tensors for more detail.
|
|
output_hidden_states (`bool`, *optional*):
|
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
more detail.
|
|
return_dict (`bool`, *optional*):
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
"""
|
|
|
|
CLIP_INPUTS_DOCSTRING = r"""
|
|
Args:
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
|
it.
|
|
|
|
Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
[What are input IDs?](../glossary#input-ids)
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
|
config.max_position_embeddings - 1]`.
|
|
|
|
[What are position IDs?](../glossary#position-ids)
|
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
|
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
|
[`CLIPFeatureExtractor`]. See [`CLIPFeatureExtractor.__call__`] for details.
|
|
return_loss (`bool`, *optional*):
|
|
Whether or not to return the contrastive loss.
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
tensors for more detail.
|
|
output_hidden_states (`bool`, *optional*):
|
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
more detail.
|
|
return_dict (`bool`, *optional*):
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
"""
|
|
|
|
|
|
class CLIPEncoder(nn.Module):
|
|
"""
|
|
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
|
[`CLIPEncoderLayer`].
|
|
|
|
Args:
|
|
config: CLIPConfig
|
|
"""
|
|
|
|
def __init__(self, config: CLIPConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
|
self.gradient_checkpointing = False
|
|
|
|
def forward(
|
|
self,
|
|
inputs_embeds,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
causal_attention_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[Tuple, BaseModelOutput]:
|
|
r"""
|
|
Args:
|
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
|
than the model's internal embedding lookup matrix.
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Causal mask for the text model. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
returned tensors for more detail.
|
|
output_hidden_states (`bool`, *optional*):
|
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
for more detail.
|
|
return_dict (`bool`, *optional*):
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
"""
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
encoder_states = () if output_hidden_states else None
|
|
all_attentions = () if output_attentions else None
|
|
|
|
hidden_states = inputs_embeds
|
|
for idx, encoder_layer in enumerate(self.layers):
|
|
if output_hidden_states:
|
|
encoder_states = encoder_states + (hidden_states,)
|
|
if self.gradient_checkpointing and self.training:
|
|
|
|
def create_custom_forward(module):
|
|
def custom_forward(*inputs):
|
|
return module(*inputs, output_attentions)
|
|
|
|
return custom_forward
|
|
|
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(encoder_layer),
|
|
hidden_states,
|
|
attention_mask,
|
|
causal_attention_mask,
|
|
)
|
|
else:
|
|
layer_outputs = encoder_layer(
|
|
hidden_states,
|
|
attention_mask,
|
|
causal_attention_mask,
|
|
output_attentions=output_attentions,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if output_attentions:
|
|
all_attentions = all_attentions + (layer_outputs[1],)
|
|
|
|
if output_hidden_states:
|
|
encoder_states = encoder_states + (hidden_states,)
|
|
|
|
if not return_dict:
|
|
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
|
return BaseModelOutput(
|
|
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
|
)
|
|
|
|
|
|
class CLIPTextTransformer(nn.Module):
|
|
def __init__(self, config: CLIPTextConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
embed_dim = config.hidden_size
|
|
self.embeddings = CLIPTextEmbeddings(config)
|
|
self.encoder = CLIPEncoder(config)
|
|
self.final_layer_norm = nn.LayerNorm(embed_dim)
|
|
|
|
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
|
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
|
r"""
|
|
Returns:
|
|
|
|
"""
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
if input_ids is None:
|
|
raise ValueError("You have to specify either input_ids")
|
|
|
|
input_shape = input_ids.size()
|
|
input_ids = input_ids.view(-1, input_shape[-1])
|
|
|
|
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask)
|
|
|
|
bsz, seq_len = input_shape
|
|
# CLIP's text model uses causal mask, prepare it here.
|
|
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
|
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len).to(hidden_states.device)
|
|
|
|
# expand attention_mask
|
|
if attention_mask is not None:
|
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
|
|
|
|
encoder_outputs = self.encoder(
|
|
inputs_embeds=hidden_states,
|
|
attention_mask=None,
|
|
causal_attention_mask=None,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
last_hidden_state = encoder_outputs[0]
|
|
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
|
|
|
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
|
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
|
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)]
|
|
|
|
if not return_dict:
|
|
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
|
|
|
return BaseModelOutputWithPooling(
|
|
last_hidden_state=last_hidden_state,
|
|
pooler_output=pooled_output,
|
|
hidden_states=encoder_outputs.hidden_states,
|
|
attentions=encoder_outputs.attentions,
|
|
)
|
|
|
|
def _build_causal_attention_mask(self, bsz, seq_len):
|
|
# lazily create causal attention mask, with full attention between the vision tokens
|
|
# pytorch uses additive attention mask; fill with -inf
|
|
mask = torch.empty(bsz, seq_len, seq_len)
|
|
mask.fill_(torch.tensor(float("-inf")))
|
|
mask.triu_(1) # zero out the lower diagonal
|
|
mask = mask.unsqueeze(1) # expand mask
|
|
return mask
|
|
|
|
|
|
class CLIPTextModel(CLIPPreTrainedModel):
|
|
config_class = CLIPTextConfig
|
|
|
|
def __init__(self, config: CLIPTextConfig):
|
|
super().__init__(config)
|
|
self.text_model = CLIPTextTransformer(config)
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self) -> nn.Module:
|
|
return self.text_model.embeddings.token_embedding
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.text_model.embeddings.token_embedding = value
|
|
|
|
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
|
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
|
r"""
|
|
Returns:
|
|
|
|
Examples:
|
|
|
|
```python
|
|
>>> from transformers import CLIPTokenizer, CLIPTextModel
|
|
|
|
>>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
|
|
>>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
|
|
|
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
|
|
|
>>> outputs = model(**inputs)
|
|
>>> last_hidden_state = outputs.last_hidden_state
|
|
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
|
|
```"""
|
|
return self.text_model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
|
|
#####################
|
|
# END OF THE CLIP MODEL COPY-PASTE
|
|
#####################
|
|
|
|
|
|
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
|
"""
|
|
Extract values from a 1-D numpy array for a batch of indices.
|
|
|
|
:param arr: the 1-D numpy array.
|
|
:param timesteps: a tensor of indices into the array to extract.
|
|
:param broadcast_shape: a larger shape of K dimensions with the batch
|
|
dimension equal to the length of timesteps.
|
|
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
|
"""
|
|
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
|
while len(res.shape) < len(broadcast_shape):
|
|
res = res[..., None]
|
|
return res + torch.zeros(broadcast_shape, device=timesteps.device)
|
|
|
|
|
|
class GLIDE(DiffusionPipeline):
|
|
def __init__(
|
|
self,
|
|
text_unet: GLIDETextToImageUNetModel,
|
|
text_noise_scheduler: ClassifierFreeGuidanceScheduler,
|
|
text_encoder: CLIPTextModel,
|
|
tokenizer: GPT2Tokenizer,
|
|
upscale_unet: GLIDESuperResUNetModel,
|
|
upscale_noise_scheduler: GlideDDIMScheduler,
|
|
):
|
|
super().__init__()
|
|
self.register_modules(
|
|
text_unet=text_unet,
|
|
text_noise_scheduler=text_noise_scheduler,
|
|
text_encoder=text_encoder,
|
|
tokenizer=tokenizer,
|
|
upscale_unet=upscale_unet,
|
|
upscale_noise_scheduler=upscale_noise_scheduler,
|
|
)
|
|
|
|
def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
|
|
"""
|
|
Compute the mean and variance of the diffusion posterior:
|
|
|
|
q(x_{t-1} | x_t, x_0)
|
|
|
|
"""
|
|
assert x_start.shape == x_t.shape
|
|
posterior_mean = (
|
|
_extract_into_tensor(scheduler.posterior_mean_coef1, t, x_t.shape) * x_start
|
|
+ _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
|
|
)
|
|
posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
|
|
posterior_log_variance_clipped = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x_t.shape)
|
|
assert (
|
|
posterior_mean.shape[0]
|
|
== posterior_variance.shape[0]
|
|
== posterior_log_variance_clipped.shape[0]
|
|
== x_start.shape[0]
|
|
)
|
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
|
|
|
def p_mean_variance(self, model, scheduler, x, t, transformer_out=None, low_res=None, clip_denoised=True):
|
|
"""
|
|
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
|
the initial x, x_0.
|
|
|
|
:param model: the model, which takes a signal and a batch of timesteps
|
|
as input.
|
|
:param x: the [N x C x ...] tensor at time t.
|
|
:param t: a 1-D Tensor of timesteps.
|
|
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
|
|
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
|
pass to the model. This can be used for conditioning.
|
|
:return: a dict with the following keys:
|
|
- 'mean': the model mean output.
|
|
- 'variance': the model variance output.
|
|
- 'log_variance': the log of 'variance'.
|
|
- 'pred_xstart': the prediction for x_0.
|
|
"""
|
|
|
|
B, C = x.shape[:2]
|
|
assert t.shape == (B,)
|
|
if transformer_out is None:
|
|
# super-res model
|
|
model_output = model(x, t, low_res)
|
|
else:
|
|
# text2image model
|
|
model_output = model(x, t, transformer_out)
|
|
|
|
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
|
model_output, model_var_values = torch.split(model_output, C, dim=1)
|
|
min_log = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x.shape)
|
|
max_log = _extract_into_tensor(np.log(scheduler.betas), t, x.shape)
|
|
# The model_var_values is [-1, 1] for [min_var, max_var].
|
|
frac = (model_var_values + 1) / 2
|
|
model_log_variance = frac * max_log + (1 - frac) * min_log
|
|
model_variance = torch.exp(model_log_variance)
|
|
|
|
pred_xstart = self._predict_xstart_from_eps(scheduler, x_t=x, t=t, eps=model_output)
|
|
if clip_denoised:
|
|
pred_xstart = pred_xstart.clamp(-1, 1)
|
|
model_mean, _, _ = self.q_posterior_mean_variance(scheduler, x_start=pred_xstart, x_t=x, t=t)
|
|
|
|
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
|
return model_mean, model_variance, model_log_variance, pred_xstart
|
|
|
|
def _predict_xstart_from_eps(self, scheduler, x_t, t, eps):
|
|
assert x_t.shape == eps.shape
|
|
return (
|
|
_extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
|
- _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
|
)
|
|
|
|
def _predict_eps_from_xstart(self, scheduler, x_t, t, pred_xstart):
|
|
return (
|
|
_extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
|
|
) / _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
|
|
|
@torch.no_grad()
|
|
def __call__(self, prompt, generator=None, torch_device=None):
|
|
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
self.text_unet.to(torch_device)
|
|
self.text_encoder.to(torch_device)
|
|
self.upscale_unet.to(torch_device)
|
|
|
|
# Create a classifier-free guidance sampling function
|
|
guidance_scale = 3.0
|
|
|
|
def text_model_fn(x_t, ts, transformer_out, **kwargs):
|
|
half = x_t[: len(x_t) // 2]
|
|
combined = torch.cat([half, half], dim=0)
|
|
model_out = self.text_unet(combined, ts, transformer_out, **kwargs)
|
|
eps, rest = model_out[:, :3], model_out[:, 3:]
|
|
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
|
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
|
|
eps = torch.cat([half_eps, half_eps], dim=0)
|
|
return torch.cat([eps, rest], dim=1)
|
|
|
|
# 1. Sample gaussian noise
|
|
batch_size = 2 # second image is empty for classifier-free guidance
|
|
image = self.text_noise_scheduler.sample_noise(
|
|
(batch_size, self.text_unet.in_channels, 64, 64), device=torch_device, generator=generator
|
|
)
|
|
|
|
# 2. Encode tokens
|
|
# an empty input is needed to guide the model away from (
|
|
inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt")
|
|
input_ids = inputs["input_ids"].to(torch_device)
|
|
attention_mask = inputs["attention_mask"].to(torch_device)
|
|
transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
|
|
|
|
# 3. Run the text2image generation step
|
|
num_timesteps = len(self.text_noise_scheduler)
|
|
for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps):
|
|
t = torch.tensor([i] * image.shape[0], device=torch_device)
|
|
mean, variance, log_variance, pred_xstart = self.p_mean_variance(
|
|
text_model_fn, self.text_noise_scheduler, image, t, transformer_out=transformer_out
|
|
)
|
|
noise = self.text_noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator)
|
|
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0
|
|
image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
|
|
|
|
# 4. Run the upscaling step
|
|
batch_size = 1
|
|
image = image[:1]
|
|
low_res = ((image + 1) * 127.5).round() / 127.5 - 1
|
|
eta = 0.0
|
|
|
|
# Tune this parameter to control the sharpness of 256x256 images.
|
|
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
|
|
upsample_temp = 0.997
|
|
|
|
image = (
|
|
self.upscale_noise_scheduler.sample_noise(
|
|
(batch_size, 3, 256, 256), device=torch_device, generator=generator
|
|
)
|
|
* upsample_temp
|
|
)
|
|
|
|
num_timesteps = len(self.upscale_noise_scheduler)
|
|
for t in tqdm.tqdm(
|
|
reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)
|
|
):
|
|
# i) define coefficients for time step t
|
|
clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t))
|
|
clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1)
|
|
image_coeff = (
|
|
(1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1))
|
|
* torch.sqrt(self.upscale_noise_scheduler.get_alpha(t))
|
|
/ (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
|
|
)
|
|
clipped_coeff = (
|
|
torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1))
|
|
* self.upscale_noise_scheduler.get_beta(t)
|
|
/ (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
|
|
)
|
|
|
|
# ii) predict noise residual
|
|
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
|
|
model_output = self.upscale_unet(image, time_input, low_res)
|
|
noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
|
|
|
|
# iii) compute predicted image from residual
|
|
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
|
|
pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual
|
|
pred_mean = torch.clamp(pred_mean, -1, 1)
|
|
prev_image = clipped_coeff * pred_mean + image_coeff * image
|
|
|
|
# iv) sample variance
|
|
prev_variance = self.upscale_noise_scheduler.sample_variance(
|
|
t, prev_image.shape, device=torch_device, generator=generator
|
|
)
|
|
|
|
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
|
|
sampled_prev_image = prev_image + prev_variance
|
|
image = sampled_prev_image
|
|
|
|
image = image.permute(0, 2, 3, 1)
|
|
|
|
return image
|