From 5a784f98a6e4676b64ee5290b577b6d805e60131 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 7 Jun 2022 19:41:50 +0200 Subject: [PATCH 01/10] Dev version --- setup.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index d902811b..17a5dc36 100644 --- a/setup.py +++ b/setup.py @@ -28,11 +28,11 @@ To create the package for pypi. 3. Unpin specific versions from setup.py that use a git install. 4. Checkout the release branch (v-release, for example v4.19-release), and commit these changes with the - message: "Release: " and push. + message: "Release: " and push. 5. Wait for the tests on main to be completed and be green (otherwise revert and fix bugs) -6. Add a tag in git to mark the release: "git tag v -m 'Adds tag v for pypi' " +6. Add a tag in git to mark the release: "git tag v -m 'Adds tag v for pypi' " Push the tag to git: git push --tags origin v-release 7. Build both the sources and the wheel. Do not change anything in setup.py between @@ -189,7 +189,7 @@ extras["sagemaker"] = [ setup( name="diffusers", - version="0.0.1", + version="0.0.2", description="Diffusers", long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", @@ -222,8 +222,8 @@ setup( # Release checklist # 1. Change the version in __init__.py and setup.py. -# 2. Commit these changes with the message: "Release: VERSION" -# 3. Add a tag in git to mark the release: "git tag VERSION -m 'Adds tag VERSION for pypi' " +# 2. Commit these changes with the message: "Release: Release" +# 3. Add a tag in git to mark the release: "git tag RELEASE -m 'Adds tag RELEASE for pypi' " # Push the tag to git: git push --tags origin main # 4. Run the following commands in the top-level directory: # python setup.py bdist_wheel From db3757aa06eee0a4ae4c8dc13d65d1760ac26b54 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 8 Jun 2022 08:42:11 +0000 Subject: [PATCH 02/10] up --- run_inference.py | 23 +++++ src/diffusers/schedulers/ddim.py | 102 +++++++++++++++++++ src/diffusers/schedulers/gaussian_ddpm.py | 25 +---- src/diffusers/schedulers/schedulers_utils.py | 38 +++++++ 4 files changed, 164 insertions(+), 24 deletions(-) create mode 100755 run_inference.py create mode 100644 src/diffusers/schedulers/ddim.py create mode 100644 src/diffusers/schedulers/schedulers_utils.py diff --git a/run_inference.py b/run_inference.py new file mode 100755 index 00000000..38cdd3bb --- /dev/null +++ b/run_inference.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +# !pip install diffusers +from diffusers import DiffusionPipeline +import PIL.Image +import numpy as np + +model_id = "fusing/ddpm-cifar10" +model_id = "fusing/ddpm-lsun-bedroom" + +# load model and scheduler +ddpm = DiffusionPipeline.from_pretrained(model_id) + +# run pipeline in inference (sample random noise and denoise) +image = ddpm() + +# process image to PIL +image_processed = image.cpu().permute(0, 2, 3, 1) +image_processed = (image_processed + 1.0) * 127.5 +image_processed = image_processed.numpy().astype(np.uint8) +image_pil = PIL.Image.fromarray(image_processed[0]) + +# save image +image_pil.save("/home/patrick/images/show.png") diff --git a/src/diffusers/schedulers/ddim.py b/src/diffusers/schedulers/ddim.py new file mode 100644 index 00000000..0bcf59d2 --- /dev/null +++ b/src/diffusers/schedulers/ddim.py @@ -0,0 +1,102 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import math +from torch import nn + +from ..configuration_utils import ConfigMixin +from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar + + +SAMPLING_CONFIG_NAME = "scheduler_config.json" + + +class GaussianDDPMScheduler(nn.Module, ConfigMixin): + + config_name = SAMPLING_CONFIG_NAME + + def __init__( + self, + timesteps=1000, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="linear", + variance_type="fixed_small", + ): + super().__init__() + self.register( + timesteps=timesteps, + beta_start=beta_start, + beta_end=beta_end, + beta_schedule=beta_schedule, + variance_type=variance_type, + ) + self.num_timesteps = int(timesteps) + + if beta_schedule == "linear": + betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) + elif beta_schedule == "squaredcos_cap_v2": + # GLIDE cosine schedule + betas = betas_for_alpha_bar( + timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, axis=0) + alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0) + + variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) + + if variance_type == "fixed_small": + log_variance = torch.log(variance.clamp(min=1e-20)) + elif variance_type == "fixed_large": + log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0)) + + self.register_buffer("betas", betas.to(torch.float32)) + self.register_buffer("alphas", alphas.to(torch.float32)) + self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32)) + + self.register_buffer("log_variance", log_variance.to(torch.float32)) + + def get_alpha(self, time_step): + return self.alphas[time_step] + + def get_beta(self, time_step): + return self.betas[time_step] + + def get_alpha_prod(self, time_step): + if time_step < 0: + return torch.tensor(1.0) + return self.alphas_cumprod[time_step] + + def sample_variance(self, time_step, shape, device, generator=None): + variance = self.log_variance[time_step] + nonzero_mask = torch.tensor([1 - (time_step == 0)], device=device).float()[None, :] + + noise = self.sample_noise(shape, device=device, generator=generator) + + sampled_variance = nonzero_mask * (0.5 * variance).exp() + sampled_variance = sampled_variance * noise + + return sampled_variance + + def sample_noise(self, shape, device, generator=None): + # always sample on CPU to be deterministic + return torch.randn(shape, generator=generator).to(device) + + def __len__(self): + return self.num_timesteps diff --git a/src/diffusers/schedulers/gaussian_ddpm.py b/src/diffusers/schedulers/gaussian_ddpm.py index 2a25cbbf..0bcf59d2 100644 --- a/src/diffusers/schedulers/gaussian_ddpm.py +++ b/src/diffusers/schedulers/gaussian_ddpm.py @@ -16,35 +16,12 @@ import math from torch import nn from ..configuration_utils import ConfigMixin +from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar SAMPLING_CONFIG_NAME = "scheduler_config.json" -def linear_beta_schedule(timesteps, beta_start, beta_end): - return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) - - -def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, - which defines the cumulative product of (1-beta) over time from t = [0,1]. - - :param num_diffusion_timesteps: the number of betas to produce. - :param alpha_bar: a lambda that takes an argument t from 0 to 1 and - produces the cumulative product of (1-beta) up to that - part of the diffusion process. - :param max_beta: the maximum beta to use; use values lower than 1 to - prevent singularities. - """ - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float64) - - class GaussianDDPMScheduler(nn.Module, ConfigMixin): config_name = SAMPLING_CONFIG_NAME diff --git a/src/diffusers/schedulers/schedulers_utils.py b/src/diffusers/schedulers/schedulers_utils.py new file mode 100644 index 00000000..582adfd0 --- /dev/null +++ b/src/diffusers/schedulers/schedulers_utils.py @@ -0,0 +1,38 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +def linear_beta_schedule(timesteps, beta_start, beta_end): + return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float64) From ae81c3d69662bc69a5ec20f65e2eaf7cad329922 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 8 Jun 2022 09:13:47 +0000 Subject: [PATCH 03/10] make from pretrained more general --- models/vision/ddim/README.md | 28 ++++++++ models/vision/ddim/modeling_ddim.py | 67 +++++++++++++++++++ models/vision/ddim/run_ddpm.py | 17 +++++ .../vision/ddim/run_inference.py | 4 +- models/vision/ddpm/modeling_ddpm.py | 2 - src/diffusers/pipeline_utils.py | 16 ++--- 6 files changed, 122 insertions(+), 12 deletions(-) create mode 100644 models/vision/ddim/README.md create mode 100644 models/vision/ddim/modeling_ddim.py create mode 100755 models/vision/ddim/run_ddpm.py rename run_inference.py => models/vision/ddim/run_inference.py (86%) diff --git a/models/vision/ddim/README.md b/models/vision/ddim/README.md new file mode 100644 index 00000000..1070c042 --- /dev/null +++ b/models/vision/ddim/README.md @@ -0,0 +1,28 @@ + + +# Denoising Diffusion Implicit Models (DDIM) + +## Overview + +DDPM was proposed in [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) by *Jiaming Song, Chenlin Meng, Stefano Ermon* + +The abstract from the paper is the following: + +*Denoising diffusion probabilistic models (DDPMs) have achieved high quality image generation without adversarial training, yet they require simulating a Markov chain for many steps to produce a sample. To accelerate sampling, we present denoising diffusion implicit models (DDIMs), a more efficient class of iterative implicit probabilistic models with the same training procedure as DDPMs. In DDPMs, the generative process is defined as the reverse of a Markovian diffusion process. We construct a class of non-Markovian diffusion processes that lead to the same training objective, but whose reverse process can be much faster to sample from. We empirically demonstrate that DDIMs can produce high quality samples 10× to 50× faster in terms of wall-clock time compared to DDPMs, allow us to trade off computation for sample quality, and can perform semantically meaningful image interpolation directly in the latent space.* + +Tips: + +- ... +- ... + +This model was contributed by [???](https://huggingface.co/???). The original code can be found [here](https://github.com/hojonathanho/diffusion). diff --git a/models/vision/ddim/modeling_ddim.py b/models/vision/ddim/modeling_ddim.py new file mode 100644 index 00000000..dcd084c0 --- /dev/null +++ b/models/vision/ddim/modeling_ddim.py @@ -0,0 +1,67 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# limitations under the License. + + +from diffusers import DiffusionPipeline +import tqdm +import torch + + +def compute_alpha(beta, t): + beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0) + a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1) + return a + + +class DDIM(DiffusionPipeline): + + def __init__(self, unet, noise_scheduler): + super().__init__() + self.register_modules(unet=unet, noise_scheduler=noise_scheduler) + + def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, inference_time_steps=50): + seq = range(0, self.num_timesteps, self.num_timesteps // inference_time_steps) + b = self.noise_scheduler.betas + if torch_device is None: + torch_device = "cuda" if torch.cuda.is_available() else "cpu" + + self.unet.to(torch_device) + x = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator) + + with torch.no_grad(): + n = batch_size + seq_next = [-1] + list(seq[:-1]) + x0_preds = [] + xs = [x] + for i, j in zip(reversed(seq), reversed(seq_next)): + print(i) + t = (torch.ones(n) * i).to(x.device) + next_t = (torch.ones(n) * j).to(x.device) + at = compute_alpha(b, t.long()) + at_next = compute_alpha(b, next_t.long()) + xt = xs[-1].to('cuda') + et = self.unet(xt, t) + x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() + x0_preds.append(x0_t.to('cpu')) + # eta + c1 = ( + eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt() + ) + c2 = ((1 - at_next) - c1 ** 2).sqrt() + xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et + xs.append(xt_next.to('cpu')) + + import ipdb; ipdb.set_trace() + return xs, x0_preds diff --git a/models/vision/ddim/run_ddpm.py b/models/vision/ddim/run_ddpm.py new file mode 100755 index 00000000..88de9313 --- /dev/null +++ b/models/vision/ddim/run_ddpm.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python3 +import torch + +from diffusers import GaussianDDPMScheduler, UNetModel + + +model = UNetModel(dim=64, dim_mults=(1, 2, 4, 8)) + +diffusion = GaussianDDPMScheduler(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2 + +training_images = torch.randn(8, 3, 128, 128) # your images need to be normalized from a range of -1 to +1 +loss = diffusion(training_images) +loss.backward() +# after a lot of training + +sampled_images = diffusion.sample(batch_size=4) +sampled_images.shape # (4, 3, 128, 128) diff --git a/run_inference.py b/models/vision/ddim/run_inference.py similarity index 86% rename from run_inference.py rename to models/vision/ddim/run_inference.py index 38cdd3bb..59ed5865 100755 --- a/run_inference.py +++ b/models/vision/ddim/run_inference.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # !pip install diffusers -from diffusers import DiffusionPipeline +from modeling_ddim import DDIM import PIL.Image import numpy as np @@ -8,7 +8,7 @@ model_id = "fusing/ddpm-cifar10" model_id = "fusing/ddpm-lsun-bedroom" # load model and scheduler -ddpm = DiffusionPipeline.from_pretrained(model_id) +ddpm = DDIM.from_pretrained(model_id) # run pipeline in inference (sample random noise and denoise) image = ddpm() diff --git a/models/vision/ddpm/modeling_ddpm.py b/models/vision/ddpm/modeling_ddpm.py index e85d3cfe..f84ab452 100644 --- a/models/vision/ddpm/modeling_ddpm.py +++ b/models/vision/ddpm/modeling_ddpm.py @@ -21,8 +21,6 @@ import torch class DDPM(DiffusionPipeline): - modeling_file = "modeling_ddpm.py" - def __init__(self, unet, noise_scheduler): super().__init__() self.register_modules(unet=unet, noise_scheduler=noise_scheduler) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 60ece225..2d4803d3 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -55,14 +55,13 @@ class DiffusionPipeline(ConfigMixin): class_name = module.__class__.__name__ register_dict = {name: (library, class_name)} - # save model index config self.register(**register_dict) # set models setattr(self, name, module) - + register_dict = {"_module" : self.__module__.split(".")[-1] + ".py"} self.register(**register_dict) @@ -101,15 +100,15 @@ class DiffusionPipeline(ConfigMixin): cached_folder = pretrained_model_name_or_path config_dict = cls.get_config_dict(cached_folder) - - module = config_dict["_module"] - class_name_ = config_dict["_class_name"] - - if class_name_ == cls.__name__: + + # if we load from explicit class, let's use it + if cls != DiffusionPipeline: pipeline_class = cls else: + # else we need to load the correct module from the Hub + class_name_ = config_dict["_class_name"] + module = config_dict["_module"] pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder) - init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) @@ -120,6 +119,7 @@ class DiffusionPipeline(ConfigMixin): if library_name == module: # TODO(Suraj) + # for vq pass library = importlib.import_module(library_name) From 86064df7b556e1ffb2a37c0d03ec5d10ecba940a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 8 Jun 2022 09:14:50 +0000 Subject: [PATCH 04/10] fix --- src/diffusers/pipeline_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 2d4803d3..bcce66b4 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -101,13 +101,15 @@ class DiffusionPipeline(ConfigMixin): config_dict = cls.get_config_dict(cached_folder) + module_candidate = config_dict["_module"] + # if we load from explicit class, let's use it if cls != DiffusionPipeline: pipeline_class = cls else: # else we need to load the correct module from the Hub class_name_ = config_dict["_class_name"] - module = config_dict["_module"] + module = module_candidate pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder) init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) @@ -117,7 +119,7 @@ class DiffusionPipeline(ConfigMixin): for name, (library_name, class_name) in init_dict.items(): importable_classes = LOADABLE_CLASSES[library_name] - if library_name == module: + if library_name == module_candidate: # TODO(Suraj) # for vq pass From 4ea4429d1a8368e3b3289208f36eb1eba1af0eb0 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 8 Jun 2022 11:29:09 +0200 Subject: [PATCH 05/10] add unet for ldm --- src/diffusers/models/unet_ldm.py | 1297 ++++++++++++++++++++++++++++++ 1 file changed, 1297 insertions(+) create mode 100644 src/diffusers/models/unet_ldm.py diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py new file mode 100644 index 00000000..d6006a49 --- /dev/null +++ b/src/diffusers/models/unet_ldm.py @@ -0,0 +1,1297 @@ +from inspect import isfunction +from abc import abstractmethod +from functools import partial +import math +from typing import Iterable +from flask import Config + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import repeat, rearrange, einsum + +from ..configuration_utils import ConfigMixin +from ..modeling_utils import ModelMixin + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): + super().__init__() + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) + for d in range(depth)] + ) + + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c') + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) + x = self.proj_out(x) + return x + x_in + +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + +def convert_module_to_f32(l): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.float() + if l.bias is not None: + l.bias.data = l.bias.data.float() + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +class GroupNorm32(nn.GroupNorm): + def __init__(self, num_groups, num_channels, swish, eps=1e-5): + super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps) + self.swish = swish + + def forward(self, x): + y = super().forward(x.float()).to(x.dtype) + if self.swish == 1.0: + y = F.silu(y) + elif self.swish: + y = y * F.sigmoid(y * float(self.swish)) + return y + + +def normalization(channels, swish=0.0): + """ + Make a standard normalization layer, with an optional swish activation. + + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=timesteps.device + ) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = torch.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += torch.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + a = torch.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(ModelMixin, ConfigMixin): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + ): + super().__init__() + + # register all __init__ params with self.register + self.register( + image_size=image_size, + in_channels=in_channels, + model_channels=model_channels, + out_channels=out_channels, + num_res_blocks=num_res_blocks, + attention_resolutions=attention_resolutions, + dropout=dropout, + channel_mult=channel_mult, + conv_resample=conv_resample, + dims=dims, + num_classes=num_classes, + use_checkpoint=use_checkpoint, + use_fp16=use_fp16, + num_heads=num_heads, + num_head_channels=num_head_channels, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + resblock_updown=resblock_updown, + use_new_attention_order=use_new_attention_order, + use_spatial_transformer=use_spatial_transformer, + transformer_depth=transformer_depth, + context_dim=context_dim, + n_embed=n_embed, + legacy=legacy, + ) + + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = torch.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) + From 2f24ce1ce3988b2e622bc9a39f29255494bc62a8 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 8 Jun 2022 11:29:28 +0200 Subject: [PATCH 06/10] rename to UNetLDMModel --- src/diffusers/models/unet_ldm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index d6006a49..0b0b3bc1 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -720,7 +720,7 @@ class QKVAttentionLegacy(nn.Module): return count_flops_attn(model, _x, y) -class UNetModel(ModelMixin, ConfigMixin): +class UNetLDMModel(ModelMixin, ConfigMixin): """ The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. From a9374a0228773dd2ba4224922360a6bd6dee10e6 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 8 Jun 2022 11:29:42 +0200 Subject: [PATCH 07/10] remove unused imports --- src/diffusers/models/unet_ldm.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index 0b0b3bc1..00dfeb18 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -1,9 +1,6 @@ from inspect import isfunction from abc import abstractmethod -from functools import partial import math -from typing import Iterable -from flask import Config import numpy as np import torch From b903d3d3c1a786ea5af6087241419dacea8aa425 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 8 Jun 2022 11:30:14 +0200 Subject: [PATCH 08/10] fix einsum --- src/diffusers/models/unet_ldm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index 00dfeb18..465c168c 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -6,7 +6,7 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from einops import repeat, rearrange, einsum +from einops import repeat, rearrange from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin @@ -180,7 +180,7 @@ class CrossAttention(nn.Module): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale if exists(mask): mask = rearrange(mask, 'b ... -> b (...)') @@ -191,7 +191,7 @@ class CrossAttention(nn.Module): # attention, what we cannot get enough of attn = sim.softmax(dim=-1) - out = einsum('b i j, b j d -> b i d', attn, v) + out = torch.einsum('b i j, b j d -> b i d', attn, v) out = rearrange(out, '(b h) n d -> b n (h d)', h=h) return self.to_out(out) From 7a1323b62fa1a8880051ab91d559a6e6248bee1c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 8 Jun 2022 09:42:31 +0000 Subject: [PATCH 09/10] add first version of ddim --- models/vision/ddim/modeling_ddim.py | 12 ++++++++---- tests/test_modeling_utils.py | 30 +++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/models/vision/ddim/modeling_ddim.py b/models/vision/ddim/modeling_ddim.py index dcd084c0..fa486c32 100644 --- a/models/vision/ddim/modeling_ddim.py +++ b/models/vision/ddim/modeling_ddim.py @@ -32,11 +32,16 @@ class DDIM(DiffusionPipeline): self.register_modules(unet=unet, noise_scheduler=noise_scheduler) def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, inference_time_steps=50): - seq = range(0, self.num_timesteps, self.num_timesteps // inference_time_steps) - b = self.noise_scheduler.betas + # eta is η in paper + if torch_device is None: torch_device = "cuda" if torch.cuda.is_available() else "cpu" + num_timesteps = self.noise_scheduler.num_timesteps + + seq = range(0, num_timesteps, num_timesteps // inference_time_steps) + b = self.noise_scheduler.betas.to(torch_device) + self.unet.to(torch_device) x = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator) @@ -63,5 +68,4 @@ class DDIM(DiffusionPipeline): xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et xs.append(xt_next.to('cpu')) - import ipdb; ipdb.set_trace() - return xs, x0_preds + return xt_next diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 0b7dece2..c5b18e4a 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -25,6 +25,7 @@ import torch from diffusers import GaussianDDPMScheduler, UNetModel from diffusers.pipeline_utils import DiffusionPipeline from models.vision.ddpm.modeling_ddpm import DDPM +from models.vision.ddim.modeling_ddim import DDIM global_rng = random.Random() @@ -205,6 +206,7 @@ class SamplerTesterMixin(unittest.TestCase): class PipelineTesterMixin(unittest.TestCase): + def test_from_pretrained_save_pretrained(self): # 1. Load models model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32) @@ -241,3 +243,31 @@ class PipelineTesterMixin(unittest.TestCase): new_image = ddpm_from_hub(generator=generator) assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass" + + @slow + def test_ddpm_cifar10(self): + generator = torch.manual_seed(0) + model_id = "fusing/ddpm-cifar10" + + ddpm = DDPM.from_pretrained(model_id) + image = ddpm(generator=generator) + + image_slice = image[0, -1, -3:, -3:].cpu() + + assert image.shape == (1, 3, 32, 32) + expected_slice = torch.tensor([0.2250, 0.3375, 0.2360, 0.0930, 0.3440, 0.3156, 0.1937, 0.3585, 0.1761]) + assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + + @slow + def test_ddim_cifar10(self): + generator = torch.manual_seed(0) + model_id = "fusing/ddpm-cifar10" + + ddim = DDIM.from_pretrained(model_id) + image = ddim(generator=generator, eta=0.0) + + image_slice = image[0, -1, -3:, -3:].cpu() + + assert image.shape == (1, 3, 32, 32) + expected_slice = torch.tensor([-0.7688, -0.7690, -0.7597, -0.7660, -0.7713, -0.7531, -0.7009, -0.7098, -0.7350]) + assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 From 4d53a521508955e47b8bdac2f76891136135ad16 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 8 Jun 2022 11:44:27 +0200 Subject: [PATCH 10/10] add unet ldm in init --- src/diffusers/__init__.py | 1 + src/diffusers/models/__init__.py | 1 + src/diffusers/models/unet_ldm.py | 4 ++-- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 3ce4142f..8feb9e81 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -7,5 +7,6 @@ __version__ = "0.0.1" from .modeling_utils import ModelMixin from .models.unet import UNetModel from .models.unet_glide import UNetGLIDEModel +from .models.unet_ldm import UNetLDMModel from .pipeline_utils import DiffusionPipeline from .schedulers.gaussian_ddpm import GaussianDDPMScheduler diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 85f1cc03..6d6c4d3d 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -18,3 +18,4 @@ from .unet import UNetModel from .unet_glide import UNetGLIDEModel +from .unet_ldm import UNetLDMModel diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index 465c168c..57dec0b6 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -830,7 +830,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): self.conv_resample = conv_resample self.num_classes = num_classes self.use_checkpoint = use_checkpoint - self.dtype = torch.float16 if use_fp16 else torch.float32 + self.dtype_ = torch.float16 if use_fp16 else torch.float32 self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample @@ -1060,7 +1060,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): assert y.shape == (x.shape[0],) emb = emb + self.label_emb(y) - h = x.type(self.dtype) + h = x.type(self.dtype_) for module in self.input_blocks: h = module(h, emb, context) hs.append(h)