diff --git a/debug_conversion.py b/debug_conversion.py deleted file mode 100755 index a32ce784..00000000 --- a/debug_conversion.py +++ /dev/null @@ -1,110 +0,0 @@ -#!/usr/bin/env python3 -import json -import os - -from regex import P -from diffusers import UNetUnconditionalModel -from scripts.convert_ncsnpp_original_checkpoint_to_diffusers import convert_ncsnpp_checkpoint -from huggingface_hub import hf_hub_download -import torch - - - -def convert_checkpoint(model_id, subfolder=None, checkpoint = "diffusion_model.pt", config = "config.json"): - if subfolder is not None: - checkpoint = os.path.join(subfolder, checkpoint) - config = os.path.join(subfolder, config) - - original_checkpoint = torch.load(hf_hub_download(model_id, checkpoint),map_location='cpu') - config_path = hf_hub_download(model_id, config) - - with open(config_path) as f: - config = json.load(f) - - checkpoint = convert_ncsnpp_checkpoint(original_checkpoint, config) - - - def current_codebase_conversion(path): - model = UNetUnconditionalModel.from_pretrained(model_id, subfolder=subfolder, sde=True) - model.eval() - model.config.sde=False - model.save_config(path) - model.config.sde=True - - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size) - time_step = torch.tensor([10] * noise.shape[0]) - - with torch.no_grad(): - output = model(noise, time_step) - - return model.state_dict() - - path = f"{model_id}_converted" - currently_converted_checkpoint = current_codebase_conversion(path) - - - def diff_between_checkpoints(ch_0, ch_1): - all_layers_included = False - - if not set(ch_0.keys()) == set(ch_1.keys()): - print(f"Contained in ch_0 and not in ch_1 (Total: {len((set(ch_0.keys()) - set(ch_1.keys())))})") - for key in sorted(list((set(ch_0.keys()) - set(ch_1.keys())))): - print(f"\t{key}") - - print(f"Contained in ch_1 and not in ch_0 (Total: {len((set(ch_1.keys()) - set(ch_0.keys())))})") - for key in sorted(list((set(ch_1.keys()) - set(ch_0.keys())))): - print(f"\t{key}") - else: - print("Keys are the same between the two checkpoints") - all_layers_included = True - - keys = ch_0.keys() - non_equal_keys = [] - - if all_layers_included: - for key in keys: - try: - if not torch.allclose(ch_0[key].cpu(), ch_1[key].cpu()): - non_equal_keys.append(f'{key}. Diff: {torch.max(torch.abs(ch_0[key].cpu() - ch_1[key].cpu()))}') - - except RuntimeError as e: - print(e) - non_equal_keys.append(f'{key}. Diff in shape: {ch_0[key].size()} vs {ch_1[key].size()}') - - if len(non_equal_keys): - non_equal_keys = '\n\t'.join(non_equal_keys) - print(f"These keys do not satisfy equivalence requirement:\n\t{non_equal_keys}") - else: - print("All keys are equal across checkpoints.") - - - diff_between_checkpoints(currently_converted_checkpoint, checkpoint) - os.makedirs( f"{model_id}_converted",exist_ok =True) - torch.save(checkpoint, f"{model_id}_converted/diffusion_model.pt") - - -model_ids = ["fusing/ffhq_ncsnpp","fusing/church_256-ncsnpp-ve", "fusing/celebahq_256-ncsnpp-ve", - "fusing/bedroom_256-ncsnpp-ve","fusing/ffhq_256-ncsnpp-ve","fusing/ncsnpp-ffhq-ve-dummy" - ] -for model in model_ids: - print(f"converting {model}") - try: - convert_checkpoint(model) - except Exception as e: - print(e) - -from tests.test_modeling_utils import PipelineTesterMixin, NCSNppModelTests - -tester1 = NCSNppModelTests() -tester2 = PipelineTesterMixin() - -os.environ["RUN_SLOW"] = '1' -cmd = "export RUN_SLOW=1; echo $RUN_SLOW" # or whatever command -os.system(cmd) -tester2.test_score_sde_ve_pipeline(f"{model_ids[0]}_converted") -tester1.test_output_pretrained_ve_mid(f"{model_ids[2]}_converted") -tester1.test_output_pretrained_ve_large(f"{model_ids[-1]}_converted") diff --git a/run.py b/run.py deleted file mode 100755 index cae97139..00000000 --- a/run.py +++ /dev/null @@ -1,153 +0,0 @@ -#!/usr/bin/env python3 -import numpy as np -import PIL -import torch -#from configs.ve import ffhq_ncsnpp_continuous as configs -# from configs.ve import cifar10_ncsnpp_continuous as configs - - -device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') - -torch.backends.cuda.matmul.allow_tf32 = False -torch.manual_seed(0) - - -class NewReverseDiffusionPredictor: - def __init__(self, score_fn, probability_flow=False, sigma_min=0.0, sigma_max=0.0, N=0): - super().__init__() - self.sigma_min = sigma_min - self.sigma_max = sigma_max - self.N = N - self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N)) - - self.probability_flow = probability_flow - self.score_fn = score_fn - - def discretize(self, x, t): - timestep = (t * (self.N - 1)).long() - sigma = self.discrete_sigmas.to(t.device)[timestep] - adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t), - self.discrete_sigmas[timestep - 1].to(t.device)) - f = torch.zeros_like(x) - G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2) - - labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t - result = self.score_fn(x, labels) - - rev_f = f - G[:, None, None, None] ** 2 * result * (0.5 if self.probability_flow else 1.) - rev_G = torch.zeros_like(G) if self.probability_flow else G - return rev_f, rev_G - - def update_fn(self, x, t): - f, G = self.discretize(x, t) - z = torch.randn_like(x) - x_mean = x - f - x = x_mean + G[:, None, None, None] * z - return x, x_mean - - -class NewLangevinCorrector: - def __init__(self, score_fn, snr, n_steps, sigma_min=0.0, sigma_max=0.0): - super().__init__() - self.score_fn = score_fn - self.snr = snr - self.n_steps = n_steps - - self.sigma_min = sigma_min - self.sigma_max = sigma_max - - def update_fn(self, x, t): - score_fn = self.score_fn - n_steps = self.n_steps - target_snr = self.snr -# if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE): -# timestep = (t * (sde.N - 1) / sde.T).long() -# alpha = sde.alphas.to(t.device)[timestep] -# else: - alpha = torch.ones_like(t) - - for i in range(n_steps): - labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t - grad = score_fn(x, labels) - noise = torch.randn_like(x) - grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean() - noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() - step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha - x_mean = x + step_size[:, None, None, None] * grad - x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise - - return x, x_mean - - - -def save_image(x): - image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8) - image_pil = PIL.Image.fromarray(image_processed[0]) - image_pil.save("../images/hey.png") - - -# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth" -#ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth" -# Note usually we need to restore ema etc... -# ema restored checkpoint used from below - -N = 2 -sigma_min = 0.01 -sigma_max = 1348 -sampling_eps = 1e-5 -batch_size = 1 -centered = False - -from diffusers import NCSNpp - -model = NCSNpp.from_pretrained("/home/patrick/ffhq_ncsnpp").to(device) -model = torch.nn.DataParallel(model) - -img_size = model.module.config.image_size -channels = model.module.config.num_channels -shape = (batch_size, channels, img_size, img_size) -probability_flow = False -snr = 0.15 -n_steps = 1 - - -new_corrector = NewLangevinCorrector(score_fn=model, snr=snr, n_steps=n_steps, sigma_min=sigma_min, sigma_max=sigma_max) -new_predictor = NewReverseDiffusionPredictor(score_fn=model, sigma_min=sigma_min, sigma_max=sigma_max, N=N) - -with torch.no_grad(): - # Initial sample - x = torch.randn(*shape) * sigma_max - x = x.to(device) - timesteps = torch.linspace(1, sampling_eps, N, device=device) - - for i in range(N): - t = timesteps[i] - vec_t = torch.ones(shape[0], device=t.device) * t - x, x_mean = new_corrector.update_fn(x, vec_t) - x, x_mean = new_predictor.update_fn(x, vec_t) - - x = x_mean - if centered: - x = (x + 1.) / 2. - - -# save_image(x) - -# for 5 cifar10 -x_sum = 106071.9922 -x_mean = 34.52864456176758 - -# for 1000 cifar10 -x_sum = 461.9700 -x_mean = 0.1504 - -# for 2 for 1024 -x_sum = 3382810112.0 -x_mean = 1075.366455078125 - -def check_x_sum_x_mean(x, x_sum, x_mean): - assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}" - assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" - - -check_x_sum_x_mean(x, x_sum, x_mean) diff --git a/scripts/conversion_glide.py b/scripts/conversion_glide.py deleted file mode 100644 index 6cf0133d..00000000 --- a/scripts/conversion_glide.py +++ /dev/null @@ -1,113 +0,0 @@ -import torch -from torch import nn - -from diffusers import ClassifierFreeGuidanceScheduler, DDIMScheduler, GlideSuperResUNetModel, GlideTextToImageUNetModel -from diffusers.pipelines.pipeline_glide import Glide, CLIPTextModel -from transformers import CLIPTextConfig, GPT2Tokenizer - - -# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt -state_dict = torch.load("base.pt", map_location="cpu") -state_dict = {k: nn.Parameter(v) for k, v in state_dict.items()} - -### Convert the text encoder - -config = CLIPTextConfig( - vocab_size=50257, - max_position_embeddings=128, - hidden_size=512, - intermediate_size=2048, - num_hidden_layers=16, - num_attention_heads=8, - use_padding_embeddings=True, -) -model = CLIPTextModel(config).eval() -tokenizer = GPT2Tokenizer( - "./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>" -) - -hf_encoder = model.text_model - -hf_encoder.embeddings.token_embedding.weight = state_dict["token_embedding.weight"] -hf_encoder.embeddings.position_embedding.weight.data = state_dict["positional_embedding"] -hf_encoder.embeddings.padding_embedding.weight.data = state_dict["padding_embedding"] - -hf_encoder.final_layer_norm.weight = state_dict["final_ln.weight"] -hf_encoder.final_layer_norm.bias = state_dict["final_ln.bias"] - -for layer_idx in range(config.num_hidden_layers): - hf_layer = hf_encoder.encoder.layers[layer_idx] - hf_layer.self_attn.qkv_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.weight"] - hf_layer.self_attn.qkv_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.bias"] - - hf_layer.self_attn.out_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.weight"] - hf_layer.self_attn.out_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.bias"] - - hf_layer.layer_norm1.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.weight"] - hf_layer.layer_norm1.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.bias"] - hf_layer.layer_norm2.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.weight"] - hf_layer.layer_norm2.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.bias"] - - hf_layer.mlp.fc1.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.weight"] - hf_layer.mlp.fc1.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.bias"] - hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"] - hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"] - -### Convert the Text-to-Image UNet - -text2im_model = GlideTextToImageUNetModel( - in_channels=3, - model_channels=192, - out_channels=6, - num_res_blocks=3, - attention_resolutions=(2, 4, 8), - dropout=0.1, - channel_mult=(1, 2, 3, 4), - num_heads=1, - num_head_channels=64, - num_heads_upsample=1, - use_scale_shift_norm=True, - resblock_updown=True, - transformer_dim=512, -) - -text2im_model.load_state_dict(state_dict, strict=False) - -text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2") - -### Convert the Super-Resolution UNet - -# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt -ups_state_dict = torch.load("upsample.pt", map_location="cpu") - -superres_model = GlideSuperResUNetModel( - in_channels=6, - model_channels=192, - out_channels=6, - num_res_blocks=2, - attention_resolutions=(8, 16, 32), - dropout=0.1, - channel_mult=(1, 1, 2, 2, 4, 4), - num_heads=1, - num_head_channels=64, - num_heads_upsample=1, - use_scale_shift_norm=True, - resblock_updown=True, -) - -superres_model.load_state_dict(ups_state_dict, strict=False) - -upscale_scheduler = DDIMScheduler( - timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02, tensor_format="pt" -) - -glide = Glide( - text_unet=text2im_model, - text_noise_scheduler=text_scheduler, - text_encoder=model, - tokenizer=tokenizer, - upscale_unet=superres_model, - upscale_noise_scheduler=upscale_scheduler, -) - -glide.save_pretrained("./glide-base") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 220f0704..6f504677 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -7,36 +7,13 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode __version__ = "0.0.4" from .modeling_utils import ModelMixin -from .models import ( - AutoencoderKL, - NCSNpp, - UNetConditionalModel, - UNetLDMModel, - UNetModel, - UNetUnconditionalModel, - VQModel, -) +from .models import AutoencoderKL, UNetConditionalModel, UNetUnconditionalModel, VQModel from .pipeline_utils import DiffusionPipeline -from .pipelines import ( - DDIMPipeline, - DDPMPipeline, - LatentDiffusionUncondPipeline, - PNDMPipeline, - ScoreSdeVePipeline, - ScoreSdeVpPipeline, -) -from .schedulers import ( - DDIMScheduler, - DDPMScheduler, - PNDMScheduler, - SchedulerMixin, - ScoreSdeVeScheduler, - ScoreSdeVpScheduler, -) +from .pipelines import DDIMPipeline, DDPMPipeline, LatentDiffusionUncondPipeline, PNDMPipeline, ScoreSdeVePipeline +from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler if is_transformers_available(): - from .models.unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel - from .pipelines import GlidePipeline, LatentDiffusionPipeline + from .pipelines import LatentDiffusionPipeline else: from .utils.dummy_transformers_objects import * diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 2a491c69..f3b2fe9e 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -16,10 +16,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .unet import UNetModel from .unet_conditional import UNetConditionalModel -from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel -from .unet_ldm import UNetLDMModel -from .unet_sde_score_estimation import NCSNpp from .unet_unconditional import UNetUnconditionalModel from .vae import AutoencoderKL, VQModel diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 5872c2a4..8b52859e 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -54,6 +54,43 @@ def get_timestep_embedding( return emb +class TimestepEmbedding(nn.Module): + def __init__(self, channel, time_embed_dim, act_fn="silu"): + super().__init__() + + self.linear_1 = nn.Linear(channel, time_embed_dim) + self.act = None + if act_fn == "silu": + self.act = nn.SiLU() + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) + + def forward(self, sample): + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + return sample + + +class Timesteps(nn.Module): + def __init__(self, num_channels, flip_sin_to_cos, downscale_freq_shift): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + ) + return t_emb + + class GaussianFourierProjection(nn.Module): """Gaussian Fourier embeddings for noise levels.""" diff --git a/src/diffusers/models/unet.py b/src/diffusers/models/unet.py deleted file mode 100644 index 7c497697..00000000 --- a/src/diffusers/models/unet.py +++ /dev/null @@ -1,195 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and - -# limitations under the License. - -# helpers functions - -import torch -from torch import nn - -from ..configuration_utils import ConfigMixin -from ..modeling_utils import ModelMixin -from .attention import AttentionBlock -from .embeddings import get_timestep_embedding -from .resnet import Downsample2D, ResnetBlock2D, Upsample2D -from .unet_new import UNetMidBlock2D - - -def nonlinearity(x): - # swish - return x * torch.sigmoid(x) - - -def Normalize(in_channels): - return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - - -class UNetModel(ModelMixin, ConfigMixin): - def __init__( - self, - ch=128, - out_ch=3, - ch_mult=(1, 1, 2, 2, 4, 4), - num_res_blocks=2, - attn_resolutions=(16,), - dropout=0.0, - resamp_with_conv=True, - in_channels=3, - resolution=256, - ): - super().__init__() - self.register_to_config( - ch=ch, - out_ch=out_ch, - ch_mult=ch_mult, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - in_channels=in_channels, - resolution=resolution, - ) - ch_mult = tuple(ch_mult) - self.ch = ch - self.temb_ch = self.ch * 4 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - # timestep embedding - self.temb = nn.Module() - self.temb.dense = nn.ModuleList( - [ - torch.nn.Linear(self.ch, self.temb_ch), - torch.nn.Linear(self.temb_ch, self.temb_ch), - ] - ) - - # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) - - curr_res = resolution - in_ch_mult = (1,) + ch_mult - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock2D( - in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttentionBlock(block_in, overwrite_qkv=True)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - down.downsample = Downsample2D(block_in, use_conv=resamp_with_conv, padding=0) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock2D( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True) - self.mid.block_2 = ResnetBlock2D( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - self.mid_new = UNetMidBlock2D(in_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) - self.mid_new.resnets[0] = self.mid.block_1 - self.mid_new.attentions[0] = self.mid.attn_1 - self.mid_new.resnets[1] = self.mid.block_2 - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - skip_in = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - if i_block == self.num_res_blocks: - skip_in = ch * in_ch_mult[i_level] - block.append( - ResnetBlock2D( - in_channels=block_in + skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttentionBlock(block_in, overwrite_qkv=True)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample2D(block_in, use_conv=resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) - - def forward(self, sample, timesteps): - x = sample - assert x.shape[2] == x.shape[3] == self.resolution - - if not torch.is_tensor(timesteps): - timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device) - - # timestep embedding - temb = get_timestep_embedding(timesteps, self.ch) - temb = self.temb.dense[0](temb) - temb = nonlinearity(temb) - temb = self.temb.dense[1](temb) - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = self.mid_new(hs[-1], temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h diff --git a/src/diffusers/models/unet_new.py b/src/diffusers/models/unet_blocks.py similarity index 100% rename from src/diffusers/models/unet_new.py rename to src/diffusers/models/unet_blocks.py diff --git a/src/diffusers/models/unet_conditional.py b/src/diffusers/models/unet_conditional.py index 3e0eb4c9..a034e3f8 100644 --- a/src/diffusers/models/unet_conditional.py +++ b/src/diffusers/models/unet_conditional.py @@ -1,74 +1,12 @@ -import functools -import math from typing import Dict, Union -import numpy as np import torch import torch.nn as nn from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin -from .attention import AttentionBlock, SpatialTransformer -from .embeddings import GaussianFourierProjection, get_timestep_embedding -from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D -from .unet_new import UNetMidBlock2DCrossAttn, get_down_block, get_up_block - - -class Combine(nn.Module): - """Combine information from skip connections.""" - - def __init__(self, dim1, dim2, method="cat"): - super().__init__() - # 1x1 convolution with DDPM initialization. - self.Conv_0 = nn.Conv2d(dim1, dim2, kernel_size=1, padding=0) - self.method = method - - -# def forward(self, x, y): -# h = self.Conv_0(x) -# if self.method == "cat": -# return torch.cat([h, y], dim=1) -# elif self.method == "sum": -# return h + y -# else: -# raise ValueError(f"Method {self.method} not recognized.") - - -class TimestepEmbedding(nn.Module): - def __init__(self, channel, time_embed_dim, act_fn="silu"): - super().__init__() - - self.linear_1 = nn.Linear(channel, time_embed_dim) - self.act = None - if act_fn == "silu": - self.act = nn.SiLU() - self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) - - def forward(self, sample): - sample = self.linear_1(sample) - - if self.act is not None: - sample = self.act(sample) - - sample = self.linear_2(sample) - return sample - - -class Timesteps(nn.Module): - def __init__(self, num_channels, flip_sin_to_cos, downscale_freq_shift): - super().__init__() - self.num_channels = num_channels - self.flip_sin_to_cos = flip_sin_to_cos - self.downscale_freq_shift = downscale_freq_shift - - def forward(self, timesteps): - t_emb = get_timestep_embedding( - timesteps, - self.num_channels, - flip_sin_to_cos=self.flip_sin_to_cos, - downscale_freq_shift=self.downscale_freq_shift, - ) - return t_emb +from .embeddings import TimestepEmbedding, Timesteps +from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block class UNetConditionalModel(ModelMixin, ConfigMixin): @@ -124,38 +62,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin): downscale_freq_shift=0, mid_block_scale_factor=1, center_input_sample=False, - # TODO(PVP) - to delete later at release - # IMPORTANT: NOT RELEVANT WHEN REVIEWING API - # ====================================== - # LDM - attention_resolutions=(4, 2, 1), - # DDPM - out_ch=None, - resolution=None, - attn_resolutions=None, - resamp_with_conv=None, - ch_mult=None, - ch=None, - ddpm=False, - # SDE - sde=False, - nf=None, - fir=None, - progressive=None, - progressive_combine=None, - scale_by_sigma=None, - skip_rescale=None, - num_channels=None, - centered=False, - conditional=True, - conv_size=3, - fir_kernel=(1, 3, 3, 1), - fourier_scale=16, - init_scale=0.0, - progressive_input="input_skip", - resnet_num_groups=32, - continuous=True, - ldm=False, + resnet_num_groups=30, ): super().__init__() # register all __init__ params to be accessible via `self.config.<...>` @@ -175,21 +82,13 @@ class UNetConditionalModel(ModelMixin, ConfigMixin): num_head_channels=num_head_channels, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=downscale_freq_shift, - attention_resolutions=attention_resolutions, - attn_resolutions=attn_resolutions, mid_block_scale_factor=mid_block_scale_factor, resnet_num_groups=resnet_num_groups, center_input_sample=center_input_sample, ) - self.ldm = ldm - - # TODO(PVP) - to delete later at release - # IMPORTANT: NOT RELEVANT WHEN REVIEWING API - # ====================================== self.image_size = image_size time_embed_dim = block_channels[0] * 4 - # ====================================== # input self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1)) @@ -264,57 +163,18 @@ class UNetConditionalModel(ModelMixin, ConfigMixin): prev_output_channel = output_channel # out - num_groups_out = resnet_num_groups if resnet_num_groups is not None else min(block_channels[0] // 4, 32) - self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=num_groups_out, eps=resnet_eps) + self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=resnet_num_groups, eps=resnet_eps) self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1) - # ======================== Out ==================== - - # =========== TO DELETE AFTER CONVERSION ========== - # TODO(PVP) - to delete later at release - # IMPORTANT: NOT RELEVANT WHEN REVIEWING API - # ====================================== - self.is_overwritten = False - if ldm: - num_heads = 8 - num_head_channels = -1 - transformer_depth = 1 - use_spatial_transformer = True - context_dim = 1280 - legacy = False - model_channels = block_channels[0] - channel_mult = tuple([x // model_channels for x in block_channels]) - self.init_for_ldm( - in_channels, - model_channels, - channel_mult, - num_res_blocks, - dropout, - time_embed_dim, - attention_resolutions, - num_head_channels, - num_heads, - legacy, - False, - transformer_depth, - context_dim, - conv_resample, - out_channels, - ) - def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, ) -> Dict[str, torch.FloatTensor]: - # TODO(PVP) - to delete later at release - # IMPORTANT: NOT RELEVANT WHEN REVIEWING API - # ====================================== - if not self.is_overwritten: - self.set_weights() + # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 @@ -329,7 +189,6 @@ class UNetConditionalModel(ModelMixin, ConfigMixin): emb = self.time_embedding(t_emb) # 2. pre-process - skip_sample = sample sample = self.conv_in(sample) # 3. down @@ -349,7 +208,6 @@ class UNetConditionalModel(ModelMixin, ConfigMixin): sample = self.mid(sample, emb, encoder_hidden_states=encoder_hidden_states) # 5. up - skip_sample = None for upsample_block in self.upsample_blocks: res_samples = down_block_res_samples[-len(upsample_block.resnets) :] @@ -374,259 +232,3 @@ class UNetConditionalModel(ModelMixin, ConfigMixin): output = {"sample": sample} return output - - # !!!IMPORTANT - ALL OF THE FOLLOWING CODE WILL BE DELETED AT RELEASE TIME AND SHOULD NOT BE TAKEN INTO CONSIDERATION WHEN EVALUATING THE API ### - # ================================================================================================================================================= - - def set_weights(self): - self.is_overwritten = True - if self.ldm: - self.time_embedding.linear_1.weight.data = self.time_embed[0].weight.data - self.time_embedding.linear_1.bias.data = self.time_embed[0].bias.data - self.time_embedding.linear_2.weight.data = self.time_embed[2].weight.data - self.time_embedding.linear_2.bias.data = self.time_embed[2].bias.data - - self.conv_in.weight.data = self.input_blocks[0][0].weight.data - self.conv_in.bias.data = self.input_blocks[0][0].bias.data - - # ================ SET WEIGHTS OF ALL WEIGHTS ================== - for i, input_layer in enumerate(self.input_blocks[1:]): - block_id = i // (self.config.num_res_blocks + 1) - layer_in_block_id = i % (self.config.num_res_blocks + 1) - - if layer_in_block_id == 2: - self.downsample_blocks[block_id].downsamplers[0].conv.weight.data = input_layer[0].op.weight.data - self.downsample_blocks[block_id].downsamplers[0].conv.bias.data = input_layer[0].op.bias.data - elif len(input_layer) > 1: - self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0]) - self.downsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1]) - else: - self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0]) - - self.mid.resnets[0].set_weight(self.middle_block[0]) - self.mid.resnets[1].set_weight(self.middle_block[2]) - self.mid.attentions[0].set_weight(self.middle_block[1]) - - for i, input_layer in enumerate(self.output_blocks): - block_id = i // (self.config.num_res_blocks + 1) - layer_in_block_id = i % (self.config.num_res_blocks + 1) - - if len(input_layer) > 2: - self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0]) - self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1]) - self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[2].conv.weight.data - self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[2].conv.bias.data - elif len(input_layer) > 1 and "Upsample2D" in input_layer[1].__class__.__name__: - self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0]) - self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[1].conv.weight.data - self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[1].conv.bias.data - elif len(input_layer) > 1: - self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0]) - self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1]) - else: - self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0]) - - self.conv_norm_out.weight.data = self.out[0].weight.data - self.conv_norm_out.bias.data = self.out[0].bias.data - self.conv_out.weight.data = self.out[2].weight.data - self.conv_out.bias.data = self.out[2].bias.data - - self.remove_ldm() - - def init_for_ldm( - self, - in_channels, - model_channels, - channel_mult, - num_res_blocks, - dropout, - time_embed_dim, - attention_resolutions, - num_head_channels, - num_heads, - legacy, - use_spatial_transformer, - transformer_depth, - context_dim, - conv_resample, - out_channels, - ): - # TODO(PVP) - delete after weight conversion - class TimestepEmbedSequential(nn.Sequential): - """ - A sequential module that passes timestep embeddings to the children that support it as an extra input. - """ - - pass - - # TODO(PVP) - delete after weight conversion - def conv_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D convolution module. - """ - if dims == 1: - return nn.Conv1d(*args, **kwargs) - elif dims == 2: - return nn.Conv2d(*args, **kwargs) - elif dims == 3: - return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - self.time_embed = nn.Sequential( - nn.Linear(model_channels, time_embed_dim), - nn.SiLU(), - nn.Linear(time_embed_dim, time_embed_dim), - ) - - dims = 2 - self.input_blocks = nn.ModuleList( - [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))] - ) - - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - for level, mult in enumerate(channel_mult): - for _ in range(num_res_blocks): - layers = [ - ResnetBlock2D( - in_channels=ch, - out_channels=mult * model_channels, - dropout=dropout, - temb_channels=time_embed_dim, - eps=1e-5, - non_linearity="silu", - overwrite_for_ldm=True, - ) - ] - ch = mult * model_channels - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - # num_heads = 1 - dim_head = num_head_channels - layers.append( - SpatialTransformer( - ch, - num_heads, - dim_head, - depth=transformer_depth, - context_dim=context_dim, - ), - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op") - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - # num_heads = 1 - dim_head = num_head_channels - - if dim_head < 0: - dim_head = None - - # TODO(Patrick) - delete after weight conversion - # init to be able to overwrite `self.mid` - self.middle_block = TimestepEmbedSequential( - ResnetBlock2D( - in_channels=ch, - out_channels=None, - dropout=dropout, - temb_channels=time_embed_dim, - eps=1e-5, - non_linearity="silu", - overwrite_for_ldm=True, - ), - SpatialTransformer( - ch, - num_heads, - dim_head, - depth=transformer_depth, - context_dim=context_dim, - ), - ResnetBlock2D( - in_channels=ch, - out_channels=None, - dropout=dropout, - temb_channels=time_embed_dim, - eps=1e-5, - non_linearity="silu", - overwrite_for_ldm=True, - ), - ) - self._feature_size += ch - - self.output_blocks = nn.ModuleList([]) - for level, mult in list(enumerate(channel_mult))[::-1]: - for i in range(num_res_blocks + 1): - ich = input_block_chans.pop() - layers = [ - ResnetBlock2D( - in_channels=ch + ich, - out_channels=model_channels * mult, - dropout=dropout, - temb_channels=time_embed_dim, - eps=1e-5, - non_linearity="silu", - overwrite_for_ldm=True, - ), - ] - ch = model_channels * mult - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - # num_heads = 1 - dim_head = num_head_channels - layers.append( - SpatialTransformer( - ch, - num_heads, - dim_head, - depth=transformer_depth, - context_dim=context_dim, - ) - ) - if level and i == num_res_blocks: - out_ch = ch - layers.append(Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch)) - ds //= 2 - self.output_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - - self.out = nn.Sequential( - nn.GroupNorm(num_channels=model_channels, num_groups=32, eps=1e-5), - nn.SiLU(), - nn.Conv2d(model_channels, out_channels, 3, padding=1), - ) - - def remove_ldm(self): - del self.time_embed - del self.input_blocks - del self.middle_block - del self.output_blocks - del self.out diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py deleted file mode 100644 index e26fcbdd..00000000 --- a/src/diffusers/models/unet_glide.py +++ /dev/null @@ -1,554 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from ..configuration_utils import ConfigMixin -from ..modeling_utils import ModelMixin -from .attention import AttentionBlock -from .embeddings import get_timestep_embedding -from .resnet import Downsample2D, ResnetBlock2D, Upsample2D -from .unet_new import UNetMidBlock2D - - -def convert_module_to_f16(l): - """ - Convert primitive modules to float16. - """ - if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): - l.weight.data = l.weight.data.half() - if l.bias is not None: - l.bias.data = l.bias.data.half() - - -def convert_module_to_f32(l): - """ - Convert primitive modules to float32, undoing convert_module_to_f16(). - """ - if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): - l.weight.data = l.weight.data.float() - if l.bias is not None: - l.bias.data = l.bias.data.float() - - -def conv_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D convolution module. - """ - if dims == 1: - return nn.Conv1d(*args, **kwargs) - elif dims == 2: - return nn.Conv2d(*args, **kwargs) - elif dims == 3: - return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -def linear(*args, **kwargs): - """ - Create a linear module. - """ - return nn.Linear(*args, **kwargs) - - -class GroupNorm32(nn.GroupNorm): - def __init__(self, num_groups, num_channels, swish, eps=1e-5): - super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps) - self.swish = swish - - def forward(self, x): - y = super().forward(x.float()).to(x.dtype) - if self.swish == 1.0: - y = F.silu(y) - elif self.swish: - y = y * F.sigmoid(y * float(self.swish)) - return y - - -def normalization(channels, swish=0.0): - """ - Make a standard normalization layer, with an optional swish activation. - - :param channels: number of input channels. :return: an nn.Module for normalization. - """ - return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -class TimestepEmbedSequential(nn.Sequential): - """ - A sequential module that passes timestep embeddings to the children that support it as an extra input. - """ - - def forward(self, x, emb, encoder_out=None): - for layer in self: - if isinstance(layer, ResnetBlock2D) or isinstance(layer, TimestepEmbedSequential): - x = layer(x, emb) - elif isinstance(layer, AttentionBlock): - x = layer(x, encoder_out) - else: - x = layer(x) - return x - - -class GlideUNetModel(ModelMixin, ConfigMixin): - """ - The full UNet model with attention and timestep embedding. - - :param in_channels: channels in the input Tensor. :param model_channels: base channel count for the model. :param - out_channels: channels in the output Tensor. :param num_res_blocks: number of residual blocks per downsample. - :param attention_resolutions: a collection of downsample rates at which - attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x - downsampling, attention will be used. - :param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param - conv_resample: if True, use learned convolutions for upsampling and - downsampling. - :param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this - model will be - class-conditional with `num_classes` classes. - :param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention - heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use - a fixed channel width per attention head. - :param num_heads_upsample: works with num_heads to set a different number - of heads for upsampling. Deprecated. - :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks - for up/downsampling. - """ - - def __init__( - self, - in_channels=3, - resolution=64, - model_channels=192, - out_channels=6, - num_res_blocks=3, - attention_resolutions=(2, 4, 8), - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - use_checkpoint=False, - use_fp16=False, - num_heads=1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - transformer_dim=None, - ): - super().__init__() - - if num_heads_upsample == -1: - num_heads_upsample = num_heads - - self.in_channels = in_channels - self.resolution = resolution - self.model_channels = model_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.attention_resolutions = attention_resolutions - self.dropout = dropout - self.channel_mult = channel_mult - self.conv_resample = conv_resample - self.use_checkpoint = use_checkpoint - # self.dtype = torch.float16 if use_fp16 else torch.float32 - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - - time_embed_dim = model_channels * 4 - self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - ch = input_ch = int(channel_mult[0] * model_channels) - self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]) - self._feature_size = ch - input_block_chans = [ch] - ds = 1 - for level, mult in enumerate(channel_mult): - for _ in range(num_res_blocks): - layers = [ - ResnetBlock2D( - in_channels=ch, - out_channels=mult * model_channels, - dropout=dropout, - temb_channels=time_embed_dim, - eps=1e-5, - non_linearity="silu", - time_embedding_norm="scale_shift" if use_scale_shift_norm else "default", - overwrite_for_glide=True, - ) - ] - ch = int(mult * model_channels) - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - num_heads=num_heads, - num_head_channels=num_head_channels, - encoder_channels=transformer_dim, - ) - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - ResnetBlock2D( - in_channels=ch, - out_channels=out_ch, - dropout=dropout, - temb_channels=time_embed_dim, - eps=1e-5, - non_linearity="silu", - time_embedding_norm="scale_shift" if use_scale_shift_norm else "default", - overwrite_for_glide=True, - down=True, - ) - if resblock_updown - else Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op") - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - self.mid = UNetMidBlock2D( - in_channels=ch, - dropout=dropout, - temb_channels=time_embed_dim, - resnet_eps=1e-5, - resnet_act_fn="silu", - resnet_time_scale_shift="scale_shift" if use_scale_shift_norm else "default", - attn_num_heads=num_heads, - attn_num_head_channels=num_head_channels, - attn_encoder_channels=transformer_dim, - ) - - # TODO(Patrick) - delete after weight conversion - # init to be able to overwrite `self.mid` - self.middle_block = TimestepEmbedSequential( - ResnetBlock2D( - in_channels=ch, - dropout=dropout, - temb_channels=time_embed_dim, - eps=1e-5, - non_linearity="silu", - time_embedding_norm="scale_shift" if use_scale_shift_norm else "default", - overwrite_for_glide=True, - ), - AttentionBlock( - ch, - num_heads=num_heads, - num_head_channels=num_head_channels, - encoder_channels=transformer_dim, - ), - ResnetBlock2D( - in_channels=ch, - dropout=dropout, - temb_channels=time_embed_dim, - eps=1e-5, - non_linearity="silu", - time_embedding_norm="scale_shift" if use_scale_shift_norm else "default", - overwrite_for_glide=True, - ), - ) - self.mid.resnets[0] = self.middle_block[0] - self.mid.attentions[0] = self.middle_block[1] - self.mid.resnets[1] = self.middle_block[2] - - self._feature_size += ch - - self.output_blocks = nn.ModuleList([]) - for level, mult in list(enumerate(channel_mult))[::-1]: - for i in range(num_res_blocks + 1): - ich = input_block_chans.pop() - layers = [ - ResnetBlock2D( - in_channels=ch + ich, - out_channels=model_channels * mult, - dropout=dropout, - temb_channels=time_embed_dim, - eps=1e-5, - non_linearity="silu", - time_embedding_norm="scale_shift" if use_scale_shift_norm else "default", - overwrite_for_glide=True, - ), - ] - ch = int(model_channels * mult) - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - num_heads=num_heads_upsample, - num_head_channels=num_head_channels, - encoder_channels=transformer_dim, - ) - ) - if level and i == num_res_blocks: - out_ch = ch - layers.append( - ResnetBlock2D( - in_channels=ch, - out_channels=out_ch, - dropout=dropout, - temb_channels=time_embed_dim, - eps=1e-5, - non_linearity="silu", - time_embedding_norm="scale_shift" if use_scale_shift_norm else "default", - overwrite_for_glide=True, - up=True, - ) - if resblock_updown - else Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch) - ) - ds //= 2 - self.output_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - - self.out = nn.Sequential( - normalization(ch, swish=1.0), - nn.Identity(), - zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)), - ) - self.use_fp16 = use_fp16 - - def convert_to_fp16(self): - """ - Convert the torso of the model to float16. - """ - self.input_blocks.apply(convert_module_to_f16) - self.middle_block.apply(convert_module_to_f16) - self.output_blocks.apply(convert_module_to_f16) - - def convert_to_fp32(self): - """ - Convert the torso of the model to float32. - """ - self.input_blocks.apply(convert_module_to_f32) - self.middle_block.apply(convert_module_to_f32) - self.output_blocks.apply(convert_module_to_f32) - - def forward(self, x, timesteps): - """ - Apply the model to an input batch. - - :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. :param y: an [N] - Tensor of labels, if class-conditional. :return: an [N x C x ...] Tensor of outputs. - """ - - hs = [] - emb = self.time_embed( - get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0) - ) - - h = x.type(self.dtype) - for module in self.input_blocks: - h = module(h, emb) - hs.append(h) - h = self.mid(h, emb) - for module in self.output_blocks: - h = torch.cat([h, hs.pop()], dim=1) - h = module(h, emb) - h = h.type(x.dtype) - return self.out(h) - - -class GlideTextToImageUNetModel(GlideUNetModel): - """ - A UNetModel that performs super-resolution. - - Expects an extra kwarg `low_res` to condition on a low-resolution image. - """ - - def __init__( - self, - in_channels=3, - resolution=64, - model_channels=192, - out_channels=6, - num_res_blocks=3, - attention_resolutions=(2, 4, 8), - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - use_checkpoint=False, - use_fp16=False, - num_heads=1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - transformer_dim=512, - ): - super().__init__( - in_channels=in_channels, - resolution=resolution, - model_channels=model_channels, - out_channels=out_channels, - num_res_blocks=num_res_blocks, - attention_resolutions=attention_resolutions, - dropout=dropout, - channel_mult=channel_mult, - conv_resample=conv_resample, - dims=dims, - use_checkpoint=use_checkpoint, - use_fp16=use_fp16, - num_heads=num_heads, - num_head_channels=num_head_channels, - num_heads_upsample=num_heads_upsample, - use_scale_shift_norm=use_scale_shift_norm, - resblock_updown=resblock_updown, - transformer_dim=transformer_dim, - ) - self.register_to_config( - in_channels=in_channels, - resolution=resolution, - model_channels=model_channels, - out_channels=out_channels, - num_res_blocks=num_res_blocks, - attention_resolutions=attention_resolutions, - dropout=dropout, - channel_mult=channel_mult, - conv_resample=conv_resample, - dims=dims, - use_checkpoint=use_checkpoint, - use_fp16=use_fp16, - num_heads=num_heads, - num_head_channels=num_head_channels, - num_heads_upsample=num_heads_upsample, - use_scale_shift_norm=use_scale_shift_norm, - resblock_updown=resblock_updown, - transformer_dim=transformer_dim, - ) - - self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4) - - def forward(self, sample, timestep, transformer_out=None): - timesteps = timestep - x = sample - hs = [] - emb = self.time_embed( - get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0) - ) - - # project the last token - transformer_proj = self.transformer_proj(transformer_out[:, -1]) - transformer_out = transformer_out.permute(0, 2, 1) # NLC -> NCL - - emb = emb + transformer_proj.to(emb) - - h = x - for module in self.input_blocks: - h = module(h, emb, transformer_out) - hs.append(h) - h = self.mid(h, emb, transformer_out) - for module in self.output_blocks: - other = hs.pop() - h = torch.cat([h, other], dim=1) - h = module(h, emb, transformer_out) - return self.out(h) - - -class GlideSuperResUNetModel(GlideUNetModel): - """ - A UNetModel that performs super-resolution. - - Expects an extra kwarg `low_res` to condition on a low-resolution image. - """ - - def __init__( - self, - in_channels=3, - resolution=256, - model_channels=192, - out_channels=6, - num_res_blocks=3, - attention_resolutions=(2, 4, 8), - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - use_checkpoint=False, - use_fp16=False, - num_heads=1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - ): - super().__init__( - in_channels=in_channels, - resolution=resolution, - model_channels=model_channels, - out_channels=out_channels, - num_res_blocks=num_res_blocks, - attention_resolutions=attention_resolutions, - dropout=dropout, - channel_mult=channel_mult, - conv_resample=conv_resample, - dims=dims, - use_checkpoint=use_checkpoint, - use_fp16=use_fp16, - num_heads=num_heads, - num_head_channels=num_head_channels, - num_heads_upsample=num_heads_upsample, - use_scale_shift_norm=use_scale_shift_norm, - resblock_updown=resblock_updown, - ) - self.register_to_config( - in_channels=in_channels, - resolution=resolution, - model_channels=model_channels, - out_channels=out_channels, - num_res_blocks=num_res_blocks, - attention_resolutions=attention_resolutions, - dropout=dropout, - channel_mult=channel_mult, - conv_resample=conv_resample, - dims=dims, - use_checkpoint=use_checkpoint, - use_fp16=use_fp16, - num_heads=num_heads, - num_head_channels=num_head_channels, - num_heads_upsample=num_heads_upsample, - use_scale_shift_norm=use_scale_shift_norm, - resblock_updown=resblock_updown, - ) - - def forward(self, sample, timestep, low_res=None): - timesteps = timestep - x = sample - _, _, new_height, new_width = x.shape - upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") - x = torch.cat([x, upsampled], dim=1) - - hs = [] - emb = self.time_embed( - get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0) - ) - - h = x - for module in self.input_blocks: - h = module(h, emb) - hs.append(h) - h = self.mid(h, emb) - for module in self.output_blocks: - h = torch.cat([h, hs.pop()], dim=1) - h = module(h, emb) - - return self.out(h) diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py deleted file mode 100644 index 7913abd6..00000000 --- a/src/diffusers/models/unet_ldm.py +++ /dev/null @@ -1,627 +0,0 @@ -import math -from inspect import isfunction - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - -from ..configuration_utils import ConfigMixin -from ..modeling_utils import ModelMixin -from .attention import AttentionBlock -from .embeddings import get_timestep_embedding -from .resnet import Downsample2D, ResnetBlock2D, Upsample2D -from .unet_new import UNetMidBlock2D - - -# from .resnet import ResBlock - - -def exists(val): - return val is not None - - -def uniq(arr): - return {el: True for el in arr}.keys() - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def max_neg_value(t): - return -torch.finfo(t.dtype).max - - -def init_(tensor): - dim = tensor.shape[-1] - std = 1 / math.sqrt(dim) - tensor.uniform_(-std, std) - return tensor - - -# feedforward -class GEGLU(nn.Module): - def __init__(self, dim_in, dim_out): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) - - def forward(self, x): - x, gate = self.proj(x).chunk(2, dim=-1) - return x * F.gelu(gate) - - -class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): - super().__init__() - inner_dim = int(dim * mult) - dim_out = default(dim_out, dim) - project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) - - self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) - - def forward(self, x): - return self.net(x) - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -def Normalize(in_channels): - return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - - -def convert_module_to_f16(l): - """ - Convert primitive modules to float16. - """ - if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): - l.weight.data = l.weight.data.half() - if l.bias is not None: - l.bias.data = l.bias.data.half() - - -def convert_module_to_f32(l): - """ - Convert primitive modules to float32, undoing convert_module_to_f16(). - """ - if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): - l.weight.data = l.weight.data.float() - if l.bias is not None: - l.bias.data = l.bias.data.float() - - -def conv_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D convolution module. - """ - if dims == 1: - return nn.Conv1d(*args, **kwargs) - elif dims == 2: - return nn.Conv2d(*args, **kwargs) - elif dims == 3: - return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -def linear(*args, **kwargs): - """ - Create a linear module. - """ - return nn.Linear(*args, **kwargs) - - -class GroupNorm32(nn.GroupNorm): - def __init__(self, num_groups, num_channels, swish, eps=1e-5): - super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps) - self.swish = swish - - def forward(self, x): - y = super().forward(x.float()).to(x.dtype) - if self.swish == 1.0: - y = F.silu(y) - elif self.swish: - y = y * F.sigmoid(y * float(self.swish)) - return y - - -def normalization(channels, swish=0.0): - """ - Make a standard normalization layer, with an optional swish activation. - - :param channels: number of input channels. :return: an nn.Module for normalization. - """ - return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) - - -class TimestepEmbedSequential(nn.Sequential): - """ - A sequential module that passes timestep embeddings to the children that support it as an extra input. - """ - - def forward(self, x, emb, context=None): - for layer in self: - if isinstance(layer, ResnetBlock2D) or isinstance(layer, TimestepEmbedSequential): - x = layer(x, emb) - elif isinstance(layer, SpatialTransformer): - x = layer(x, context) - else: - x = layer(x) - return x - - -def count_flops_attn(model, _x, y): - """ - A counter for the `thop` package to count the operations in an attention operation. Meant to be used like: - macs, params = thop.profile( - model, inputs=(inputs, timestamps), custom_ops={QKVAttention: QKVAttention.count_flops}, - ) - """ - b, c, *spatial = y[0].shape - num_spatial = int(np.prod(spatial)) - # We perform two matmuls with the same number of ops. - # The first computes the weight matrix, the second computes - # the combination of the value vectors. - matmul_ops = 2 * b * (num_spatial**2) * c - model.total_ops += torch.DoubleTensor([matmul_ops]) - - -class UNetLDMModel(ModelMixin, ConfigMixin): - """ - The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param - model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param - num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample - rates at which - attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x - downsampling, attention will be used. - :param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param - conv_resample: if True, use learned convolutions for upsampling and - downsampling. - :param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this - model will be - class-conditional with `num_classes` classes. - :param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention - heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use - a fixed channel width per attention head. - :param num_heads_upsample: works with num_heads to set a different number - of heads for upsampling. Deprecated. - :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks - for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially - increased efficiency. - """ - - def __init__( - self, - image_size, - in_channels, - model_channels, - out_channels, - num_res_blocks, - attention_resolutions, - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - num_classes=None, - use_checkpoint=False, - use_fp16=False, - num_heads=-1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - use_new_attention_order=False, - use_spatial_transformer=False, # custom transformer support - transformer_depth=1, # custom transformer support - context_dim=None, # custom transformer support - n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model - legacy=True, - ): - super().__init__() - - # register all __init__ params with self.register - self.register_to_config( - image_size=image_size, - in_channels=in_channels, - model_channels=model_channels, - out_channels=out_channels, - num_res_blocks=num_res_blocks, - attention_resolutions=attention_resolutions, - dropout=dropout, - channel_mult=channel_mult, - conv_resample=conv_resample, - dims=dims, - num_classes=num_classes, - use_fp16=use_fp16, - num_heads=num_heads, - num_head_channels=num_head_channels, - num_heads_upsample=num_heads_upsample, - use_scale_shift_norm=use_scale_shift_norm, - resblock_updown=resblock_updown, - use_spatial_transformer=use_spatial_transformer, - transformer_depth=transformer_depth, - context_dim=context_dim, - n_embed=n_embed, - legacy=legacy, - ) - - if use_spatial_transformer: - assert ( - context_dim is not None - ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." - - if context_dim is not None: - assert ( - use_spatial_transformer - ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." - - if num_heads_upsample == -1: - num_heads_upsample = num_heads - - if num_heads == -1: - assert num_head_channels != -1, "Either num_heads or num_head_channels has to be set" - - if num_head_channels == -1: - assert num_heads != -1, "Either num_heads or num_head_channels has to be set" - - self.image_size = image_size - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.attention_resolutions = attention_resolutions - self.dropout = dropout - self.channel_mult = channel_mult - self.conv_resample = conv_resample - self.num_classes = num_classes - self.dtype_ = torch.float16 if use_fp16 else torch.float32 - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - self.predict_codebook_ids = n_embed is not None - - time_embed_dim = model_channels * 4 - self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - if self.num_classes is not None: - self.label_emb = nn.Embedding(num_classes, time_embed_dim) - - self.input_blocks = nn.ModuleList( - [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))] - ) - - self.down_in_conv = self.input_blocks[0][0] - self.downsample_blocks = nn.ModuleList([]) - self.upsample_blocks = nn.ModuleList([]) - - # ========================= Down (OLD) =================== # - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - for level, mult in enumerate(channel_mult): - for _ in range(num_res_blocks): - layers = [ - ResnetBlock2D( - in_channels=ch, - out_channels=mult * model_channels, - dropout=dropout, - temb_channels=time_embed_dim, - eps=1e-5, - non_linearity="silu", - overwrite_for_ldm=True, - ) - ] - ch = mult * model_channels - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - # num_heads = 1 - dim_head = ch // num_heads if use_spatial_transformer else num_head_channels - layers.append( - AttentionBlock( - ch, - num_heads=num_heads, - num_head_channels=dim_head, - ) - if not use_spatial_transformer - else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim - ) - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op") - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - input_channels = [model_channels * mult for mult in [1] + list(channel_mult[:-1])] - output_channels = [model_channels * mult for mult in channel_mult] - - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - # num_heads = 1 - dim_head = ch // num_heads if use_spatial_transformer else num_head_channels - - if dim_head < 0: - dim_head = None - - # ========================= MID (New) =================== # - self.mid = UNetMidBlock2D( - in_channels=ch, - dropout=dropout, - temb_channels=time_embed_dim, - resnet_eps=1e-5, - resnet_act_fn="silu", - resnet_time_scale_shift="scale_shift" if use_scale_shift_norm else "default", - attention_layer_type="self" if not use_spatial_transformer else "spatial", - attn_num_heads=num_heads, - attn_num_head_channels=dim_head, - attn_depth=transformer_depth, - attn_encoder_channels=context_dim, - ) - - # TODO(Patrick) - delete after weight conversion - # init to be able to overwrite `self.mid` - self.middle_block = TimestepEmbedSequential( - ResnetBlock2D( - in_channels=ch, - out_channels=None, - dropout=dropout, - temb_channels=time_embed_dim, - eps=1e-5, - non_linearity="silu", - overwrite_for_ldm=True, - ), - AttentionBlock( - ch, - num_heads=num_heads, - num_head_channels=dim_head, - ) - if not use_spatial_transformer - else SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim), - ResnetBlock2D( - in_channels=ch, - out_channels=None, - dropout=dropout, - temb_channels=time_embed_dim, - eps=1e-5, - non_linearity="silu", - overwrite_for_ldm=True, - ), - ) - self.mid.resnets[0] = self.middle_block[0] - self.mid.attentions[0] = self.middle_block[1] - self.mid.resnets[1] = self.middle_block[2] - - self._feature_size += ch - - # ========================= Up (Old) =================== # - self.output_blocks = nn.ModuleList([]) - for level, mult in list(enumerate(channel_mult))[::-1]: - for i in range(num_res_blocks + 1): - ich = input_block_chans.pop() - layers = [ - ResnetBlock2D( - in_channels=ch + ich, - out_channels=model_channels * mult, - dropout=dropout, - temb_channels=time_embed_dim, - eps=1e-5, - non_linearity="silu", - overwrite_for_ldm=True, - ), - ] - ch = model_channels * mult - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - # num_heads = 1 - dim_head = ch // num_heads if use_spatial_transformer else num_head_channels - layers.append( - AttentionBlock( - ch, - num_heads=num_heads_upsample, - num_head_channels=dim_head, - ) - if not use_spatial_transformer - else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim - ) - ) - if level and i == num_res_blocks: - out_ch = ch - layers.append(Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch)) - ds //= 2 - self.output_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), - ) - - def forward(self, x, timesteps=None, context=None, y=None, **kwargs): - """ - Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch - of timesteps. :param context: conditioning plugged in via crossattn :param y: an [N] Tensor of labels, if - class-conditional. :return: an [N x C x ...] Tensor of outputs. - """ - assert (y is not None) == ( - self.num_classes is not None - ), "must specify y if and only if the model is class-conditional" - hs = [] - if not torch.is_tensor(timesteps): - timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device) - t_emb = get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0) - emb = self.time_embed(t_emb) - - if self.num_classes is not None: - assert y.shape == (x.shape[0],) - emb = emb + self.label_emb(y) - - h = x.type(self.dtype_) - - for module in self.input_blocks: - h = module(h, emb, context) - hs.append(h) - - h = self.mid(h, emb, context) - - for module in self.output_blocks: - h = torch.cat([h, hs.pop()], dim=1) - h = module(h, emb, context) - - return self.out(h) - - -class SpatialTransformer(nn.Module): - """ - Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply - standard transformer action. Finally, reshape to image - """ - - def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None): - super().__init__() - self.in_channels = in_channels - inner_dim = n_heads * d_head - self.norm = Normalize(in_channels) - - self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) - - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) - for d in range(depth) - ] - ) - - self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) - - def forward(self, x, context=None): - # note: if no context is given, cross-attention defaults to self-attention - b, c, h, w = x.shape - x_in = x - x = self.norm(x) - x = self.proj_in(x) - x = x.permute(0, 2, 3, 1).reshape(b, h * w, c) - for block in self.transformer_blocks: - x = block(x, context=context) - x = x.reshape(b, h, w, c).permute(0, 3, 1, 2) - x = self.proj_out(x) - return x + x_in - - -class BasicTransformerBlock(nn.Module): - def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True): - super().__init__() - self.attn1 = CrossAttention( - query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout - ) # is a self-attention - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = CrossAttention( - query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout - ) # is self-attn if context is none - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - self.norm3 = nn.LayerNorm(dim) - self.checkpoint = checkpoint - - def forward(self, x, context=None): - x = self.attn1(self.norm1(x)) + x - x = self.attn2(self.norm2(x), context=context) + x - x = self.ff(self.norm3(x)) + x - return x - - -class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): - super().__init__() - inner_dim = dim_head * heads - context_dim = default(context_dim, query_dim) - - self.scale = dim_head**-0.5 - self.heads = heads - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - - self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) - - def reshape_heads_to_batch_dim(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) - return tensor - - def reshape_batch_dim_to_heads(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.heads - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) - return tensor - - def forward(self, x, context=None, mask=None): - batch_size, sequence_length, dim = x.shape - - h = self.heads - - q = self.to_q(x) - context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) - - q = self.reshape_heads_to_batch_dim(q) - k = self.reshape_heads_to_batch_dim(k) - v = self.reshape_heads_to_batch_dim(v) - - sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale - - if exists(mask): - mask = mask.reshape(batch_size, -1) - max_neg_value = -torch.finfo(sim.dtype).max - mask = mask[:, None, :].repeat(h, 1, 1) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) - - out = torch.einsum("b i j, b j d -> b i d", attn, v) - out = self.reshape_batch_dim_to_heads(out) - return self.to_out(out) diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py deleted file mode 100644 index e8961452..00000000 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ /dev/null @@ -1,471 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and - -# limitations under the License. - -# helpers functions - -import functools -import math - -import numpy as np -import torch -import torch.nn as nn - -from ..configuration_utils import ConfigMixin -from ..modeling_utils import ModelMixin -from .attention import AttentionBlock -from .embeddings import GaussianFourierProjection, get_timestep_embedding -from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D -from .unet_new import UNetMidBlock2D - - -class Combine(nn.Module): - """Combine information from skip connections.""" - - def __init__(self, dim1, dim2, method="cat"): - super().__init__() - # 1x1 convolution with DDPM initialization. - self.Conv_0 = nn.Conv2d(dim1, dim2, kernel_size=1, padding=0) - self.method = method - - def forward(self, x, y): - h = self.Conv_0(x) - if self.method == "cat": - return torch.cat([h, y], dim=1) - elif self.method == "sum": - return h + y - else: - raise ValueError(f"Method {self.method} not recognized.") - - -class NCSNpp(ModelMixin, ConfigMixin): - """NCSN++ model""" - - def __init__( - self, - image_size=1024, - num_channels=3, - centered=False, - attn_resolutions=(16,), - ch_mult=(1, 2, 4, 8, 16, 32, 32, 32), - conditional=True, - conv_size=3, - dropout=0.0, - embedding_type="fourier", - fir=True, - fir_kernel=(1, 3, 3, 1), - fourier_scale=16, - init_scale=0.0, - nf=16, - num_res_blocks=1, - progressive="output_skip", - progressive_combine="sum", - progressive_input="input_skip", - resamp_with_conv=True, - scale_by_sigma=True, - skip_rescale=True, - continuous=True, - ): - super().__init__() - self.register_to_config( - image_size=image_size, - num_channels=num_channels, - centered=centered, - attn_resolutions=attn_resolutions, - ch_mult=ch_mult, - conditional=conditional, - conv_size=conv_size, - dropout=dropout, - embedding_type=embedding_type, - fir=fir, - fir_kernel=fir_kernel, - fourier_scale=fourier_scale, - init_scale=init_scale, - nf=nf, - num_res_blocks=num_res_blocks, - progressive=progressive, - progressive_combine=progressive_combine, - progressive_input=progressive_input, - resamp_with_conv=resamp_with_conv, - scale_by_sigma=scale_by_sigma, - skip_rescale=skip_rescale, - continuous=continuous, - ) - self.act = nn.SiLU() - self.nf = nf - self.num_res_blocks = num_res_blocks - self.attn_resolutions = attn_resolutions - self.num_resolutions = len(ch_mult) - self.all_resolutions = all_resolutions = [image_size // (2**i) for i in range(self.num_resolutions)] - - self.conditional = conditional - self.skip_rescale = skip_rescale - self.progressive = progressive - self.progressive_input = progressive_input - self.embedding_type = embedding_type - assert progressive in ["none", "output_skip", "residual"] - assert progressive_input in ["none", "input_skip", "residual"] - assert embedding_type in ["fourier", "positional"] - combine_method = progressive_combine.lower() - combiner = functools.partial(Combine, method=combine_method) - - modules = [] - # timestep/noise_level embedding; only for continuous training - if embedding_type == "fourier": - # Gaussian Fourier features embeddings. - modules.append(GaussianFourierProjection(embedding_size=nf, scale=fourier_scale)) - embed_dim = 2 * nf - - elif embedding_type == "positional": - embed_dim = nf - - else: - raise ValueError(f"embedding type {embedding_type} unknown.") - - modules.append(nn.Linear(embed_dim, nf * 4)) - modules.append(nn.Linear(nf * 4, nf * 4)) - - AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0)) - - if self.fir: - Up_sample = functools.partial(FirUpsample2D, fir_kernel=fir_kernel, use_conv=resamp_with_conv) - else: - Up_sample = functools.partial(Upsample2D, name="Conv2d_0") - - if progressive == "output_skip": - self.pyramid_upsample = Up_sample(channels=None, use_conv=False) - elif progressive == "residual": - pyramid_upsample = functools.partial(Up_sample, use_conv=True) - - if self.fir: - Down_sample = functools.partial(FirDownsample2D, fir_kernel=fir_kernel, use_conv=resamp_with_conv) - else: - Down_sample = functools.partial(Downsample2D, padding=0, name="Conv2d_0") - - if progressive_input == "input_skip": - self.pyramid_downsample = Down_sample(channels=None, use_conv=False) - elif progressive_input == "residual": - pyramid_downsample = functools.partial(Down_sample, use_conv=True) - - channels = num_channels - if progressive_input != "none": - input_pyramid_ch = channels - - modules.append(nn.Conv2d(channels, nf, kernel_size=3, padding=1)) - hs_c = [nf] - - in_ch = nf - for i_level in range(self.num_resolutions): - # Residual blocks for this resolution - for i_block in range(num_res_blocks): - out_ch = nf * ch_mult[i_level] - modules.append( - ResnetBlock2D( - in_channels=in_ch, - out_channels=out_ch, - temb_channels=4 * nf, - output_scale_factor=np.sqrt(2.0), - non_linearity="silu", - groups=min(in_ch // 4, 32), - groups_out=min(out_ch // 4, 32), - overwrite_for_score_vde=True, - ) - ) - in_ch = out_ch - - if all_resolutions[i_level] in attn_resolutions: - modules.append(AttnBlock(channels=in_ch)) - hs_c.append(in_ch) - - if i_level != self.num_resolutions - 1: - modules.append( - ResnetBlock2D( - in_channels=in_ch, - temb_channels=4 * nf, - output_scale_factor=np.sqrt(2.0), - non_linearity="silu", - groups=min(in_ch // 4, 32), - groups_out=min(out_ch // 4, 32), - overwrite_for_score_vde=True, - down=True, - kernel="fir" if self.fir else "sde_vp", - use_nin_shortcut=True, - ) - ) - - if progressive_input == "input_skip": - modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch)) - if combine_method == "cat": - in_ch *= 2 - - elif progressive_input == "residual": - modules.append(pyramid_downsample(channels=input_pyramid_ch, out_channels=in_ch)) - input_pyramid_ch = in_ch - - hs_c.append(in_ch) - - # mid - self.mid = UNetMidBlock2D( - in_channels=in_ch, - temb_channels=4 * nf, - output_scale_factor=math.sqrt(2.0), - resnet_act_fn="silu", - resnet_groups=min(in_ch // 4, 32), - dropout=dropout, - ) - - in_ch = hs_c[-1] - modules.append( - ResnetBlock2D( - in_channels=in_ch, - temb_channels=4 * nf, - output_scale_factor=np.sqrt(2.0), - non_linearity="silu", - groups=min(in_ch // 4, 32), - groups_out=min(out_ch // 4, 32), - overwrite_for_score_vde=True, - ) - ) - modules.append(AttnBlock(channels=in_ch)) - modules.append( - ResnetBlock2D( - in_channels=in_ch, - temb_channels=4 * nf, - output_scale_factor=np.sqrt(2.0), - non_linearity="silu", - groups=min(in_ch // 4, 32), - groups_out=min(out_ch // 4, 32), - overwrite_for_score_vde=True, - ) - ) - # self.mid.resnets[0] = modules[len(modules) - 3] - # self.mid.attentions[0] = modules[len(modules) - 2] - # self.mid.resnets[1] = modules[len(modules) - 1] - - pyramid_ch = 0 - # Upsampling block - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(num_res_blocks + 1): - out_ch = nf * ch_mult[i_level] - in_ch = in_ch + hs_c.pop() - modules.append( - ResnetBlock2D( - in_channels=in_ch, - out_channels=out_ch, - temb_channels=4 * nf, - output_scale_factor=np.sqrt(2.0), - non_linearity="silu", - groups=min(in_ch // 4, 32), - groups_out=min(out_ch // 4, 32), - overwrite_for_score_vde=True, - ) - ) - in_ch = out_ch - - if all_resolutions[i_level] in attn_resolutions: - modules.append(AttnBlock(channels=in_ch)) - - if progressive != "none": - if i_level == self.num_resolutions - 1: - if progressive == "output_skip": - modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) - modules.append(nn.Conv2d(in_ch, channels, kernel_size=3, padding=1)) - pyramid_ch = channels - # elif progressive == "residual": - # modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) - # modules.append(nn.Conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1)) - # pyramid_ch = in_ch - # else: - # raise ValueError(f"{progressive} is not a valid name.") - else: - if progressive == "output_skip": - modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) - modules.append(nn.Conv2d(in_ch, channels, bias=True, kernel_size=3, padding=1)) - pyramid_ch = channels - # elif progressive == "residual": - # modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch)) - # pyramid_ch = in_ch - # else: - # raise ValueError(f"{progressive} is not a valid name") - - if i_level != 0: - modules.append( - ResnetBlock2D( - in_channels=in_ch, - temb_channels=4 * nf, - output_scale_factor=np.sqrt(2.0), - non_linearity="silu", - groups=min(in_ch // 4, 32), - groups_out=min(out_ch // 4, 32), - overwrite_for_score_vde=True, - up=True, - kernel="fir" if self.fir else "sde_vp", - use_nin_shortcut=True, - ) - ) - - assert not hs_c - - if progressive != "output_skip": - modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) - modules.append(nn.Conv2d(in_ch, channels, kernel_size=3, padding=1)) - - self.all_modules = nn.ModuleList(modules) - - def forward(self, sample, timestep, sigmas=None): - timesteps = timestep - x = sample - # timestep/noise_level embedding; only for continuous training - modules = self.all_modules - m_idx = 0 - if self.embedding_type == "fourier": - # Gaussian Fourier features embeddings. - used_sigmas = timesteps - temb = modules[m_idx](used_sigmas) - m_idx += 1 - - elif self.embedding_type == "positional": - # Sinusoidal positional embeddings. - timesteps = timesteps - used_sigmas = sigmas - temb = get_timestep_embedding(timesteps, self.nf) - - else: - raise ValueError(f"embedding type {self.embedding_type} unknown.") - - if self.conditional: - temb = modules[m_idx](temb) - m_idx += 1 - temb = modules[m_idx](self.act(temb)) - m_idx += 1 - else: - temb = None - - # If input data is in [0, 1] - if not self.config.centered: - x = 2 * x - 1.0 - - # Downsampling block - input_pyramid = None - if self.progressive_input != "none": - input_pyramid = x - - hs = [modules[m_idx](x)] - m_idx += 1 - - for i_level in range(self.num_resolutions): - # Residual blocks for this resolution - for i_block in range(self.num_res_blocks): - h = modules[m_idx](hs[-1], temb) - m_idx += 1 - if h.shape[-1] in self.attn_resolutions: - h = modules[m_idx](h) - m_idx += 1 - - hs.append(h) - - if i_level != self.num_resolutions - 1: - h = modules[m_idx](hs[-1], temb) - m_idx += 1 - - if self.progressive_input == "input_skip": - input_pyramid = self.pyramid_downsample(input_pyramid) - h = modules[m_idx](input_pyramid, h) - m_idx += 1 - - elif self.progressive_input == "residual": - input_pyramid = modules[m_idx](input_pyramid) - m_idx += 1 - if self.skip_rescale: - input_pyramid = (input_pyramid + h) / np.sqrt(2.0) - else: - input_pyramid = input_pyramid + h - h = input_pyramid - - hs.append(h) - - h = hs[-1] - h = modules[m_idx](h, temb) - m_idx += 1 - h = modules[m_idx](h) - m_idx += 1 - h = modules[m_idx](h, temb) - m_idx += 1 - - pyramid = None - - # Upsampling block - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb) - m_idx += 1 - - if h.shape[-1] in self.attn_resolutions: - h = modules[m_idx](h) - m_idx += 1 - - if self.progressive != "none": - if i_level == self.num_resolutions - 1: - if self.progressive == "output_skip": - pyramid = self.act(modules[m_idx](h)) - m_idx += 1 - pyramid = modules[m_idx](pyramid) - m_idx += 1 - # elif self.progressive == "residual": - # pyramid = self.act(modules[m_idx](h)) - # m_idx += 1 - # pyramid = modules[m_idx](pyramid) - # m_idx += 1 - # else: - # raise ValueError(f"{self.progressive} is not a valid name.") - else: - if self.progressive == "output_skip": - pyramid_h = self.act(modules[m_idx](h)) - m_idx += 1 - pyramid_h = modules[m_idx](pyramid_h) - m_idx += 1 - - skip_sample = self.pyramid_upsample(pyramid) - pyramid = skip_sample + pyramid_h - # elif self.progressive == "residual": - # pyramid = modules[m_idx](pyramid) - # m_idx += 1 - # if self.skip_rescale: - # pyramid = (pyramid + h) / np.sqrt(2.0) - # else: - # pyramid = pyramid + h - # h = pyramid - # else: - # raise ValueError(f"{self.progressive} is not a valid name") - - if i_level != 0: - h = modules[m_idx](h, temb) - m_idx += 1 - - assert not hs - - if self.progressive == "output_skip": - h = pyramid - else: - h = self.act(modules[m_idx](h)) - m_idx += 1 - h = modules[m_idx](h) - m_idx += 1 - - assert m_idx == len(modules) - if self.config.scale_by_sigma: - used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:])))) - h = h / used_sigmas - - return h diff --git a/src/diffusers/models/unet_unconditional.py b/src/diffusers/models/unet_unconditional.py index 60cac787..610f7125 100644 --- a/src/diffusers/models/unet_unconditional.py +++ b/src/diffusers/models/unet_unconditional.py @@ -1,74 +1,12 @@ -import functools -import math from typing import Dict, Union -import numpy as np import torch import torch.nn as nn from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin -from .attention import AttentionBlock -from .embeddings import GaussianFourierProjection, get_timestep_embedding -from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D -from .unet_new import UNetMidBlock2D, get_down_block, get_up_block - - -class Combine(nn.Module): - """Combine information from skip connections.""" - - def __init__(self, dim1, dim2, method="cat"): - super().__init__() - # 1x1 convolution with DDPM initialization. - self.Conv_0 = nn.Conv2d(dim1, dim2, kernel_size=1, padding=0) - self.method = method - - -# def forward(self, x, y): -# h = self.Conv_0(x) -# if self.method == "cat": -# return torch.cat([h, y], dim=1) -# elif self.method == "sum": -# return h + y -# else: -# raise ValueError(f"Method {self.method} not recognized.") - - -class TimestepEmbedding(nn.Module): - def __init__(self, channel, time_embed_dim, act_fn="silu"): - super().__init__() - - self.linear_1 = nn.Linear(channel, time_embed_dim) - self.act = None - if act_fn == "silu": - self.act = nn.SiLU() - self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) - - def forward(self, sample): - sample = self.linear_1(sample) - - if self.act is not None: - sample = self.act(sample) - - sample = self.linear_2(sample) - return sample - - -class Timesteps(nn.Module): - def __init__(self, num_channels, flip_sin_to_cos, downscale_freq_shift): - super().__init__() - self.num_channels = num_channels - self.flip_sin_to_cos = flip_sin_to_cos - self.downscale_freq_shift = downscale_freq_shift - - def forward(self, timesteps): - t_emb = get_timestep_embedding( - timesteps, - self.num_channels, - flip_sin_to_cos=self.flip_sin_to_cos, - downscale_freq_shift=self.downscale_freq_shift, - ) - return t_emb +from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps +from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block class UNetUnconditionalModel(ModelMixin, ConfigMixin): @@ -120,39 +58,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): time_embedding_type="positional", mid_block_scale_factor=1, center_input_sample=False, - # TODO(PVP) - to delete later at release - # IMPORTANT: NOT RELEVANT WHEN REVIEWING API - # ====================================== - # LDM - attention_resolutions=(8, 4, 2), - ldm=False, - # DDPM - out_ch=None, - resolution=None, - attn_resolutions=None, - resamp_with_conv=None, - ch_mult=None, - ch=None, - ddpm=False, - # SDE - sde=False, - nf=None, - fir=None, - progressive=None, - progressive_combine=None, - scale_by_sigma=None, - skip_rescale=None, - num_channels=None, - centered=False, - conditional=True, - conv_size=3, - fir_kernel=(1, 3, 3, 1), - fourier_scale=16, - init_scale=0.0, - progressive_input="input_skip", resnet_num_groups=32, - continuous=True, - **kwargs, ): super().__init__() # register all __init__ params to be accessible via `self.config.<...>` @@ -173,51 +79,20 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=downscale_freq_shift, time_embedding_type=time_embedding_type, - attention_resolutions=attention_resolutions, - attn_resolutions=attn_resolutions, mid_block_scale_factor=mid_block_scale_factor, resnet_num_groups=resnet_num_groups, center_input_sample=center_input_sample, - # to delete later - ldm=ldm, - ddpm=ddpm, - sde=sde, ) - # if sde: - # block_channels = [nf * x for x in ch_mult] - # in_channels = out_channels = num_channels - # conv_resample = resamp_with_conv - # time_embedding_type = "fourier" - # self.config.time_embedding_type = time_embedding_type - # self.config.resnet_eps = 1e-6 - # self.config.mid_block_scale_factor = math.sqrt(2.0) - # self.config.resnet_num_groups = None - # down_blocks = ( - # "UNetResSkipDownBlock2D", - # "UNetResAttnSkipDownBlock2D", - # "UNetResSkipDownBlock2D", - # "UNetResSkipDownBlock2D", - # ) - # up_blocks = ( - # "UNetResSkipUpBlock2D", - # "UNetResSkipUpBlock2D", - # "UNetResAttnSkipUpBlock2D", - # "UNetResSkipUpBlock2D", - # ) - # TODO(PVP) - to delete later at release - # IMPORTANT: NOT RELEVANT WHEN REVIEWING API - # ====================================== self.image_size = image_size time_embed_dim = block_channels[0] * 4 - # ====================================== # input self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1)) # time if time_embedding_type == "fourier": - self.time_steps = GaussianFourierProjection(embedding_size=block_channels[0], scale=fourier_scale) + self.time_steps = GaussianFourierProjection(embedding_size=block_channels[0], scale=16) timestep_input_dim = 2 * block_channels[0] elif time_embedding_type == "positional": self.time_steps = Timesteps(block_channels[0], flip_sin_to_cos, downscale_freq_shift) @@ -251,30 +126,17 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): self.downsample_blocks.append(down_block) # mid - if ddpm: - self.mid_new_2 = UNetMidBlock2D( - in_channels=block_channels[-1], - dropout=dropout, - temb_channels=time_embed_dim, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift="default", - attn_num_head_channels=num_head_channels, - resnet_groups=resnet_num_groups, - ) - else: - self.mid = UNetMidBlock2D( - in_channels=block_channels[-1], - dropout=dropout, - temb_channels=time_embed_dim, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift="default", - attn_num_head_channels=num_head_channels, - resnet_groups=resnet_num_groups, - ) + self.mid = UNetMidBlock2D( + in_channels=block_channels[-1], + dropout=dropout, + temb_channels=time_embed_dim, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift="default", + attn_num_head_channels=num_head_channels, + resnet_groups=resnet_num_groups, + ) # up reversed_block_channels = list(reversed(block_channels)) @@ -307,116 +169,11 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1) - # ======================== Out ==================== - - # =========== TO DELETE AFTER CONVERSION ========== - # TODO(PVP) - to delete later at release - # IMPORTANT: NOT RELEVANT WHEN REVIEWING API - # ====================================== - self.is_overwritten = False - if ldm: - transformer_depth = 1 - context_dim = None - legacy = True - num_heads = -1 - model_channels = block_channels[0] - channel_mult = tuple([x // model_channels for x in block_channels]) - self.init_for_ldm( - in_channels, - model_channels, - channel_mult, - num_res_blocks, - dropout, - time_embed_dim, - attention_resolutions, - num_head_channels, - num_heads, - legacy, - False, - transformer_depth, - context_dim, - conv_resample, - out_channels, - ) - elif ddpm: - out_channels = out_ch - image_size = resolution - block_channels = [x * ch for x in ch_mult] - conv_resample = resamp_with_conv - out_ch = out_channels - resolution = image_size - ch = block_channels[0] - ch_mult = [b // ch for b in block_channels] - resamp_with_conv = conv_resample - self.init_for_ddpm( - ch_mult, - ch, - num_res_blocks, - resolution, - in_channels, - resamp_with_conv, - attn_resolutions, - out_ch, - dropout=0.1, - ) - elif sde: - nf = block_channels[0] - ch_mult = [x // nf for x in block_channels] - num_channels = in_channels - # in_channels = out_channels = num_channels = in_channels - # block_channels = [nf * x for x in ch_mult] - # conv_resample = resamp_with_conv - resamp_with_conv = conv_resample - time_embedding_type = self.config.time_embedding_type - # time_embedding_type = "fourier" - # self.config.time_embedding_type = time_embedding_type - fir = True - progressive = "output_skip" - progressive_combine = "sum" - scale_by_sigma = True - skip_rescale = True - centered = False - conditional = True - conv_size = 3 - fir_kernel = (1, 3, 3, 1) - fourier_scale = 16 - init_scale = 0.0 - progressive_input = "input_skip" - continuous = True - self.init_for_sde( - image_size, - num_channels, - centered, - attn_resolutions, - ch_mult, - conditional, - conv_size, - dropout, - time_embedding_type, - fir, - fir_kernel, - fourier_scale, - init_scale, - nf, - num_res_blocks, - progressive, - progressive_combine, - progressive_input, - resamp_with_conv, - scale_by_sigma, - skip_rescale, - continuous, - ) - def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int] ) -> Dict[str, torch.FloatTensor]: - # TODO(PVP) - to delete later at release - # IMPORTANT: NOT RELEVANT WHEN REVIEWING API - # ====================================== - if not self.is_overwritten: - self.set_weights() + # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 @@ -447,10 +204,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): down_block_res_samples += res_samples # 4. mid - if self.config.ddpm: - sample = self.mid_new_2(sample, emb) - else: - sample = self.mid(sample, emb) + sample = self.mid(sample, emb) # 5. up skip_sample = None @@ -464,7 +218,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): sample = upsample_block(sample, res_samples, emb) # 6. post-process - sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) @@ -472,724 +225,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): if skip_sample is not None: sample += skip_sample - if ( - self.config.time_embedding_type == "fourier" - or self.time_steps.__class__.__name__ == "GaussianFourierProjection" - ): + if self.config.time_embedding_type == "fourier": timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:])))) sample = sample / timesteps output = {"sample": sample} return output - - # !!!IMPORTANT - ALL OF THE FOLLOWING CODE WILL BE DELETED AT RELEASE TIME AND SHOULD NOT BE TAKEN INTO CONSIDERATION WHEN EVALUATING THE API ### - # ================================================================================================================================================= - - def set_weights(self): - self.is_overwritten = True - if self.config.ldm: - self.time_embedding.linear_1.weight.data = self.time_embed[0].weight.data - self.time_embedding.linear_1.bias.data = self.time_embed[0].bias.data - self.time_embedding.linear_2.weight.data = self.time_embed[2].weight.data - self.time_embedding.linear_2.bias.data = self.time_embed[2].bias.data - - # ================ SET WEIGHTS OF ALL WEIGHTS ================== - for i, input_layer in enumerate(self.input_blocks[1:]): - block_id = i // (self.config.num_res_blocks + 1) - layer_in_block_id = i % (self.config.num_res_blocks + 1) - - if layer_in_block_id == 2: - self.downsample_blocks[block_id].downsamplers[0].conv.weight.data = input_layer[0].op.weight.data - self.downsample_blocks[block_id].downsamplers[0].conv.bias.data = input_layer[0].op.bias.data - elif len(input_layer) > 1: - self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0]) - self.downsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1]) - else: - self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0]) - - self.mid.resnets[0].set_weight(self.middle_block[0]) - self.mid.resnets[1].set_weight(self.middle_block[2]) - self.mid.attentions[0].set_weight(self.middle_block[1]) - - for i, input_layer in enumerate(self.output_blocks): - block_id = i // (self.config.num_res_blocks + 1) - layer_in_block_id = i % (self.config.num_res_blocks + 1) - - if len(input_layer) > 2: - self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0]) - self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1]) - self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[2].conv.weight.data - self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[2].conv.bias.data - elif len(input_layer) > 1 and "Upsample2D" in input_layer[1].__class__.__name__: - self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0]) - self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[1].conv.weight.data - self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[1].conv.bias.data - elif len(input_layer) > 1: - self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0]) - self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1]) - else: - self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0]) - - self.conv_in.weight.data = self.input_blocks[0][0].weight.data - self.conv_in.bias.data = self.input_blocks[0][0].bias.data - - self.conv_norm_out.weight.data = self.out[0].weight.data - self.conv_norm_out.bias.data = self.out[0].bias.data - self.conv_out.weight.data = self.out[2].weight.data - self.conv_out.bias.data = self.out[2].bias.data - - self.remove_ldm() - - elif self.config.ddpm: - self.time_embedding.linear_1.weight.data = self.temb.dense[0].weight.data - self.time_embedding.linear_1.bias.data = self.temb.dense[0].bias.data - self.time_embedding.linear_2.weight.data = self.temb.dense[1].weight.data - self.time_embedding.linear_2.bias.data = self.temb.dense[1].bias.data - - for i, block in enumerate(self.down): - if hasattr(block, "downsample"): - self.downsample_blocks[i].downsamplers[0].conv.weight.data = block.downsample.conv.weight.data - self.downsample_blocks[i].downsamplers[0].conv.bias.data = block.downsample.conv.bias.data - if hasattr(block, "block") and len(block.block) > 0: - for j in range(self.num_res_blocks): - self.downsample_blocks[i].resnets[j].set_weight(block.block[j]) - if hasattr(block, "attn") and len(block.attn) > 0: - for j in range(self.num_res_blocks): - self.downsample_blocks[i].attentions[j].set_weight(block.attn[j]) - - self.mid_new_2.resnets[0].set_weight(self.mid.block_1) - self.mid_new_2.resnets[1].set_weight(self.mid.block_2) - self.mid_new_2.attentions[0].set_weight(self.mid.attn_1) - - for i, block in enumerate(self.up): - k = len(self.up) - 1 - i - if hasattr(block, "upsample"): - self.upsample_blocks[k].upsamplers[0].conv.weight.data = block.upsample.conv.weight.data - self.upsample_blocks[k].upsamplers[0].conv.bias.data = block.upsample.conv.bias.data - if hasattr(block, "block") and len(block.block) > 0: - for j in range(self.num_res_blocks + 1): - self.upsample_blocks[k].resnets[j].set_weight(block.block[j]) - if hasattr(block, "attn") and len(block.attn) > 0: - for j in range(self.num_res_blocks + 1): - self.upsample_blocks[k].attentions[j].set_weight(block.attn[j]) - - self.conv_norm_out.weight.data = self.norm_out.weight.data - self.conv_norm_out.bias.data = self.norm_out.bias.data - - self.remove_ddpm() - elif self.config.sde: - self.time_steps.weight = self.all_modules[0].weight - self.time_embedding.linear_1.weight.data = self.all_modules[1].weight.data - self.time_embedding.linear_1.bias.data = self.all_modules[1].bias.data - self.time_embedding.linear_2.weight.data = self.all_modules[2].weight.data - self.time_embedding.linear_2.bias.data = self.all_modules[2].bias.data - - self.conv_in.weight.data = self.all_modules[3].weight.data - self.conv_in.bias.data = self.all_modules[3].bias.data - - module_index = 4 - for i, block in enumerate(self.downsample_blocks): - has_attentios = hasattr(block, "attentions") - if has_attentios: - for j in range(len(block.attentions)): - block.resnets[j].set_weight(self.all_modules[module_index]) - module_index += 1 - block.attentions[j].set_weight(self.all_modules[module_index]) - module_index += 1 - if hasattr(block, "downsamplers") and block.downsamplers is not None: - block.resnet_down.set_weight(self.all_modules[module_index]) - module_index += 1 - block.skip_conv.weight.data = self.all_modules[module_index].Conv_0.weight.data - block.skip_conv.bias.data = self.all_modules[module_index].Conv_0.bias.data - module_index += 1 - else: - for j in range(len(block.resnets)): - block.resnets[j].set_weight(self.all_modules[module_index]) - module_index += 1 - if hasattr(block, "downsamplers") and block.downsamplers is not None: - block.resnet_down.set_weight(self.all_modules[module_index]) - module_index += 1 - block.skip_conv.weight.data = self.all_modules[module_index].Conv_0.weight.data - block.skip_conv.bias.data = self.all_modules[module_index].Conv_0.bias.data - module_index += 1 - - self.mid.resnets[0].set_weight(self.all_modules[module_index]) - module_index += 1 - self.mid.attentions[0].set_weight(self.all_modules[module_index]) - module_index += 1 - self.mid.resnets[1].set_weight(self.all_modules[module_index]) - module_index += 1 - - for i, block in enumerate(self.upsample_blocks): - for j in range(len(block.resnets)): - block.resnets[j].set_weight(self.all_modules[module_index]) - module_index += 1 - if hasattr(block, "attentions") and block.attentions is not None: - block.attentions[0].set_weight(self.all_modules[module_index]) - module_index += 1 - if hasattr(block, "resnet_up") and block.resnet_up is not None: - block.skip_norm.weight.data = self.all_modules[module_index].weight.data - block.skip_norm.bias.data = self.all_modules[module_index].bias.data - module_index += 1 - block.skip_conv.weight.data = self.all_modules[module_index].weight.data - block.skip_conv.bias.data = self.all_modules[module_index].bias.data - module_index += 1 - block.resnet_up.set_weight(self.all_modules[module_index]) - module_index += 1 - - self.conv_norm_out.weight.data = self.all_modules[module_index].weight.data - self.conv_norm_out.bias.data = self.all_modules[module_index].bias.data - module_index += 1 - self.conv_out.weight.data = self.all_modules[module_index].weight.data - self.conv_out.bias.data = self.all_modules[module_index].bias.data - - self.remove_sde() - - def init_for_ddpm( - self, - ch_mult, - ch, - num_res_blocks, - resolution, - in_channels, - resamp_with_conv, - attn_resolutions, - out_ch, - dropout=0.1, - ): - ch_mult = tuple(ch_mult) - self.ch = ch - self.temb_ch = self.ch * 4 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - - # timestep embedding - self.temb = nn.Module() - self.temb.dense = nn.ModuleList( - [ - torch.nn.Linear(self.ch, self.temb_ch), - torch.nn.Linear(self.temb_ch, self.temb_ch), - ] - ) - - # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) - - curr_res = resolution - in_ch_mult = (1,) + ch_mult - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock2D( - in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttentionBlock(block_in, overwrite_qkv=True)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - down.downsample = Downsample2D(block_in, use_conv=resamp_with_conv, padding=0) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock2D( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True) - self.mid.block_2 = ResnetBlock2D( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - self.mid_new = UNetMidBlock2D(in_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) - self.mid_new.resnets[0] = self.mid.block_1 - self.mid_new.attentions[0] = self.mid.attn_1 - self.mid_new.resnets[1] = self.mid.block_2 - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - skip_in = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - if i_block == self.num_res_blocks: - skip_in = ch * in_ch_mult[i_level] - block.append( - ResnetBlock2D( - in_channels=block_in + skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttentionBlock(block_in, overwrite_qkv=True)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample2D(block_in, use_conv=resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) - - def init_for_ldm( - self, - in_channels, - model_channels, - channel_mult, - num_res_blocks, - dropout, - time_embed_dim, - attention_resolutions, - num_head_channels, - num_heads, - legacy, - use_spatial_transformer, - transformer_depth, - context_dim, - conv_resample, - out_channels, - ): - # TODO(PVP) - delete after weight conversion - class TimestepEmbedSequential(nn.Sequential): - """ - A sequential module that passes timestep embeddings to the children that support it as an extra input. - """ - - pass - - # TODO(PVP) - delete after weight conversion - def conv_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D convolution module. - """ - if dims == 1: - return nn.Conv1d(*args, **kwargs) - elif dims == 2: - return nn.Conv2d(*args, **kwargs) - elif dims == 3: - return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - self.time_embed = nn.Sequential( - nn.Linear(model_channels, time_embed_dim), - nn.SiLU(), - nn.Linear(time_embed_dim, time_embed_dim), - ) - - dims = 2 - self.input_blocks = nn.ModuleList( - [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))] - ) - - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - for level, mult in enumerate(channel_mult): - for _ in range(num_res_blocks): - layers = [ - ResnetBlock2D( - in_channels=ch, - out_channels=mult * model_channels, - dropout=dropout, - temb_channels=time_embed_dim, - eps=1e-5, - non_linearity="silu", - overwrite_for_ldm=True, - ) - ] - ch = mult * model_channels - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - # num_heads = 1 - dim_head = num_head_channels - layers.append( - AttentionBlock( - ch, - num_heads=num_heads, - num_head_channels=dim_head, - ), - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op") - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - # num_heads = 1 - dim_head = num_head_channels - - if dim_head < 0: - dim_head = None - - # TODO(Patrick) - delete after weight conversion - # init to be able to overwrite `self.mid` - self.middle_block = TimestepEmbedSequential( - ResnetBlock2D( - in_channels=ch, - out_channels=None, - dropout=dropout, - temb_channels=time_embed_dim, - eps=1e-5, - non_linearity="silu", - overwrite_for_ldm=True, - ), - AttentionBlock( - ch, - num_heads=num_heads, - num_head_channels=dim_head, - ), - ResnetBlock2D( - in_channels=ch, - out_channels=None, - dropout=dropout, - temb_channels=time_embed_dim, - eps=1e-5, - non_linearity="silu", - overwrite_for_ldm=True, - ), - ) - self._feature_size += ch - - self.output_blocks = nn.ModuleList([]) - for level, mult in list(enumerate(channel_mult))[::-1]: - for i in range(num_res_blocks + 1): - ich = input_block_chans.pop() - layers = [ - ResnetBlock2D( - in_channels=ch + ich, - out_channels=model_channels * mult, - dropout=dropout, - temb_channels=time_embed_dim, - eps=1e-5, - non_linearity="silu", - overwrite_for_ldm=True, - ), - ] - ch = model_channels * mult - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - # num_heads = 1 - dim_head = num_head_channels - layers.append( - AttentionBlock( - ch, - num_heads=-1, - num_head_channels=dim_head, - ), - ) - if level and i == num_res_blocks: - out_ch = ch - layers.append(Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch)) - ds //= 2 - self.output_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - - self.out = nn.Sequential( - nn.GroupNorm(num_channels=model_channels, num_groups=32, eps=1e-5), - nn.SiLU(), - nn.Conv2d(model_channels, out_channels, 3, padding=1), - ) - - def init_for_sde( - self, - image_size, - num_channels, - centered, - attn_resolutions, - ch_mult, - conditional, - conv_size, - dropout, - embedding_type, - fir, - fir_kernel, - fourier_scale, - init_scale, - nf, - num_res_blocks, - progressive, - progressive_combine, - progressive_input, - resamp_with_conv, - scale_by_sigma, - skip_rescale, - continuous, - ): - self.act = nn.SiLU() - self.nf = nf - self.num_res_blocks = num_res_blocks - self.attn_resolutions = attn_resolutions - self.num_resolutions = len(ch_mult) - self.all_resolutions = all_resolutions = [image_size // (2**i) for i in range(self.num_resolutions)] - - self.conditional = conditional - self.skip_rescale = skip_rescale - self.progressive = progressive - self.progressive_input = progressive_input - self.embedding_type = embedding_type - assert progressive in ["none", "output_skip", "residual"] - assert progressive_input in ["none", "input_skip", "residual"] - assert embedding_type in ["fourier", "positional"] - combine_method = progressive_combine.lower() - combiner = functools.partial(Combine, method=combine_method) - - modules = [] - # timestep/noise_level embedding; only for continuous training - if embedding_type == "fourier": - # Gaussian Fourier features embeddings. - modules.append(GaussianFourierProjection(embedding_size=nf, scale=fourier_scale)) - embed_dim = 2 * nf - - elif embedding_type == "positional": - embed_dim = nf - - else: - raise ValueError(f"embedding type {embedding_type} unknown.") - - modules.append(nn.Linear(embed_dim, nf * 4)) - modules.append(nn.Linear(nf * 4, nf * 4)) - - AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0)) - - if fir: - Up_sample = functools.partial(FirUpsample2D, fir_kernel=fir_kernel, use_conv=resamp_with_conv) - else: - Up_sample = functools.partial(Upsample2D, name="Conv2d_0") - - if progressive == "output_skip": - self.pyramid_upsample = Up_sample(channels=None, use_conv=False) - elif progressive == "residual": - pyramid_upsample = functools.partial(Up_sample, use_conv=True) - - if fir: - Down_sample = functools.partial(FirDownsample2D, fir_kernel=fir_kernel, use_conv=resamp_with_conv) - else: - Down_sample = functools.partial(Downsample2D, padding=0, name="Conv2d_0") - - if progressive_input == "input_skip": - self.pyramid_downsample = Down_sample(channels=None, use_conv=False) - elif progressive_input == "residual": - pyramid_downsample = functools.partial(Down_sample, use_conv=True) - - channels = num_channels - if progressive_input != "none": - input_pyramid_ch = channels - - modules.append(nn.Conv2d(channels, nf, kernel_size=3, padding=1)) - hs_c = [nf] - - in_ch = nf - for i_level in range(self.num_resolutions): - # Residual blocks for this resolution - for i_block in range(num_res_blocks): - out_ch = nf * ch_mult[i_level] - modules.append( - ResnetBlock2D( - in_channels=in_ch, - out_channels=out_ch, - temb_channels=4 * nf, - output_scale_factor=np.sqrt(2.0), - non_linearity="silu", - groups=min(in_ch // 4, 32), - groups_out=min(out_ch // 4, 32), - overwrite_for_score_vde=True, - ) - ) - in_ch = out_ch - - if all_resolutions[i_level] in attn_resolutions: - modules.append(AttnBlock(channels=in_ch)) - hs_c.append(in_ch) - - if i_level != self.num_resolutions - 1: - modules.append( - ResnetBlock2D( - in_channels=in_ch, - temb_channels=4 * nf, - output_scale_factor=np.sqrt(2.0), - non_linearity="silu", - groups=min(in_ch // 4, 32), - groups_out=min(out_ch // 4, 32), - overwrite_for_score_vde=True, - down=True, - kernel="fir" if fir else "sde_vp", - use_nin_shortcut=True, - ) - ) - - if progressive_input == "input_skip": - modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch)) - if combine_method == "cat": - in_ch *= 2 - - elif progressive_input == "residual": - modules.append(pyramid_downsample(channels=input_pyramid_ch, out_channels=in_ch)) - input_pyramid_ch = in_ch - - hs_c.append(in_ch) - - # mid - in_ch = hs_c[-1] - modules.append( - ResnetBlock2D( - in_channels=in_ch, - temb_channels=4 * nf, - output_scale_factor=np.sqrt(2.0), - non_linearity="silu", - groups=min(in_ch // 4, 32), - groups_out=min(out_ch // 4, 32), - overwrite_for_score_vde=True, - ) - ) - modules.append(AttnBlock(channels=in_ch)) - modules.append( - ResnetBlock2D( - in_channels=in_ch, - temb_channels=4 * nf, - output_scale_factor=np.sqrt(2.0), - non_linearity="silu", - groups=min(in_ch // 4, 32), - groups_out=min(out_ch // 4, 32), - overwrite_for_score_vde=True, - ) - ) - - pyramid_ch = 0 - # Upsampling block - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(num_res_blocks + 1): - out_ch = nf * ch_mult[i_level] - in_ch = in_ch + hs_c.pop() - modules.append( - ResnetBlock2D( - in_channels=in_ch, - out_channels=out_ch, - temb_channels=4 * nf, - output_scale_factor=np.sqrt(2.0), - non_linearity="silu", - groups=min(in_ch // 4, 32), - groups_out=min(out_ch // 4, 32), - overwrite_for_score_vde=True, - ) - ) - in_ch = out_ch - - if all_resolutions[i_level] in attn_resolutions: - modules.append(AttnBlock(channels=in_ch)) - - if progressive != "none": - if i_level == self.num_resolutions - 1: - if progressive == "output_skip": - modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) - modules.append(nn.Conv2d(in_ch, channels, kernel_size=3, padding=1)) - pyramid_ch = channels - elif progressive == "residual": - modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) - modules.append(nn.Conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1)) - pyramid_ch = in_ch - else: - raise ValueError(f"{progressive} is not a valid name.") - else: - if progressive == "output_skip": - modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) - modules.append(nn.Conv2d(in_ch, channels, bias=True, kernel_size=3, padding=1)) - pyramid_ch = channels - elif progressive == "residual": - modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch)) - pyramid_ch = in_ch - else: - raise ValueError(f"{progressive} is not a valid name") - - if i_level != 0: - modules.append( - ResnetBlock2D( - in_channels=in_ch, - temb_channels=4 * nf, - output_scale_factor=np.sqrt(2.0), - non_linearity="silu", - groups=min(in_ch // 4, 32), - groups_out=min(out_ch // 4, 32), - overwrite_for_score_vde=True, - up=True, - kernel="fir" if fir else "sde_vp", - use_nin_shortcut=True, - ) - ) - - assert not hs_c - - if progressive != "output_skip": - modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) - modules.append(nn.Conv2d(in_ch, channels, kernel_size=3, padding=1)) - - self.all_modules = nn.ModuleList(modules) - - def remove_ldm(self): - del self.time_embed - del self.input_blocks - del self.middle_block - del self.output_blocks - del self.out - - def remove_ddpm(self): - del self.temb - del self.down - del self.mid_new - del self.up - del self.norm_out - - def remove_sde(self): - del self.all_modules - - -def nonlinearity(x): - # swish - return x * torch.sigmoid(x) - - -def Normalize(in_channels): - return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 52ea279f..9b19419d 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -4,9 +4,7 @@ from .ddpm import DDPMPipeline from .latent_diffusion_uncond import LatentDiffusionUncondPipeline from .pndm import PNDMPipeline from .score_sde_ve import ScoreSdeVePipeline -from .score_sde_vp import ScoreSdeVpPipeline if is_transformers_available(): - from .glide import GlidePipeline from .latent_diffusion import LatentDiffusionPipeline diff --git a/src/diffusers/pipelines/glide/__init__.py b/src/diffusers/pipelines/glide/__init__.py deleted file mode 100644 index d4bac0eb..00000000 --- a/src/diffusers/pipelines/glide/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from ...utils import is_transformers_available - - -if is_transformers_available(): - from .pipeline_glide import CLIPTextModel, GlidePipeline diff --git a/src/diffusers/pipelines/glide/pipeline_glide.py b/src/diffusers/pipelines/glide/pipeline_glide.py deleted file mode 100644 index 5ba28d87..00000000 --- a/src/diffusers/pipelines/glide/pipeline_glide.py +++ /dev/null @@ -1,843 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The OpenAI Team Authors and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch CLIP model.""" - -import math -from dataclasses import dataclass -from typing import Any, Optional, Tuple, Union - -import torch -import torch.utils.checkpoint -from torch import nn - -from tqdm.auto import tqdm -from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer -from transformers.activations import ACT2FN -from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings - -from ...models import GlideSuperResUNetModel, GlideTextToImageUNetModel -from ...pipeline_utils import DiffusionPipeline -from ...schedulers import DDIMScheduler, DDPMScheduler -from ...utils import logging - - -##################### -# START OF THE CLIP MODEL COPY-PASTE (with a modified attention module) -##################### - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "fusing/glide-base" - -CLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "fusing/glide-base", - # See all CLIP models at https://huggingface.co/models?filter=clip -] - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - -# contrastive loss function, adapted from -# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html -def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: - return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) - - -def clip_loss(similarity: torch.Tensor) -> torch.Tensor: - caption_loss = contrastive_loss(similarity) - image_loss = contrastive_loss(similarity.T) - return (caption_loss + image_loss) / 2.0 - - -@dataclass -class CLIPOutput(ModelOutput): - """ - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): - Contrastive loss for image-text similarity. - logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): - The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text - similarity scores. - logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): - The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image - similarity scores. - text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): - The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`]. - image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): - The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`]. - text_model_output(`BaseModelOutputWithPooling`): - The output of the [`CLIPTextModel`]. - vision_model_output(`BaseModelOutputWithPooling`): - The output of the [`CLIPVisionModel`]. - """ - - loss: Optional[torch.FloatTensor] = None - logits_per_image: torch.FloatTensor = None - logits_per_text: torch.FloatTensor = None - text_embeds: torch.FloatTensor = None - image_embeds: torch.FloatTensor = None - text_model_output: BaseModelOutputWithPooling = None - vision_model_output: BaseModelOutputWithPooling = None - - def to_tuple(self) -> Tuple[Any]: - return tuple( - self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() - for k in self.keys() - ) - - -class CLIPVisionEmbeddings(nn.Module): - def __init__(self, config: CLIPVisionConfig): - super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.image_size = config.image_size - self.patch_size = config.patch_size - - self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) - - self.patch_embedding = nn.Conv2d( - in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False - ) - - self.num_patches = (self.image_size // self.patch_size) ** 2 - self.num_positions = self.num_patches + 1 - self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) - self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1))) - - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] - patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] - patch_embeds = patch_embeds.flatten(2).transpose(1, 2) - - class_embeds = self.class_embedding.expand(batch_size, 1, -1) - embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding(self.position_ids) - return embeddings - - -class CLIPTextEmbeddings(nn.Module): - def __init__(self, config: CLIPTextConfig): - super().__init__() - embed_dim = config.hidden_size - - self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) - self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) - self.use_padding_embeddings = config.use_padding_embeddings - if self.use_padding_embeddings: - self.padding_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) - - # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] - - if position_ids is None: - position_ids = self.position_ids[:, :seq_length] - - if inputs_embeds is None: - inputs_embeds = self.token_embedding(input_ids) - - position_embeddings = self.position_embedding(position_ids) - embeddings = inputs_embeds + position_embeddings - - if self.use_padding_embeddings and attention_mask is not None: - padding_embeddings = self.padding_embedding(position_ids) - embeddings = torch.where(attention_mask.bool().unsqueeze(-1), embeddings, padding_embeddings) - - return embeddings - - -class CLIPAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config): - super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads})." - ) - self.scale = 1 / math.sqrt(math.sqrt(self.head_dim)) - - self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3) - self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - causal_attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - bsz, tgt_len, embed_dim = hidden_states.size() - - qkv_states = self.qkv_proj(hidden_states) - qkv_states = qkv_states.view(bsz, tgt_len, self.num_heads, -1) - query_states, key_states, value_states = torch.split(qkv_states, self.head_dim, dim=-1) - - attn_weights = torch.einsum("bthc,bshc->bhts", query_states * self.scale, key_states * self.scale) - - wdtype = attn_weights.dtype - attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1).type(wdtype) - - attn_output = torch.einsum("bhts,bshc->bthc", attn_weights, value_states) - attn_output = attn_output.reshape(bsz, tgt_len, -1) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights - - -class CLIPMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.activation_fn = ACT2FN[config.hidden_act] - self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) - self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - hidden_states = self.fc2(hidden_states) - return hidden_states - - -class CLIPEncoderLayer(nn.Module): - def __init__(self, config: CLIPConfig): - super().__init__() - self.embed_dim = config.hidden_size - self.self_attn = CLIPAttention(config) - self.layer_norm1 = nn.LayerNorm(self.embed_dim) - self.mlp = CLIPMLP(config) - self.layer_norm2 = nn.LayerNorm(self.embed_dim) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - causal_attention_mask: torch.Tensor, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - `(config.encoder_attention_heads,)`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - """ - residual = hidden_states - - hidden_states = self.layer_norm1(hidden_states) - hidden_states, attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - causal_attention_mask=causal_attention_mask, - output_attentions=output_attentions, - ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.layer_norm2(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - -class CLIPPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = CLIPConfig - base_model_prefix = "clip" - supports_gradient_checkpointing = True - _keys_to_ignore_on_load_missing = [r"position_ids"] - - def _init_weights(self, module): - """Initialize the weights""" - factor = self.config.initializer_factor - if isinstance(module, CLIPTextEmbeddings): - module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) - module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) - if hasattr(module, "padding_embedding"): - module.padding_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) - elif isinstance(module, CLIPVisionEmbeddings): - factor = self.config.initializer_factor - nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) - nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) - nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) - elif isinstance(module, CLIPAttention): - factor = self.config.initializer_factor - in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor - out_proj_std = (module.embed_dim**-0.5) * factor - nn.init.normal_(module.qkv_proj.weight, std=in_proj_std) - nn.init.normal_(module.out_proj.weight, std=out_proj_std) - elif isinstance(module, CLIPMLP): - factor = self.config.initializer_factor - in_proj_std = ( - (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor - ) - fc_std = (2 * module.config.hidden_size) ** -0.5 * factor - nn.init.normal_(module.fc1.weight, std=fc_std) - nn.init.normal_(module.fc2.weight, std=in_proj_std) - elif isinstance(module, CLIPModel): - nn.init.normal_( - module.text_projection.weight, - std=module.text_embed_dim**-0.5 * self.config.initializer_factor, - ) - nn.init.normal_( - module.visual_projection.weight, - std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, - ) - - if isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, CLIPEncoder): - module.gradient_checkpointing = value - - -CLIP_START_DOCSTRING = r""" - This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it - as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and - behavior. - - Parameters: - config ([`CLIPConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -CLIP_TEXT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - -CLIP_VISION_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using - [`CLIPFeatureExtractor`]. See [`CLIPFeatureExtractor.__call__`] for details. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - -CLIP_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using - [`CLIPFeatureExtractor`]. See [`CLIPFeatureExtractor.__call__`] for details. - return_loss (`bool`, *optional*): - Whether or not to return the contrastive loss. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -class CLIPEncoder(nn.Module): - """ - Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a - [`CLIPEncoderLayer`]. - - Args: - config: CLIPConfig - """ - - def __init__(self, config: CLIPConfig): - super().__init__() - self.config = config - self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]) - self.gradient_checkpointing = False - - def forward( - self, - inputs_embeds, - attention_mask: Optional[torch.Tensor] = None, - causal_attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutput]: - r""" - Args: - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Causal mask for the text model. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - hidden_states = inputs_embeds - for idx, encoder_layer in enumerate(self.layers): - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), - hidden_states, - attention_mask, - causal_attention_mask, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions - ) - - -class CLIPTextTransformer(nn.Module): - def __init__(self, config: CLIPTextConfig): - super().__init__() - self.config = config - embed_dim = config.hidden_size - self.embeddings = CLIPTextEmbeddings(config) - self.encoder = CLIPEncoder(config) - self.final_layer_norm = nn.LayerNorm(embed_dim) - - @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPooling]: - r""" - Returns: - - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if input_ids is None: - raise ValueError("You have to specify either input_ids") - - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - - hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask) - - bsz, seq_len = input_shape - # CLIP's text model uses causal mask, prepare it here. - # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 - # causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len).to(hidden_states.device) - - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, hidden_states.dtype) - - encoder_outputs = self.encoder( - inputs_embeds=hidden_states, - attention_mask=None, - causal_attention_mask=None, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - last_hidden_state = encoder_outputs[0] - last_hidden_state = self.final_layer_norm(last_hidden_state) - - # text_embeds.shape = [batch_size, sequence_length, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) - pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)] - - if not return_dict: - return (last_hidden_state, pooled_output) + encoder_outputs[1:] - - return BaseModelOutputWithPooling( - last_hidden_state=last_hidden_state, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - def _build_causal_attention_mask(self, bsz, seq_len): - # lazily create causal attention mask, with full attention between the vision tokens - # pytorch uses additive attention mask; fill with -inf - mask = torch.empty(bsz, seq_len, seq_len) - mask.fill_(torch.tensor(float("-inf"))) - mask.triu_(1) # zero out the lower diagonal - mask = mask.unsqueeze(1) # expand mask - return mask - - -class CLIPTextModel(CLIPPreTrainedModel): - config_class = CLIPTextConfig - - def __init__(self, config: CLIPTextConfig): - super().__init__(config) - self.text_model = CLIPTextTransformer(config) - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self) -> nn.Module: - return self.text_model.embeddings.token_embedding - - def set_input_embeddings(self, value): - self.text_model.embeddings.token_embedding = value - - @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPooling]: - r""" - Returns: - - Examples: - - ```python - >>> from transformers import CLIPTokenizer, CLIPTextModel - - >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") - >>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") - - >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") - - >>> outputs = model(**inputs) - >>> last_hidden_state = outputs.last_hidden_state - >>> pooled_output = outputs.pooler_output # pooled (EOS token) states - ```""" - return self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - -##################### -# END OF THE CLIP MODEL COPY-PASTE -##################### - - -def _extract_into_tensor(arr, timesteps, broadcast_shape): - """ - Extract values from a 1-D numpy array for a batch of indices. - - :param arr: the 1-D numpy array. :param timesteps: a tensor of indices into the array to extract. :param - broadcast_shape: a larger shape of K dimensions with the batch - dimension equal to the length of timesteps. - :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. - """ - res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float() - while len(res.shape) < len(broadcast_shape): - res = res[..., None] - return res + torch.zeros(broadcast_shape, device=timesteps.device) - - -class GlidePipeline(DiffusionPipeline): - def __init__( - self, - text_unet: GlideTextToImageUNetModel, - text_scheduler: DDPMScheduler, - text_encoder: CLIPTextModel, - tokenizer: GPT2Tokenizer, - upscale_unet: GlideSuperResUNetModel, - upscale_scheduler: DDIMScheduler, - ): - super().__init__() - self.register_modules( - text_unet=text_unet, - text_scheduler=text_scheduler, - text_encoder=text_encoder, - tokenizer=tokenizer, - upscale_unet=upscale_unet, - upscale_scheduler=upscale_scheduler, - ) - - @torch.no_grad() - def __call__( - self, - prompt, - generator=None, - torch_device=None, - num_inference_steps_upscale=50, - guidance_scale=3.0, - eta=0.0, - upsample_temp=0.997, - ): - - torch_device = "cuda" if torch.cuda.is_available() else "cpu" - - self.text_unet.to(torch_device) - self.text_encoder.to(torch_device) - self.upscale_unet.to(torch_device) - - def text_model_fn(x_t, timesteps, transformer_out, **kwargs): - half = x_t[: len(x_t) // 2] - combined = torch.cat([half, half], dim=0) - model_out = self.text_unet(combined, timesteps, transformer_out, **kwargs) - eps, rest = model_out[:, :3], model_out[:, 3:] - cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) - half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) - eps = torch.cat([half_eps, half_eps], dim=0) - return torch.cat([eps, rest], dim=1) - - # 1. Sample gaussian noise - batch_size = 2 # second image is empty for classifier-free guidance - image = torch.randn( - ( - batch_size, - self.text_unet.in_channels, - self.text_unet.resolution, - self.text_unet.resolution, - ), - generator=generator, - ).to(torch_device) - - # 2. Encode tokens - # an empty input is needed to guide the model away from it - inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt") - input_ids = inputs["input_ids"].to(torch_device) - attention_mask = inputs["attention_mask"].to(torch_device) - transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state - - # 3. Run the text2image generation step - num_prediction_steps = len(self.text_scheduler) - for t in tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps): - with torch.no_grad(): - time_input = torch.tensor([t] * image.shape[0], device=torch_device) - model_output = text_model_fn(image, time_input, transformer_out) - noise_residual, model_var_values = torch.split(model_output, 3, dim=1) - - min_log = self.text_scheduler.get_variance(t, "fixed_small_log") - max_log = self.text_scheduler.get_variance(t, "fixed_large_log") - # The model_var_values is [-1, 1] for [min_var, max_var]. - frac = (model_var_values + 1) / 2 - model_log_variance = frac * max_log + (1 - frac) * min_log - - pred_prev_image = self.text_scheduler.step(noise_residual, image, t) - noise = torch.randn(image.shape, generator=generator).to(torch_device) - variance = torch.exp(0.5 * model_log_variance) * noise - - # set current image to prev_image: x_t -> x_t-1 - image = pred_prev_image + variance - - # 4. Run the upscaling step - batch_size = 1 - image = image[:1] - low_res = ((image + 1) * 127.5).round() / 127.5 - 1 - - # Sample gaussian noise to begin loop - image = torch.randn( - ( - batch_size, - self.upscale_unet.in_channels // 2, - self.upscale_unet.resolution, - self.upscale_unet.resolution, - ), - generator=generator, - ).to(torch_device) - image = image * upsample_temp - - num_trained_timesteps = self.upscale_scheduler.timesteps - inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale) - - for t in tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale): - # 1. predict noise residual - with torch.no_grad(): - time_input = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device) - model_output = self.upscale_unet(image, time_input, low_res) - noise_residual, pred_variance = torch.split(model_output, 3, dim=1) - - # 2. predict previous mean of image x_t-1 - pred_prev_image = self.upscale_scheduler.step( - noise_residual, image, t, num_inference_steps_upscale, eta, use_clipped_residual=True - ) - - # 3. optionally sample variance - variance = 0 - if eta > 0: - noise = torch.randn(image.shape, generator=generator).to(torch_device) - variance = self.upscale_scheduler.get_variance(t, num_inference_steps_upscale).sqrt() * eta * noise - - # 4. set current image to prev_image: x_t -> x_t-1 - image = pred_prev_image + variance - - image = image.clamp(-1, 1).permute(0, 2, 3, 1) - - return image diff --git a/src/diffusers/pipelines/score_sde_vp/__init__.py b/src/diffusers/pipelines/score_sde_vp/__init__.py deleted file mode 100644 index 40660ec6..00000000 --- a/src/diffusers/pipelines/score_sde_vp/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .pipeline_score_sde_vp import ScoreSdeVpPipeline diff --git a/src/diffusers/pipelines/score_sde_vp/pipeline_score_sde_vp.py b/src/diffusers/pipelines/score_sde_vp/pipeline_score_sde_vp.py deleted file mode 100644 index b9cf0884..00000000 --- a/src/diffusers/pipelines/score_sde_vp/pipeline_score_sde_vp.py +++ /dev/null @@ -1,38 +0,0 @@ -#!/usr/bin/env python3 -import torch - -from diffusers import DiffusionPipeline - - -# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names -class ScoreSdeVpPipeline(DiffusionPipeline): - def __init__(self, model, scheduler): - super().__init__() - self.register_modules(model=model, scheduler=scheduler) - - def __call__(self, num_inference_steps=1000, generator=None): - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - - img_size = self.model.config.image_size - channels = self.model.config.num_channels - shape = (1, channels, img_size, img_size) - - model = self.model.to(device) - - x = torch.randn(*shape).to(device) - - self.scheduler.set_timesteps(num_inference_steps) - - for t in self.scheduler.timesteps: - t = t * torch.ones(shape[0], device=device) - scaled_t = t * (num_inference_steps - 1) - - # TODO add corrector - with torch.no_grad(): - result = model(x, scaled_t) - - x, x_mean = self.scheduler.step_pred(result, x, t) - - x_mean = (x_mean + 1.0) / 2.0 - - return x_mean diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 64db9c9b..76bc7bef 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -255,8 +255,6 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): def prepare_init_args_and_inputs_for_common(self): init_dict = { - "ch": 32, - "ch_mult": (1, 2), "block_channels": (32, 64), "down_blocks": ("UNetResDownBlock2D", "UNetResAttnDownBlock2D"), "up_blocks": ("UNetResAttnUpBlock2D", "UNetResUpBlock2D"), @@ -264,8 +262,6 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): "out_channels": 3, "in_channels": 3, "num_res_blocks": 2, - "attn_resolutions": (16,), - "resolution": 32, "image_size": 32, } inputs_dict = self.dummy_input @@ -322,13 +318,11 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): "in_channels": 4, "out_channels": 4, "num_res_blocks": 2, - "attention_resolutions": (16,), "block_channels": (32, 64), "num_head_channels": 32, "conv_resample": True, "down_blocks": ("UNetResDownBlock2D", "UNetResDownBlock2D"), "up_blocks": ("UNetResUpBlock2D", "UNetResUpBlock2D"), - "ldm": True, } inputs_dict = self.dummy_input return init_dict, inputs_dict @@ -529,8 +523,8 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): "ch": 64, "out_ch": 3, "num_res_blocks": 1, - "attn_resolutions": [], "in_channels": 3, + "attn_resolutions": [], "resolution": 32, "z_channels": 3, "n_embed": 256, @@ -605,11 +599,11 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): "ch_mult": (1,), "embed_dim": 4, "in_channels": 3, + "attn_resolutions": [], "num_res_blocks": 1, "out_ch": 3, "resolution": 32, "z_channels": 4, - "attn_resolutions": [], } inputs_dict = self.dummy_input return init_dict, inputs_dict @@ -655,7 +649,6 @@ class PipelineTesterMixin(unittest.TestCase): model = UNetUnconditionalModel( block_channels=(32, 64), num_res_blocks=2, - attn_resolutions=(16,), image_size=32, in_channels=3, out_channels=3,