diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index ba6df510..ac68a6c3 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -10,7 +10,7 @@ from .modeling_utils import ModelMixin from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel from .pipeline_utils import DiffusionPipeline from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline -from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin +from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin, VeSdeScheduler if is_transformers_available(): diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 30671ef2..d46782c7 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -15,10 +15,6 @@ # helpers functions -from ..modeling_utils import ModelMixin -from ..configuration_utils import ConfigMixin - - import functools import math import string @@ -28,16 +24,15 @@ import torch import torch.nn as nn import torch.nn.functional as F +from ..configuration_utils import ConfigMixin +from ..modeling_utils import ModelMixin + def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): - return upfirdn2d_native( - input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] - ) + return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) -def upfirdn2d_native( - input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 -): +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): _, channel, in_h, in_w = input.shape input = input.reshape(-1, in_h, in_w, 1) @@ -48,9 +43,7 @@ def upfirdn2d_native( out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) out = out.view(-1, in_h * up_y, in_w * up_x, minor) - out = F.pad( - out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] - ) + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) out = out[ :, max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), @@ -59,9 +52,7 @@ def upfirdn2d_native( ] out = out.permute(0, 3, 1, 2) - out = out.reshape( - [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] - ) + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) out = F.conv2d(out, w) out = out.reshape( @@ -350,7 +341,7 @@ conv3x3 = ddpm_conv3x3 def _einsum(a, b, c, x, y): - einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c)) + einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c)) return torch.einsum(einsum_str, x, y) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index d26c5fc8..e724149a 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -5,6 +5,9 @@ from .pipeline_ddpm import DDPMPipeline from .pipeline_pndm import PNDMPipeline +# from .pipeline_score_sde import NCSNppPipeline + + if is_transformers_available(): from .pipeline_glide import GlidePipeline from .pipeline_latent_diffusion import LatentDiffusionPipeline diff --git a/src/diffusers/pipelines/pipeline_score_sde.py b/src/diffusers/pipelines/pipeline_score_sde.py new file mode 100755 index 00000000..5b3cb5bc --- /dev/null +++ b/src/diffusers/pipelines/pipeline_score_sde.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +import numpy as np +import torch + +import PIL +from diffusers import DiffusionPipeline + + +# from configs.ve import ffhq_ncsnpp_continuous as configs +# from configs.ve import cifar10_ncsnpp_continuous as configs + +# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth" +# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth" +# Note usually we need to restore ema etc... +# ema restored checkpoint used from below +torch.backends.cuda.matmul.allow_tf32 = False +torch.manual_seed(0) + + +class NCSNppPipeline(DiffusionPipeline): + def __init__(self, model, scheduler): + super().__init__() + self.register_modules(model=model, scheduler=scheduler) + + def __call__(self, generator=None): + N = self.scheduler.config.N + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + img_size = self.model.config.image_size + channels = self.model.config.num_channels + shape = (1, channels, img_size, img_size) + + model = torch.nn.DataParallel(self.model.to(device)) + + centered = False + n_steps = 1 + + # Initial sample + x = torch.randn(*shape) * self.scheduler.config.sigma_max + x = x.to(device) + + for i in range(N): + sigma_t = self.scheduler.get_sigma_t(i) * torch.ones(shape[0], device=device) + + for _ in range(n_steps): + with torch.no_grad(): + result = model(x, sigma_t) + x = self.scheduler.step_correct(result, x) + + with torch.no_grad(): + result = model(x, sigma_t) + + x, x_mean = self.scheduler.step_pred(result, x, i) + + x = x_mean + + if centered: + x = (x + 1.0) / 2.0 + + return x + + +pipeline = NCSNppPipeline.from_pretrained("/home/patrick/ffhq_ncsnpp") +x = pipeline() + + +# for 5 cifar10 +# x_sum = 106071.9922 +# x_mean = 34.52864456176758 + +# for 1000 cifar10 +# x_sum = 461.9700 +# x_mean = 0.1504 + +# for N=2 for 1024 +x_sum = 3382810112.0 +x_mean = 1075.366455078125 + + +def check_x_sum_x_mean(x, x_sum, x_mean): + assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}" + assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" + + +check_x_sum_x_mean(x, x_sum, x_mean) + + +def save_image(x): + image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8) + image_pil = PIL.Image.fromarray(image_processed[0]) + image_pil.save("../images/hey.png") + + +# save_image(x) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index b2d533d3..ea306266 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -21,3 +21,4 @@ from .scheduling_ddpm import DDPMScheduler from .scheduling_grad_tts import GradTTSScheduler from .scheduling_pndm import PNDMScheduler from .scheduling_utils import SchedulerMixin +from .scheduling_ve_sde import VeSdeScheduler diff --git a/src/diffusers/schedulers/scheduling_ve_sde.py b/src/diffusers/schedulers/scheduling_ve_sde.py new file mode 100644 index 00000000..6f188272 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_ve_sde.py @@ -0,0 +1,73 @@ +# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin +from .scheduling_utils import SchedulerMixin + + +class VeSdeScheduler(SchedulerMixin, ConfigMixin): + def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, N=2, sampling_eps=1e-5, tensor_format="np"): + super().__init__() + self.register_to_config( + snr=snr, + sigma_min=sigma_min, + sigma_max=sigma_max, + N=N, + sampling_eps=sampling_eps, + ) + # (PVP) - clean up with .config. + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.snr = snr + self.N = N + self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N)) + self.timesteps = torch.linspace(1, sampling_eps, N) + + def get_sigma_t(self, t): + return self.sigma_min * (self.sigma_max / self.sigma_min) ** self.timesteps[t] + + def step_pred(self, result, x, t): + t = self.timesteps[t] * torch.ones(x.shape[0], device=x.device) + + timestep = (t * (self.N - 1)).long() + sigma = self.discrete_sigmas.to(t.device)[timestep] + adjacent_sigma = torch.where( + timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(t.device) + ) + f = torch.zeros_like(x) + G = torch.sqrt(sigma**2 - adjacent_sigma**2) + + f = f - G[:, None, None, None] ** 2 * result + + z = torch.randn_like(x) + x_mean = x - f + x = x_mean + G[:, None, None, None] * z + return x, x_mean + + def step_correct(self, result, x): + noise = torch.randn_like(x) + grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean() + noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() + step_size = (self.snr * noise_norm / grad_norm) ** 2 * 2 + step_size = step_size * torch.ones(x.shape[0], device=x.device) + x_mean = x + step_size[:, None, None, None] * result + + x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise + + return x