From db2a1077c0bf55d1600b26ad9e2659c8a06f048b Mon Sep 17 00:00:00 2001 From: anton-l Date: Tue, 7 Jun 2022 19:01:58 +0200 Subject: [PATCH 1/5] Add glide text encoder --- models/vision/glide/convert_weights.py | 17 +- models/vision/glide/modelling_text_encoder.py | 687 ++++++++++++++++++ 2 files changed, 693 insertions(+), 11 deletions(-) create mode 100644 models/vision/glide/modelling_text_encoder.py diff --git a/models/vision/glide/convert_weights.py b/models/vision/glide/convert_weights.py index 6792cdef..7ec1b924 100644 --- a/models/vision/glide/convert_weights.py +++ b/models/vision/glide/convert_weights.py @@ -3,7 +3,8 @@ import argparse import torch from torch import nn -from transformers import CLIPTextConfig, CLIPTextModel, GPT2Tokenizer +from transformers import CLIPTextConfig, GPT2Tokenizer +from modelling_text_encoder import CLIPTextModel # wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt state_dict = torch.load("base.pt", map_location="cpu") @@ -13,7 +14,8 @@ config = CLIPTextConfig( intermediate_size=2048, num_hidden_layers=16, num_attention_heads=8, - max_position_embeddings=128 + max_position_embeddings=128, + use_padding_embeddings=True, ) model = CLIPTextModel(config).eval() tokenizer = GPT2Tokenizer("./glide-base/vocab.json", "./glide-base/merges.txt", pad_token="<|endoftext|>") @@ -30,15 +32,8 @@ hf_encoder.final_layer_norm.bias = state_dict["final_ln.bias"] for layer_idx in range(config.num_hidden_layers): hf_layer = hf_encoder.encoder.layers[layer_idx] - q_proj, k_proj, v_proj = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.weight"].chunk(3, dim=0) - q_proj_bias, k_proj_bias, v_proj_bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.bias"].chunk(3, dim=0) - - hf_layer.self_attn.q_proj.weight.data = q_proj - hf_layer.self_attn.q_proj.bias.data = q_proj_bias - hf_layer.self_attn.k_proj.weight.data = k_proj - hf_layer.self_attn.k_proj.bias.data = k_proj_bias - hf_layer.self_attn.v_proj.weight.data = v_proj - hf_layer.self_attn.v_proj.bias.data = v_proj_bias + hf_layer.self_attn.qkv_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.weight"] + hf_layer.self_attn.qkv_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.bias"] hf_layer.self_attn.out_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.weight"] hf_layer.self_attn.out_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.bias"] diff --git a/models/vision/glide/modelling_text_encoder.py b/models/vision/glide/modelling_text_encoder.py new file mode 100644 index 00000000..ba25cf3a --- /dev/null +++ b/models/vision/glide/modelling_text_encoder.py @@ -0,0 +1,687 @@ +# coding=utf-8 +# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch CLIP model.""" + +from dataclasses import dataclass +import math +from typing import Any, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers import CLIPModel, CLIPConfig, CLIPVisionConfig, CLIPTextConfig + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "openai/clip-vit-base-patch32" + +CLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "openai/clip-vit-base-patch32", + # See all CLIP models at https://huggingface.co/models?filter=clip +] + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +# contrastive loss function, adapted from +# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html +def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: + return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) + + +def clip_loss(similarity: torch.Tensor) -> torch.Tensor: + caption_loss = contrastive_loss(similarity) + image_loss = contrastive_loss(similarity.T) + return (caption_loss + image_loss) / 2.0 + + +@dataclass +class CLIPOutput(ModelOutput): + """ + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`]. + image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`]. + text_model_output(`BaseModelOutputWithPooling`): + The output of the [`CLIPTextModel`]. + vision_model_output(`BaseModelOutputWithPooling`): + The output of the [`CLIPVisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: torch.FloatTensor = None + logits_per_text: torch.FloatTensor = None + text_embeds: torch.FloatTensor = None + image_embeds: torch.FloatTensor = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +class CLIPVisionEmbeddings(nn.Module): + def __init__(self, config: CLIPVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1))) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +class CLIPTextEmbeddings(nn.Module): + def __init__(self, config: CLIPTextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) + self.use_padding_embeddings = config.use_padding_embeddings + if self.use_padding_embeddings: + self.padding_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + if self.use_padding_embeddings and attention_mask is not None: + padding_embeddings = self.padding_embedding(position_ids) + embeddings = torch.where(attention_mask.bool().unsqueeze(-1), embeddings, padding_embeddings) + + return embeddings + + +class CLIPAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = 1 / math.sqrt(math.sqrt(self.head_dim)) + + self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim*3) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.size() + + qkv_states = self.qkv_proj(hidden_states) + qkv_states = qkv_states.view(bsz, tgt_len, self.num_heads, -1) + query_states, key_states, value_states = torch.split(qkv_states, self.head_dim, dim=-1) + + attn_weights = torch.einsum( + "bthc,bshc->bhts", query_states * self.scale, key_states * self.scale + ) + + wdtype = attn_weights.dtype + attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1).type(wdtype) + + attn_output = torch.einsum("bhts,bshc->bthc", attn_weights, value_states) + attn_output = attn_output.reshape(bsz, tgt_len, -1) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class CLIPMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class CLIPEncoderLayer(nn.Module): + def __init__(self, config: CLIPConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = CLIPAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim) + self.mlp = CLIPMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class CLIPPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = CLIPConfig + base_model_prefix = "clip" + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor + if isinstance(module, CLIPTextEmbeddings): + module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + if hasattr(module, "padding_embedding"): + module.padding_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + elif isinstance(module, CLIPVisionEmbeddings): + factor = self.config.initializer_factor + nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim ** -0.5 * factor) + nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) + elif isinstance(module, CLIPAttention): + factor = self.config.initializer_factor + in_proj_std = (module.embed_dim ** -0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + out_proj_std = (module.embed_dim ** -0.5) * factor + nn.init.normal_(module.qkv_proj.weight, std=in_proj_std) + nn.init.normal_(module.out_proj.weight, std=out_proj_std) + elif isinstance(module, CLIPMLP): + factor = self.config.initializer_factor + in_proj_std = ( + (module.config.hidden_size ** -0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + ) + fc_std = (2 * module.config.hidden_size) ** -0.5 * factor + nn.init.normal_(module.fc1.weight, std=fc_std) + nn.init.normal_(module.fc2.weight, std=in_proj_std) + elif isinstance(module, CLIPModel): + nn.init.normal_( + module.text_projection.weight, + std=module.text_embed_dim ** -0.5 * self.config.initializer_factor, + ) + nn.init.normal_( + module.visual_projection.weight, + std=module.vision_embed_dim ** -0.5 * self.config.initializer_factor, + ) + + if isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, CLIPEncoder): + module.gradient_checkpointing = value + + +CLIP_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`CLIPConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +CLIP_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +CLIP_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`CLIPFeatureExtractor`]. See [`CLIPFeatureExtractor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +CLIP_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`CLIPFeatureExtractor`]. See [`CLIPFeatureExtractor.__call__`] for details. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class CLIPEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`CLIPEncoderLayer`]. + + Args: + config: CLIPConfig + """ + + def __init__(self, config: CLIPConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + causal_attention_mask, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class CLIPTextTransformer(nn.Module): + def __init__(self, config: CLIPTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = CLIPTextEmbeddings(config) + self.encoder = CLIPEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim) + + @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify either input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask) + + bsz, seq_len = input_shape + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len).to(hidden_states.device) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=None, + causal_attention_mask=None, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # text_embeds.shape = [batch_size, sequence_length, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def _build_causal_attention_mask(self, bsz, seq_len): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(bsz, seq_len, seq_len) + mask.fill_(torch.tensor(float("-inf"))) + mask.triu_(1) # zero out the lower diagonal + mask = mask.unsqueeze(1) # expand mask + return mask + + +class CLIPTextModel(CLIPPreTrainedModel): + config_class = CLIPTextConfig + + def __init__(self, config: CLIPTextConfig): + super().__init__(config) + self.text_model = CLIPTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import CLIPTokenizer, CLIPTextModel + + >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") + >>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) \ No newline at end of file From 4d53a521508955e47b8bdac2f76891136135ad16 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 8 Jun 2022 11:44:27 +0200 Subject: [PATCH 2/5] add unet ldm in init --- src/diffusers/__init__.py | 1 + src/diffusers/models/__init__.py | 1 + src/diffusers/models/unet_ldm.py | 4 ++-- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 3ce4142f..8feb9e81 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -7,5 +7,6 @@ __version__ = "0.0.1" from .modeling_utils import ModelMixin from .models.unet import UNetModel from .models.unet_glide import UNetGLIDEModel +from .models.unet_ldm import UNetLDMModel from .pipeline_utils import DiffusionPipeline from .schedulers.gaussian_ddpm import GaussianDDPMScheduler diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 85f1cc03..6d6c4d3d 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -18,3 +18,4 @@ from .unet import UNetModel from .unet_glide import UNetGLIDEModel +from .unet_ldm import UNetLDMModel diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index 465c168c..57dec0b6 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -830,7 +830,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): self.conv_resample = conv_resample self.num_classes = num_classes self.use_checkpoint = use_checkpoint - self.dtype = torch.float16 if use_fp16 else torch.float32 + self.dtype_ = torch.float16 if use_fp16 else torch.float32 self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample @@ -1060,7 +1060,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): assert y.shape == (x.shape[0],) emb = emb + self.label_emb(y) - h = x.type(self.dtype) + h = x.type(self.dtype_) for module in self.input_blocks: h = module(h, emb, context) hs.append(h) From 1e21f061601dda0aa9740e88bfce68bf4aac4acd Mon Sep 17 00:00:00 2001 From: anton-l Date: Wed, 8 Jun 2022 11:47:47 +0200 Subject: [PATCH 3/5] Classifier-free guidance scheduler + GLIDe pipeline --- models/vision/glide/README.md | 4 + models/vision/glide/convert_weights.py | 46 ++++-- models/vision/glide/modeling_glide.py | 144 ++++++++++++++---- models/vision/glide/run_glide.py | 9 +- src/diffusers/__init__.py | 2 + src/diffusers/models/__init__.py | 1 + .../diffusers/models/clip_text_transformer.py | 0 src/diffusers/models/unet_glide.py | 10 +- src/diffusers/pipeline_utils.py | 4 +- src/diffusers/schedulers/__init__.py | 1 + .../schedulers/classifier_free_guidance.py | 102 +++++++++++++ 11 files changed, 275 insertions(+), 48 deletions(-) rename models/vision/glide/modelling_text_encoder.py => src/diffusers/models/clip_text_transformer.py (100%) create mode 100644 src/diffusers/schedulers/classifier_free_guidance.py diff --git a/models/vision/glide/README.md b/models/vision/glide/README.md index e69de29b..743c9bb6 100644 --- a/models/vision/glide/README.md +++ b/models/vision/glide/README.md @@ -0,0 +1,4 @@ +# References + +[GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models](https://arxiv.org/pdf/2112.10741.pdf) +[Diffusion Models Beat GANs on Image Synthesis](https://arxiv.org/pdf/2105.05233.pdf) \ No newline at end of file diff --git a/models/vision/glide/convert_weights.py b/models/vision/glide/convert_weights.py index 7ec1b924..4f3320d7 100644 --- a/models/vision/glide/convert_weights.py +++ b/models/vision/glide/convert_weights.py @@ -1,25 +1,28 @@ -import argparse - import torch from torch import nn from transformers import CLIPTextConfig, GPT2Tokenizer -from modelling_text_encoder import CLIPTextModel +from diffusers import UNetGLIDEModel, ClassifierFreeGuidanceScheduler, CLIPTextModel +from modeling_glide import GLIDE # wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt state_dict = torch.load("base.pt", map_location="cpu") state_dict = {k: nn.Parameter(v) for k, v in state_dict.items()} + +### Convert the text encoder + config = CLIPTextConfig( + vocab_size=50257, + max_position_embeddings=128, hidden_size=512, intermediate_size=2048, num_hidden_layers=16, num_attention_heads=8, - max_position_embeddings=128, use_padding_embeddings=True, ) model = CLIPTextModel(config).eval() tokenizer = GPT2Tokenizer("./glide-base/vocab.json", "./glide-base/merges.txt", pad_token="<|endoftext|>") -tokenizer.save_pretrained("./glide-base") +#tokenizer.save_pretrained("./glide-base") hf_encoder = model.text_model @@ -48,8 +51,33 @@ for layer_idx in range(config.num_hidden_layers): hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"] hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"] -inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt") -with torch.no_grad(): - outputs = model(**inputs) +#inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt") +#with torch.no_grad(): +# outputs = model(**inputs) -model.save_pretrained("./glide-base") \ No newline at end of file +#model.save_pretrained("./glide-base") + +### Convert the UNet + +unet_model = UNetGLIDEModel( + in_channels=3, + model_channels=192, + out_channels=6, + num_res_blocks=3, + attention_resolutions=(2, 4, 8), + dropout=0.1, + channel_mult=(1, 2, 3, 4), + num_heads=1, + num_head_channels=64, + num_heads_upsample=1, + use_scale_shift_norm=True, + resblock_updown=True, +) + +unet_model.load_state_dict(state_dict, strict=False) + +scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2") + +glide = GLIDE(unet=unet_model, noise_scheduler=scheduler, text_encoder=model, tokenizer=tokenizer) + +glide.save_pretrained("./glide-base") \ No newline at end of file diff --git a/models/vision/glide/modeling_glide.py b/models/vision/glide/modeling_glide.py index 747c1732..56c7b35f 100644 --- a/models/vision/glide/modeling_glide.py +++ b/models/vision/glide/modeling_glide.py @@ -14,46 +14,136 @@ # limitations under the License. -from diffusers import DiffusionPipeline -from diffusers import UNetGLIDEModel +from diffusers import DiffusionPipeline, UNetGLIDEModel, ClassifierFreeGuidanceScheduler, CLIPTextModel +from transformers import GPT2Tokenizer import tqdm import torch +import numpy as np + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + torch.zeros(broadcast_shape, device=timesteps.device) class GLIDE(DiffusionPipeline): - def __init__(self, unet: UNetGLIDEModel, noise_scheduler): + def __init__( + self, + unet: UNetGLIDEModel, + noise_scheduler: ClassifierFreeGuidanceScheduler, + text_encoder: CLIPTextModel, + tokenizer: GPT2Tokenizer + ): super().__init__() - self.register_modules(unet=unet, noise_scheduler=noise_scheduler) + self.register_modules(unet=unet, noise_scheduler=noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer) - def __call__(self, generator=None, torch_device=None): + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + + q(x_{t-1} | x_t, x_0) + + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.noise_scheduler.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.noise_scheduler.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.noise_scheduler.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.noise_scheduler.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, model, x, t, transformer_out, clip_denoised=True, model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B,) + model_output = model(x, t, transformer_out) + + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = torch.split(model_output, C, dim=1) + min_log = _extract_into_tensor(self.noise_scheduler.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.noise_scheduler.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = torch.exp(model_log_variance) + + pred_xstart = self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + if clip_denoised: + pred_xstart = pred_xstart.clamp(-1, 1) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return model_mean, model_variance, model_log_variance, pred_xstart + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.noise_scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.noise_scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def __call__(self, prompt, generator=None, torch_device=None): torch_device = "cuda" if torch.cuda.is_available() else "cpu" self.unet.to(torch_device) + self.text_encoder.to(torch_device) + # 1. Sample gaussian noise - image = self.noise_scheduler.sample_noise((1, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator) - for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)): - # i) define coefficients for time step t - clip_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t)) - clip_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1) - image_coeff = (1 - self.noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(self.noise_scheduler.get_alpha(t)) / (1 - self.noise_scheduler.get_alpha_prod(t)) - clip_coeff = torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1)) * self.noise_scheduler.get_beta(t) / (1 - self.noise_scheduler.get_alpha_prod(t)) + image = self.noise_scheduler.sample_noise((1, self.unet.in_channels, 64, 64), device=torch_device, generator=generator) - # ii) predict noise residual - with torch.no_grad(): - noise_residual = self.unet(image, t) + # 2. Encode tokens + # an empty input is needed to guide the model away from ( + inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt") + transformer_out = self.text_encoder(**inputs).last_hidden_state - # iii) compute predicted image from residual - # See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison - pred_mean = clip_image_coeff * image - clip_noise_coeff * noise_residual - pred_mean = torch.clamp(pred_mean, -1, 1) - prev_image = clip_coeff * pred_mean + image_coeff * image - - # iv) sample variance - prev_variance = self.noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator) - - # v) sample x_{t-1} ~ N(prev_image, prev_variance) - sampled_prev_image = prev_image + prev_variance - image = sampled_prev_image + num_timesteps = len(self.noise_scheduler) + for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps): + t = torch.tensor([i] * image.shape[0], device=torch_device) + mean, variance, log_variance, pred_xstart = self.p_mean_variance(self.unet, transformer_out, image, t) + noise = self.noise_scheduler.sample_noise(image.shape) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) + ) # no noise when t == 0 + image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise return image diff --git a/models/vision/glide/run_glide.py b/models/vision/glide/run_glide.py index 23cd4e10..4d6d8e2d 100644 --- a/models/vision/glide/run_glide.py +++ b/models/vision/glide/run_glide.py @@ -1,16 +1,11 @@ import torch -from .modeling_glide import GLIDE -from diffusers import UNetGLIDEModel, GaussianDDPMScheduler +from modeling_glide import GLIDE generator = torch.Generator() generator = generator.manual_seed(0) # 1. Load models - -scheduler = GaussianDDPMScheduler.from_config("fusing/glide-base") -model = UNetGLIDEModel.from_pretrained("fusing/glide-base") - -pipeline = GLIDE(model, scheduler) +pipeline = GLIDE.from_pretrained("fusing/glide-base") img = pipeline(generator) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 3ce4142f..14191402 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -7,5 +7,7 @@ __version__ = "0.0.1" from .modeling_utils import ModelMixin from .models.unet import UNetModel from .models.unet_glide import UNetGLIDEModel +from .models.clip_text_transformer import CLIPTextModel from .pipeline_utils import DiffusionPipeline from .schedulers.gaussian_ddpm import GaussianDDPMScheduler +from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 85f1cc03..964c0200 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -18,3 +18,4 @@ from .unet import UNetModel from .unet_glide import UNetGLIDEModel +from .clip_text_transformer import CLIPTextModel diff --git a/models/vision/glide/modelling_text_encoder.py b/src/diffusers/models/clip_text_transformer.py similarity index 100% rename from models/vision/glide/modelling_text_encoder.py rename to src/diffusers/models/clip_text_transformer.py diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index 363d0113..4b5cc971 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -470,7 +470,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): self.channel_mult = channel_mult self.conv_resample = conv_resample self.use_checkpoint = use_checkpoint - self.dtype = torch.float16 if use_fp16 else torch.float32 + #self.dtype = torch.float16 if use_fp16 else torch.float32 self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample @@ -653,13 +653,15 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): transformer_proj = self.transformer_proj(transformer_out[:, -1]) transformer_out = transformer_out.permute(0, 2, 1) # NLC -> NCL + emb = emb + transformer_proj.to(emb) + h = x.type(self.dtype) for module in self.input_blocks: - h = module(h, emb) + h = module(h, emb, transformer_out) hs.append(h) - h = self.middle_block(h, emb) + h = self.middle_block(h, emb, transformer_out) for module in self.output_blocks: h = torch.cat([h, hs.pop()], dim=1) - h = module(h, emb) + h = module(h, emb, transformer_out) h = h.type(x.dtype) return self.out(h) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 60ece225..dfc7f6d6 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -35,10 +35,12 @@ logger = logging.get_logger(__name__) LOADABLE_CLASSES = { "diffusers": { "ModelMixin": ["save_pretrained", "from_pretrained"], + "CLIPTextModel": ["save_pretrained", "from_pretrained"], # TODO (Anton): move to transformers "GaussianDDPMScheduler": ["save_config", "from_config"], + "ClassifierFreeGuidanceScheduler": ["save_config", "from_config"], }, "transformers": { - "ModelMixin": ["save_pretrained", "from_pretrained"], + "GPT2Tokenizer": ["save_pretrained", "from_pretrained"], }, } diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 81d9601a..82084c6c 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -17,3 +17,4 @@ # limitations under the License. from .gaussian_ddpm import GaussianDDPMScheduler +from .classifier_free_guidance import ClassifierFreeGuidanceScheduler diff --git a/src/diffusers/schedulers/classifier_free_guidance.py b/src/diffusers/schedulers/classifier_free_guidance.py new file mode 100644 index 00000000..17222c17 --- /dev/null +++ b/src/diffusers/schedulers/classifier_free_guidance.py @@ -0,0 +1,102 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import math +from torch import nn +import numpy as np + +from ..configuration_utils import ConfigMixin + + +SAMPLING_CONFIG_NAME = "scheduler_config.json" + + +def linear_beta_schedule(timesteps, beta_start, beta_end): + return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas, dtype=np.float64) + + +class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin): + + config_name = SAMPLING_CONFIG_NAME + + def __init__( + self, + timesteps=1000, + beta_schedule="squaredcos_cap_v2", + ): + super().__init__() + self.register( + timesteps=timesteps, + beta_schedule=beta_schedule, + ) + self.num_timesteps = int(timesteps) + + if beta_schedule == "squaredcos_cap_v2": + # GLIDE cosine schedule + betas = betas_for_alpha_bar( + timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + ) + + def sample_noise(self, shape, device, generator=None): + # always sample on CPU to be deterministic + return torch.randn(shape, generator=generator).to(device) + + def __len__(self): + return self.num_timesteps From 07ffe73f796db1c19555fee04711f1ab71a92de2 Mon Sep 17 00:00:00 2001 From: anton-l Date: Wed, 8 Jun 2022 11:53:12 +0200 Subject: [PATCH 4/5] Style --- models/vision/glide/convert_weights.py | 15 +-- models/vision/glide/modeling_glide.py | 30 +++--- models/vision/glide/run_glide.py | 2 + src/diffusers/__init__.py | 4 +- src/diffusers/configuration_utils.py | 10 +- src/diffusers/models/__init__.py | 2 +- src/diffusers/models/clip_text_transformer.py | 98 +++++++++---------- src/diffusers/models/unet_glide.py | 2 +- src/diffusers/pipeline_utils.py | 3 +- src/diffusers/schedulers/__init__.py | 2 +- .../schedulers/classifier_free_guidance.py | 19 ++-- 11 files changed, 91 insertions(+), 96 deletions(-) diff --git a/models/vision/glide/convert_weights.py b/models/vision/glide/convert_weights.py index 4f3320d7..a8016406 100644 --- a/models/vision/glide/convert_weights.py +++ b/models/vision/glide/convert_weights.py @@ -1,9 +1,10 @@ import torch from torch import nn -from transformers import CLIPTextConfig, GPT2Tokenizer -from diffusers import UNetGLIDEModel, ClassifierFreeGuidanceScheduler, CLIPTextModel +from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, UNetGLIDEModel from modeling_glide import GLIDE +from transformers import CLIPTextConfig, GPT2Tokenizer + # wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt state_dict = torch.load("base.pt", map_location="cpu") @@ -22,7 +23,7 @@ config = CLIPTextConfig( ) model = CLIPTextModel(config).eval() tokenizer = GPT2Tokenizer("./glide-base/vocab.json", "./glide-base/merges.txt", pad_token="<|endoftext|>") -#tokenizer.save_pretrained("./glide-base") +# tokenizer.save_pretrained("./glide-base") hf_encoder = model.text_model @@ -51,11 +52,11 @@ for layer_idx in range(config.num_hidden_layers): hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"] hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"] -#inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt") -#with torch.no_grad(): +# inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt") +# with torch.no_grad(): # outputs = model(**inputs) -#model.save_pretrained("./glide-base") +# model.save_pretrained("./glide-base") ### Convert the UNet @@ -80,4 +81,4 @@ scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squar glide = GLIDE(unet=unet_model, noise_scheduler=scheduler, text_encoder=model, tokenizer=tokenizer) -glide.save_pretrained("./glide-base") \ No newline at end of file +glide.save_pretrained("./glide-base") diff --git a/models/vision/glide/modeling_glide.py b/models/vision/glide/modeling_glide.py index 56c7b35f..ecd29637 100644 --- a/models/vision/glide/modeling_glide.py +++ b/models/vision/glide/modeling_glide.py @@ -14,12 +14,12 @@ # limitations under the License. -from diffusers import DiffusionPipeline, UNetGLIDEModel, ClassifierFreeGuidanceScheduler, CLIPTextModel -from transformers import GPT2Tokenizer +import numpy as np +import torch import tqdm -import torch -import numpy as np +from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, DiffusionPipeline, UNetGLIDEModel +from transformers import GPT2Tokenizer def _extract_into_tensor(arr, timesteps, broadcast_shape): @@ -40,14 +40,16 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): class GLIDE(DiffusionPipeline): def __init__( - self, - unet: UNetGLIDEModel, - noise_scheduler: ClassifierFreeGuidanceScheduler, - text_encoder: CLIPTextModel, - tokenizer: GPT2Tokenizer + self, + unet: UNetGLIDEModel, + noise_scheduler: ClassifierFreeGuidanceScheduler, + text_encoder: CLIPTextModel, + tokenizer: GPT2Tokenizer, ): super().__init__() - self.register_modules(unet=unet, noise_scheduler=noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer) + self.register_modules( + unet=unet, noise_scheduler=noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer + ) def q_posterior_mean_variance(self, x_start, x_t, t): """ @@ -129,7 +131,9 @@ class GLIDE(DiffusionPipeline): self.text_encoder.to(torch_device) # 1. Sample gaussian noise - image = self.noise_scheduler.sample_noise((1, self.unet.in_channels, 64, 64), device=torch_device, generator=generator) + image = self.noise_scheduler.sample_noise( + (1, self.unet.in_channels, 64, 64), device=torch_device, generator=generator + ) # 2. Encode tokens # an empty input is needed to guide the model away from ( @@ -141,9 +145,7 @@ class GLIDE(DiffusionPipeline): t = torch.tensor([i] * image.shape[0], device=torch_device) mean, variance, log_variance, pred_xstart = self.p_mean_variance(self.unet, transformer_out, image, t) noise = self.noise_scheduler.sample_noise(image.shape) - nonzero_mask = ( - (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) - ) # no noise when t == 0 + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0 image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise return image diff --git a/models/vision/glide/run_glide.py b/models/vision/glide/run_glide.py index 4d6d8e2d..2c3eafd2 100644 --- a/models/vision/glide/run_glide.py +++ b/models/vision/glide/run_glide.py @@ -1,6 +1,8 @@ import torch + from modeling_glide import GLIDE + generator = torch.Generator() generator = generator.manual_seed(0) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 4269ac0a..fa1c809b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -5,10 +5,10 @@ __version__ = "0.0.1" from .modeling_utils import ModelMixin +from .models.clip_text_transformer import CLIPTextModel from .models.unet import UNetModel from .models.unet_glide import UNetGLIDEModel from .models.unet_ldm import UNetLDMModel -from .models.clip_text_transformer import CLIPTextModel from .pipeline_utils import DiffusionPipeline -from .schedulers.gaussian_ddpm import GaussianDDPMScheduler from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler +from .schedulers.gaussian_ddpm import GaussianDDPMScheduler diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 721f13a2..c3cf8c59 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -89,7 +89,6 @@ class ConfigMixin: self.to_json_file(output_config_file) logger.info(f"ConfigMixinuration saved in {output_config_file}") - @classmethod def get_config_dict( @@ -183,7 +182,7 @@ class ConfigMixin: logger.info(f"loading configuration file {config_file}") else: logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}") - + return config_dict @classmethod @@ -199,9 +198,8 @@ class ConfigMixin: # use value from config dict init_dict[key] = config_dict.pop(key) - unused_kwargs = config_dict.update(kwargs) - + passed_keys = set(init_dict.keys()) if len(expected_keys - passed_keys) > 0: logger.warn( @@ -212,9 +210,7 @@ class ConfigMixin: @classmethod def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs): - config_dict = cls.get_config_dict( - pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs - ) + config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index f383102d..a9642a48 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -16,7 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .clip_text_transformer import CLIPTextModel from .unet import UNetModel from .unet_glide import UNetGLIDEModel from .unet_ldm import UNetLDMModel -from .clip_text_transformer import CLIPTextModel diff --git a/src/diffusers/models/clip_text_transformer.py b/src/diffusers/models/clip_text_transformer.py index ba25cf3a..1cf5aa92 100644 --- a/src/diffusers/models/clip_text_transformer.py +++ b/src/diffusers/models/clip_text_transformer.py @@ -14,14 +14,15 @@ # limitations under the License. """ PyTorch CLIP model.""" -from dataclasses import dataclass import math +from dataclasses import dataclass from typing import Any, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn +from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig from transformers.activations import ACT2FN from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.modeling_utils import PreTrainedModel @@ -32,7 +33,7 @@ from transformers.utils import ( logging, replace_return_docstrings, ) -from transformers import CLIPModel, CLIPConfig, CLIPVisionConfig, CLIPTextConfig + logger = logging.get_logger(__name__) @@ -153,11 +154,11 @@ class CLIPTextEmbeddings(nn.Module): self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.Tensor] = None, + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] @@ -193,16 +194,15 @@ class CLIPAttention(nn.Module): ) self.scale = 1 / math.sqrt(math.sqrt(self.head_dim)) - self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim*3) + self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - causal_attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -212,9 +212,7 @@ class CLIPAttention(nn.Module): qkv_states = qkv_states.view(bsz, tgt_len, self.num_heads, -1) query_states, key_states, value_states = torch.split(qkv_states, self.head_dim, dim=-1) - attn_weights = torch.einsum( - "bthc,bshc->bhts", query_states * self.scale, key_states * self.scale - ) + attn_weights = torch.einsum("bthc,bshc->bhts", query_states * self.scale, key_states * self.scale) wdtype = attn_weights.dtype attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1).type(wdtype) @@ -252,11 +250,11 @@ class CLIPEncoderLayer(nn.Module): self.layer_norm2 = nn.LayerNorm(self.embed_dim) def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - causal_attention_mask: torch.Tensor, - output_attentions: Optional[bool] = False, + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, ) -> Tuple[torch.FloatTensor]: """ Args: @@ -313,19 +311,19 @@ class CLIPPreTrainedModel(PreTrainedModel): module.padding_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) elif isinstance(module, CLIPVisionEmbeddings): factor = self.config.initializer_factor - nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim ** -0.5 * factor) + nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) elif isinstance(module, CLIPAttention): factor = self.config.initializer_factor - in_proj_std = (module.embed_dim ** -0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor - out_proj_std = (module.embed_dim ** -0.5) * factor + in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + out_proj_std = (module.embed_dim**-0.5) * factor nn.init.normal_(module.qkv_proj.weight, std=in_proj_std) nn.init.normal_(module.out_proj.weight, std=out_proj_std) elif isinstance(module, CLIPMLP): factor = self.config.initializer_factor in_proj_std = ( - (module.config.hidden_size ** -0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor ) fc_std = (2 * module.config.hidden_size) ** -0.5 * factor nn.init.normal_(module.fc1.weight, std=fc_std) @@ -333,11 +331,11 @@ class CLIPPreTrainedModel(PreTrainedModel): elif isinstance(module, CLIPModel): nn.init.normal_( module.text_projection.weight, - std=module.text_embed_dim ** -0.5 * self.config.initializer_factor, + std=module.text_embed_dim**-0.5 * self.config.initializer_factor, ) nn.init.normal_( module.visual_projection.weight, - std=module.vision_embed_dim ** -0.5 * self.config.initializer_factor, + std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, ) if isinstance(module, nn.LayerNorm): @@ -463,13 +461,13 @@ class CLIPEncoder(nn.Module): self.gradient_checkpointing = False def forward( - self, - inputs_embeds, - attention_mask: Optional[torch.Tensor] = None, - causal_attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutput]: r""" Args: @@ -562,13 +560,13 @@ class CLIPTextTransformer(nn.Module): @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: @@ -652,13 +650,13 @@ class CLIPTextModel(CLIPPreTrainedModel): @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: @@ -684,4 +682,4 @@ class CLIPTextModel(CLIPPreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - ) \ No newline at end of file + ) diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index 4b5cc971..4764dbf7 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -470,7 +470,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): self.channel_mult = channel_mult self.conv_resample = conv_resample self.use_checkpoint = use_checkpoint - #self.dtype = torch.float16 if use_fp16 else torch.float32 + # self.dtype = torch.float16 if use_fp16 else torch.float32 self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index daef124a..ccc688c3 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -17,6 +17,7 @@ import importlib import os from typing import Optional, Union + from huggingface_hub import snapshot_download # CHANGE to diffusers.utils @@ -64,7 +65,7 @@ class DiffusionPipeline(ConfigMixin): # set models setattr(self, name, module) - register_dict = {"_module" : self.__module__.split(".")[-1] + ".py"} + register_dict = {"_module": self.__module__.split(".")[-1] + ".py"} self.register(**register_dict) def save_pretrained(self, save_directory: Union[str, os.PathLike]): diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 82084c6c..7311088c 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -16,5 +16,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .gaussian_ddpm import GaussianDDPMScheduler from .classifier_free_guidance import ClassifierFreeGuidanceScheduler +from .gaussian_ddpm import GaussianDDPMScheduler diff --git a/src/diffusers/schedulers/classifier_free_guidance.py b/src/diffusers/schedulers/classifier_free_guidance.py index 17222c17..12ec76a2 100644 --- a/src/diffusers/schedulers/classifier_free_guidance.py +++ b/src/diffusers/schedulers/classifier_free_guidance.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch import math -from torch import nn + import numpy as np +import torch +from torch import nn from ..configuration_utils import ConfigMixin @@ -80,19 +81,13 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin): self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) # calculations for posterior q(x_{t-1} | x_t, x_0) - self.posterior_variance = ( - betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) - ) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain self.posterior_log_variance_clipped = np.log( - np.append(self.posterior_variance[1], self.posterior_variance[1:]) - ) - self.posterior_mean_coef1 = ( - betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) - ) - self.posterior_mean_coef2 = ( - (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + np.append(self.posterior_variance[1], self.posterior_variance[1:]) ) + self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) def sample_noise(self, shape, device, generator=None): # always sample on CPU to be deterministic From d754ce5f3b9d012131f147bb5ddc261402b62adf Mon Sep 17 00:00:00 2001 From: anton-l Date: Wed, 8 Jun 2022 12:32:46 +0200 Subject: [PATCH 5/5] transformer-guided glide sampling --- models/vision/glide/convert_weights.py | 10 ++------ models/vision/glide/modeling_glide.py | 24 +++++++++++++++---- models/vision/glide/run_glide.py | 2 +- src/diffusers/models/unet_glide.py | 22 ++++++++--------- .../schedulers/classifier_free_guidance.py | 8 +++---- 5 files changed, 37 insertions(+), 29 deletions(-) diff --git a/models/vision/glide/convert_weights.py b/models/vision/glide/convert_weights.py index a8016406..10369fca 100644 --- a/models/vision/glide/convert_weights.py +++ b/models/vision/glide/convert_weights.py @@ -22,8 +22,7 @@ config = CLIPTextConfig( use_padding_embeddings=True, ) model = CLIPTextModel(config).eval() -tokenizer = GPT2Tokenizer("./glide-base/vocab.json", "./glide-base/merges.txt", pad_token="<|endoftext|>") -# tokenizer.save_pretrained("./glide-base") +tokenizer = GPT2Tokenizer("./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>") hf_encoder = model.text_model @@ -52,12 +51,6 @@ for layer_idx in range(config.num_hidden_layers): hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"] hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"] -# inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt") -# with torch.no_grad(): -# outputs = model(**inputs) - -# model.save_pretrained("./glide-base") - ### Convert the UNet unet_model = UNetGLIDEModel( @@ -73,6 +66,7 @@ unet_model = UNetGLIDEModel( num_heads_upsample=1, use_scale_shift_norm=True, resblock_updown=True, + transformer_dim=512, ) unet_model.load_state_dict(state_dict, strict=False) diff --git a/models/vision/glide/modeling_glide.py b/models/vision/glide/modeling_glide.py index ecd29637..cc2880d8 100644 --- a/models/vision/glide/modeling_glide.py +++ b/models/vision/glide/modeling_glide.py @@ -130,21 +130,37 @@ class GLIDE(DiffusionPipeline): self.unet.to(torch_device) self.text_encoder.to(torch_device) + # Create a classifier-free guidance sampling function + guidance_scale = 3.0 + + def model_fn(x_t, ts, transformer_out, **kwargs): + half = x_t[: len(x_t) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.unet(combined, ts, transformer_out, **kwargs) + eps, rest = model_out[:, :3], model_out[:, 3:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + # 1. Sample gaussian noise + batch_size = 2 # second image is empty for classifier-free guidance image = self.noise_scheduler.sample_noise( - (1, self.unet.in_channels, 64, 64), device=torch_device, generator=generator + (batch_size, self.unet.in_channels, 64, 64), device=torch_device, generator=generator ) # 2. Encode tokens # an empty input is needed to guide the model away from ( inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt") - transformer_out = self.text_encoder(**inputs).last_hidden_state + input_ids = inputs["input_ids"].to(torch_device) + attention_mask = inputs["attention_mask"].to(torch_device) + transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state num_timesteps = len(self.noise_scheduler) for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps): t = torch.tensor([i] * image.shape[0], device=torch_device) - mean, variance, log_variance, pred_xstart = self.p_mean_variance(self.unet, transformer_out, image, t) - noise = self.noise_scheduler.sample_noise(image.shape) + mean, variance, log_variance, pred_xstart = self.p_mean_variance(model_fn, image, t, transformer_out) + noise = self.noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator) nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0 image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise diff --git a/models/vision/glide/run_glide.py b/models/vision/glide/run_glide.py index 2c3eafd2..1bea36fc 100644 --- a/models/vision/glide/run_glide.py +++ b/models/vision/glide/run_glide.py @@ -9,6 +9,6 @@ generator = generator.manual_seed(0) # 1. Load models pipeline = GLIDE.from_pretrained("fusing/glide-base") -img = pipeline(generator) +img = pipeline("an oil painting of a corgi", generator) print(img) diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index 4764dbf7..97f9b56e 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -435,7 +435,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): num_heads_upsample=-1, use_scale_shift_norm=False, resblock_updown=False, - encoder_channels=None, + transformer_dim=512, ): super().__init__() self.register( @@ -455,7 +455,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): num_heads_upsample=num_heads_upsample, use_scale_shift_norm=use_scale_shift_norm, resblock_updown=resblock_updown, - encoder_channels=encoder_channels, + transformer_dim=transformer_dim, ) if num_heads_upsample == -1: @@ -482,6 +482,8 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): linear(time_embed_dim, time_embed_dim), ) + self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4) + ch = input_ch = int(channel_mult[0] * model_channels) self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]) self._feature_size = ch @@ -508,7 +510,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): use_checkpoint=use_checkpoint, num_heads=num_heads, num_head_channels=num_head_channels, - encoder_channels=encoder_channels, + encoder_channels=transformer_dim, ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) @@ -551,7 +553,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): use_checkpoint=use_checkpoint, num_heads=num_heads, num_head_channels=num_head_channels, - encoder_channels=encoder_channels, + encoder_channels=transformer_dim, ), ResBlock( ch, @@ -587,7 +589,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): use_checkpoint=use_checkpoint, num_heads=num_heads_upsample, num_head_channels=num_head_channels, - encoder_channels=encoder_channels, + encoder_channels=transformer_dim, ) ) if level and i == num_res_blocks: @@ -642,10 +644,6 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): :param y: an [N] Tensor of labels, if class-conditional. :return: an [N x C x ...] Tensor of outputs. """ - assert (y is not None) == ( - self.num_classes is not None - ), "must specify y if and only if the model is class-conditional" - hs = [] emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) @@ -655,13 +653,13 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): emb = emb + transformer_proj.to(emb) - h = x.type(self.dtype) + h = x for module in self.input_blocks: h = module(h, emb, transformer_out) hs.append(h) h = self.middle_block(h, emb, transformer_out) for module in self.output_blocks: - h = torch.cat([h, hs.pop()], dim=1) + other = hs.pop() + h = torch.cat([h, other], dim=1) h = module(h, emb, transformer_out) - h = h.type(x.dtype) return self.out(h) diff --git a/src/diffusers/schedulers/classifier_free_guidance.py b/src/diffusers/schedulers/classifier_free_guidance.py index 12ec76a2..2cd81521 100644 --- a/src/diffusers/schedulers/classifier_free_guidance.py +++ b/src/diffusers/schedulers/classifier_free_guidance.py @@ -65,14 +65,14 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin): if beta_schedule == "squaredcos_cap_v2": # GLIDE cosine schedule - betas = betas_for_alpha_bar( + self.betas = betas_for_alpha_bar( timesteps, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, ) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - alphas = 1.0 - betas + alphas = 1.0 - self.betas self.alphas_cumprod = np.cumprod(alphas, axis=0) self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) @@ -81,12 +81,12 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin): self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) # calculations for posterior q(x_{t-1} | x_t, x_0) - self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain self.posterior_log_variance_clipped = np.log( np.append(self.posterior_variance[1], self.posterior_variance[1:]) ) - self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef1 = self.betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) def sample_noise(self, shape, device, generator=None):