From 433cb3f801470feef0a6bab3c90c7b303926ec98 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 25 Jun 2022 18:25:43 +0000 Subject: [PATCH 1/6] clean up sde ve more --- README.md | 24 ++++++ src/diffusers/__init__.py | 11 ++- src/diffusers/pipelines/__init__.py | 3 +- ..._score_sde.py => pipeline_score_sde_ve.py} | 76 +++++++++---------- src/diffusers/schedulers/__init__.py | 2 +- ...eduling_ve_sde.py => scheduling_sde_ve.py} | 42 ++++++---- tests/test_modeling_utils.py | 20 +++++ 7 files changed, 120 insertions(+), 58 deletions(-) rename src/diffusers/pipelines/{pipeline_score_sde.py => pipeline_score_sde_ve.py} (53%) rename src/diffusers/schedulers/{scheduling_ve_sde.py => scheduling_sde_ve.py} (63%) diff --git a/README.md b/README.md index 6c2c9799..bee5d880 100644 --- a/README.md +++ b/README.md @@ -226,6 +226,30 @@ image_pil = PIL.Image.fromarray(image_processed[0]) image_pil.save("test.png") ``` +#### **Example 1024x1024 image generation with SDE VE** + +See [paper](https://arxiv.org/abs/2011.13456) for more information on SDE VE. + +```python +from diffusers import DiffusionPipeline +import torch +import PIL.Image + +torch.manual_seed(32) + +score_sde_sv = DiffusionPipeline.from_pretrained("fusing/ffhq_ncsnpp") + +# Note this might take up to 3 minutes on a GPU +image = score_sde_sv(num_inference_steps=2000) + +image = image.permute(0, 2, 3, 1).cpu().numpy() +image = np.clip(image * 255, 0, 255).astype(np.uint8) +image_pil = PIL.Image.fromarray(image[0]) + +# save image +image_pil.save("test.png") +``` + #### **Text to Image generation with Latent Diffusion** _Note: To use latent diffusion install transformers from [this branch](https://github.com/patil-suraj/transformers/tree/ldm-bert)._ diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index ac68a6c3..d8516083 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -9,8 +9,15 @@ __version__ = "0.0.4" from .modeling_utils import ModelMixin from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel from .pipeline_utils import DiffusionPipeline -from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline -from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin, VeSdeScheduler +from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline, ScoreSdeVePipeline +from .schedulers import ( + DDIMScheduler, + DDPMScheduler, + GradTTSScheduler, + PNDMScheduler, + SchedulerMixin, + ScoreSdeVeScheduler, +) if is_transformers_available(): diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index e724149a..b579652e 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -3,9 +3,10 @@ from .pipeline_bddm import BDDMPipeline from .pipeline_ddim import DDIMPipeline from .pipeline_ddpm import DDPMPipeline from .pipeline_pndm import PNDMPipeline +from .pipeline_score_sde_ve import ScoreSdeVePipeline -# from .pipeline_score_sde import NCSNppPipeline +# from .pipeline_score_sde import ScoreSdeVePipeline if is_transformers_available(): diff --git a/src/diffusers/pipelines/pipeline_score_sde.py b/src/diffusers/pipelines/pipeline_score_sde_ve.py similarity index 53% rename from src/diffusers/pipelines/pipeline_score_sde.py rename to src/diffusers/pipelines/pipeline_score_sde_ve.py index 5b3cb5bc..ca759249 100755 --- a/src/diffusers/pipelines/pipeline_score_sde.py +++ b/src/diffusers/pipelines/pipeline_score_sde_ve.py @@ -6,51 +6,44 @@ import PIL from diffusers import DiffusionPipeline -# from configs.ve import ffhq_ncsnpp_continuous as configs -# from configs.ve import cifar10_ncsnpp_continuous as configs - -# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth" -# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth" -# Note usually we need to restore ema etc... -# ema restored checkpoint used from below -torch.backends.cuda.matmul.allow_tf32 = False -torch.manual_seed(0) +# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names -class NCSNppPipeline(DiffusionPipeline): +class ScoreSdeVePipeline(DiffusionPipeline): def __init__(self, model, scheduler): super().__init__() self.register_modules(model=model, scheduler=scheduler) - def __call__(self, generator=None): - N = self.scheduler.config.N + def __call__(self, num_inference_steps=2000, generator=None): device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") img_size = self.model.config.image_size channels = self.model.config.num_channels shape = (1, channels, img_size, img_size) - model = torch.nn.DataParallel(self.model.to(device)) + model = self.model.to(device) centered = False n_steps = 1 - # Initial sample x = torch.randn(*shape) * self.scheduler.config.sigma_max x = x.to(device) - for i in range(N): - sigma_t = self.scheduler.get_sigma_t(i) * torch.ones(shape[0], device=device) + self.scheduler.set_timesteps(num_inference_steps) + self.scheduler.set_sigmas(num_inference_steps) + + for i, t in enumerate(self.scheduler.timesteps): + sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device) for _ in range(n_steps): with torch.no_grad(): - result = model(x, sigma_t) + result = self.model(x, sigma_t) x = self.scheduler.step_correct(result, x) with torch.no_grad(): result = model(x, sigma_t) - x, x_mean = self.scheduler.step_pred(result, x, i) + x, x_mean = self.scheduler.step_pred(result, x, t) x = x_mean @@ -60,9 +53,16 @@ class NCSNppPipeline(DiffusionPipeline): return x -pipeline = NCSNppPipeline.from_pretrained("/home/patrick/ffhq_ncsnpp") -x = pipeline() +# from configs.ve import ffhq_ncsnpp_continuous as configs +# from configs.ve import cifar10_ncsnpp_continuous as configs +# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth" +# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth" +# Note usually we need to restore ema etc... +# ema restored checkpoint used from below + +# pipeline = ScoreSdeVePipeline.from_pretrained("/home/patrick/ffhq_ncsnpp") +# x = pipeline(num_inference_steps=2) # for 5 cifar10 # x_sum = 106071.9922 @@ -73,22 +73,22 @@ x = pipeline() # x_mean = 0.1504 # for N=2 for 1024 -x_sum = 3382810112.0 -x_mean = 1075.366455078125 - - -def check_x_sum_x_mean(x, x_sum, x_mean): - assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}" - assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" - - -check_x_sum_x_mean(x, x_sum, x_mean) - - -def save_image(x): - image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8) - image_pil = PIL.Image.fromarray(image_processed[0]) - image_pil.save("../images/hey.png") - - +# x_sum = 3382810112.0 +# x_mean = 1075.366455078125 +# +# +# def check_x_sum_x_mean(x, x_sum, x_mean): +# assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}" +# assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" +# +# +# check_x_sum_x_mean(x, x_sum, x_mean) +# +# +# def save_image(x): +# image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8) +# image_pil = PIL.Image.fromarray(image_processed[0]) +# image_pil.save("../images/hey.png") +# +# # save_image(x) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index ea306266..36bc441b 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -21,4 +21,4 @@ from .scheduling_ddpm import DDPMScheduler from .scheduling_grad_tts import GradTTSScheduler from .scheduling_pndm import PNDMScheduler from .scheduling_utils import SchedulerMixin -from .scheduling_ve_sde import VeSdeScheduler +from .scheduling_sde_ve import ScoreSdeVeScheduler diff --git a/src/diffusers/schedulers/scheduling_ve_sde.py b/src/diffusers/schedulers/scheduling_sde_ve.py similarity index 63% rename from src/diffusers/schedulers/scheduling_ve_sde.py rename to src/diffusers/schedulers/scheduling_sde_ve.py index 6f188272..652314b9 100644 --- a/src/diffusers/schedulers/scheduling_ve_sde.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -1,4 +1,4 @@ -# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved. +# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim +# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch + +# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit import numpy as np import torch @@ -21,7 +23,7 @@ from ..configuration_utils import ConfigMixin from .scheduling_utils import SchedulerMixin -class VeSdeScheduler(SchedulerMixin, ConfigMixin): +class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, N=2, sampling_eps=1e-5, tensor_format="np"): super().__init__() self.register_to_config( @@ -31,24 +33,32 @@ class VeSdeScheduler(SchedulerMixin, ConfigMixin): N=N, sampling_eps=sampling_eps, ) - # (PVP) - clean up with .config. - self.sigma_min = sigma_min - self.sigma_max = sigma_max - self.snr = snr - self.N = N - self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N)) - self.timesteps = torch.linspace(1, sampling_eps, N) - def get_sigma_t(self, t): - return self.sigma_min * (self.sigma_max / self.sigma_min) ** self.timesteps[t] + self.sigmas = None + self.discrete_sigmas = None + self.timesteps = None + + def set_timesteps(self, num_inference_steps): + self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) + + def set_sigmas(self, num_inference_steps): + if self.timesteps is None: + self.set_timesteps(num_inference_steps) + + self.discrete_sigmas = torch.exp( + torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps) + ) + self.sigmas = torch.tensor( + [self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps] + ) def step_pred(self, result, x, t): - t = self.timesteps[t] * torch.ones(x.shape[0], device=x.device) + t = t * torch.ones(x.shape[0], device=x.device) + timestep = (t * (2 - 1)).long() - timestep = (t * (self.N - 1)).long() sigma = self.discrete_sigmas.to(t.device)[timestep] adjacent_sigma = torch.where( - timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(t.device) + timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(timestep.device) ) f = torch.zeros_like(x) G = torch.sqrt(sigma**2 - adjacent_sigma**2) @@ -64,7 +74,7 @@ class VeSdeScheduler(SchedulerMixin, ConfigMixin): noise = torch.randn_like(x) grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean() noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() - step_size = (self.snr * noise_norm / grad_norm) ** 2 * 2 + step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 step_size = step_size * torch.ones(x.shape[0], device=x.device) x_mean = x + step_size[:, None, None, None] * result diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index db4ed6eb..15547afb 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -33,8 +33,11 @@ from diffusers import ( GradTTSPipeline, GradTTSScheduler, LatentDiffusionPipeline, + NCSNpp, PNDMPipeline, PNDMScheduler, + ScoreSdeVePipeline, + ScoreSdeVeScheduler, UNetGradTTSModel, UNetLDMModel, UNetModel, @@ -721,6 +724,23 @@ class PipelineTesterMixin(unittest.TestCase): ) assert (mel_spec[0, :3, :3].cpu().flatten() - expected_slice).abs().max() < 1e-2 + @slow + def test_score_sde_ve_pipeline(self): + torch.manual_seed(0) + + model = NCSNpp.from_pretrained("fusing/ffhq_ncsnpp") + scheduler = ScoreSdeVeScheduler.from_config("fusing/ffhq_ncsnpp") + + sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler) + + image = sde_ve(num_inference_steps=2) + + expected_image_sum = 3382810112.0 + expected_image_mean = 1075.366455078125 + + assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2 + assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4 + def test_module_from_pipeline(self): model = DiffWave(num_res_layers=4) noise_scheduler = DDPMScheduler(timesteps=12) From 135acd83af86b02c1dfb3bdb5650d19ef10332b2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 26 Jun 2022 00:56:18 +0000 Subject: [PATCH 2/6] fix bug --- src/diffusers/schedulers/scheduling_sde_ve.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 652314b9..2456afad 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -24,13 +24,12 @@ from .scheduling_utils import SchedulerMixin class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): - def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, N=2, sampling_eps=1e-5, tensor_format="np"): + def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, sampling_eps=1e-5, tensor_format="np"): super().__init__() self.register_to_config( snr=snr, sigma_min=sigma_min, sigma_max=sigma_max, - N=N, sampling_eps=sampling_eps, ) @@ -54,7 +53,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): def step_pred(self, result, x, t): t = t * torch.ones(x.shape[0], device=x.device) - timestep = (t * (2 - 1)).long() + timestep = (t * (len(self.timesteps) - 1)).long() sigma = self.discrete_sigmas.to(t.device)[timestep] adjacent_sigma = torch.where( From d5c527a499cf284f6756e0a28b68e14e808dfcc9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 26 Jun 2022 11:02:57 +0000 Subject: [PATCH 3/6] clean up --- .../pipelines/pipeline_score_sde_ve.py | 55 +------------------ 1 file changed, 2 insertions(+), 53 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_score_sde_ve.py b/src/diffusers/pipelines/pipeline_score_sde_ve.py index ca759249..a1a4843a 100755 --- a/src/diffusers/pipelines/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/pipeline_score_sde_ve.py @@ -1,14 +1,9 @@ #!/usr/bin/env python3 -import numpy as np import torch - -import PIL from diffusers import DiffusionPipeline # TODO(Patrick, Anton, Suraj) - rename `x` to better variable names - - class ScoreSdeVePipeline(DiffusionPipeline): def __init__(self, model, scheduler): super().__init__() @@ -23,7 +18,7 @@ class ScoreSdeVePipeline(DiffusionPipeline): model = self.model.to(device) - centered = False + # TODO(Patrick) move to scheduler config n_steps = 1 x = torch.randn(*shape) * self.scheduler.config.sigma_max @@ -45,50 +40,4 @@ class ScoreSdeVePipeline(DiffusionPipeline): x, x_mean = self.scheduler.step_pred(result, x, t) - x = x_mean - - if centered: - x = (x + 1.0) / 2.0 - - return x - - -# from configs.ve import ffhq_ncsnpp_continuous as configs -# from configs.ve import cifar10_ncsnpp_continuous as configs - -# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth" -# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth" -# Note usually we need to restore ema etc... -# ema restored checkpoint used from below - -# pipeline = ScoreSdeVePipeline.from_pretrained("/home/patrick/ffhq_ncsnpp") -# x = pipeline(num_inference_steps=2) - -# for 5 cifar10 -# x_sum = 106071.9922 -# x_mean = 34.52864456176758 - -# for 1000 cifar10 -# x_sum = 461.9700 -# x_mean = 0.1504 - -# for N=2 for 1024 -# x_sum = 3382810112.0 -# x_mean = 1075.366455078125 -# -# -# def check_x_sum_x_mean(x, x_sum, x_mean): -# assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}" -# assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" -# -# -# check_x_sum_x_mean(x, x_sum, x_mean) -# -# -# def save_image(x): -# image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8) -# image_pil = PIL.Image.fromarray(image_processed[0]) -# image_pil.save("../images/hey.png") -# -# -# save_image(x) + return x_mean From dc6d028654c7a6f1ae22728bddf4509206127ac0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 26 Jun 2022 23:41:55 +0000 Subject: [PATCH 4/6] add vp sampler --- src/diffusers/__init__.py | 3 +- .../models/unet_sde_score_estimation.py | 2 +- src/diffusers/pipelines/__init__.py | 1 + .../pipelines/pipeline_score_sde_ve.py | 0 .../pipelines/pipeline_score_sde_vp.py | 42 ++++++++++++++ src/diffusers/schedulers/__init__.py | 1 + src/diffusers/schedulers/scheduling_sde_vp.py | 55 +++++++++++++++++++ tests/test_modeling_utils.py | 19 +++++++ 8 files changed, 121 insertions(+), 2 deletions(-) mode change 100755 => 100644 src/diffusers/pipelines/pipeline_score_sde_ve.py create mode 100644 src/diffusers/pipelines/pipeline_score_sde_vp.py create mode 100644 src/diffusers/schedulers/scheduling_sde_vp.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index d8516083..213b9a5b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -9,7 +9,7 @@ __version__ = "0.0.4" from .modeling_utils import ModelMixin from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel from .pipeline_utils import DiffusionPipeline -from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline, ScoreSdeVePipeline +from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline, ScoreSdeVePipeline, ScoreSdeVpPipeline from .schedulers import ( DDIMScheduler, DDPMScheduler, @@ -17,6 +17,7 @@ from .schedulers import ( PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler, + ScoreSdeVpScheduler, ) diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index d46782c7..784d528d 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -766,7 +766,7 @@ class NCSNpp(ModelMixin, ConfigMixin): continuous=continuous, ) self.act = act = get_act(nonlinearity) - # self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config))) + self.register_buffer('sigmas', torch.tensor(np.linspace(np.log(50), np.log(0.01), 10))) self.nf = nf self.num_res_blocks = num_res_blocks diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index b579652e..5d7b1f14 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -4,6 +4,7 @@ from .pipeline_ddim import DDIMPipeline from .pipeline_ddpm import DDPMPipeline from .pipeline_pndm import PNDMPipeline from .pipeline_score_sde_ve import ScoreSdeVePipeline +from .pipeline_score_sde_vp import ScoreSdeVpPipeline # from .pipeline_score_sde import ScoreSdeVePipeline diff --git a/src/diffusers/pipelines/pipeline_score_sde_ve.py b/src/diffusers/pipelines/pipeline_score_sde_ve.py old mode 100755 new mode 100644 diff --git a/src/diffusers/pipelines/pipeline_score_sde_vp.py b/src/diffusers/pipelines/pipeline_score_sde_vp.py new file mode 100644 index 00000000..9eb88629 --- /dev/null +++ b/src/diffusers/pipelines/pipeline_score_sde_vp.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +import torch +from diffusers import DiffusionPipeline + + +# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names +class ScoreSdeVpPipeline(DiffusionPipeline): + def __init__(self, model, scheduler): + super().__init__() + self.register_modules(model=model, scheduler=scheduler) + + def __call__(self, num_inference_steps=1000, generator=None): + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + img_size = self.model.config.image_size + channels = self.model.config.num_channels + shape = (1, channels, img_size, img_size) + + beta_min, beta_max = 0.1, 20 + + model = self.model.to(device) + + x = torch.randn(*shape).to(device) + + self.scheduler.set_timesteps(num_inference_steps) + + for i, t in enumerate(self.scheduler.timesteps): + t = t * torch.ones(shape[0], device=device) + sigma_t = t * (num_inference_steps - 1) + + with torch.no_grad(): + result = model(x, sigma_t) + + log_mean_coeff = -0.25 * t ** 2 * (beta_max - beta_min) - 0.5 * t * beta_min + std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) + result = -result / std[:, None, None, None] + + x, x_mean = self.scheduler.step_pred(result, x, t) + + x_mean = (x_mean + 1.) / 2. + + return x_mean diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 36bc441b..6a6d6286 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -22,3 +22,4 @@ from .scheduling_grad_tts import GradTTSScheduler from .scheduling_pndm import PNDMScheduler from .scheduling_utils import SchedulerMixin from .scheduling_sde_ve import ScoreSdeVeScheduler +from .scheduling_sde_vp import ScoreSdeVpScheduler diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py new file mode 100644 index 00000000..c7b64971 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -0,0 +1,55 @@ +# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch + +# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin +from .scheduling_utils import SchedulerMixin + + +class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): + def __init__(self, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"): + super().__init__() + self.register_to_config( + beta_min=beta_min, + beta_max=beta_max, + sampling_eps=sampling_eps, + ) + + self.sigmas = None + self.discrete_sigmas = None + self.timesteps = None + + def set_timesteps(self, num_inference_steps): + self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) + + def step_pred(self, result, x, t): + dt = -1. / len(self.timesteps) + z = torch.randn_like(x) + + beta_t = self.beta_min + t * (self.beta_max - self.beta_min) + drift = -0.5 * beta_t[:, None, None, None] * x + diffusion = torch.sqrt(beta_t) + + drift = drift - diffusion[:, None, None, None] ** 2 * result + + x_mean = x + drift * dt + x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * z + + return x, x_mean diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 15547afb..32bc3003 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -38,6 +38,8 @@ from diffusers import ( PNDMScheduler, ScoreSdeVePipeline, ScoreSdeVeScheduler, + ScoreSdeVpPipeline, + ScoreSdeVpScheduler, UNetGradTTSModel, UNetLDMModel, UNetModel, @@ -741,6 +743,23 @@ class PipelineTesterMixin(unittest.TestCase): assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2 assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4 + @slow + def test_score_sde_vp_pipeline(self): + + model = NCSNpp.from_pretrained("/home/patrick/cifar10-ddpmpp-vp") + scheduler = ScoreSdeVpScheduler() + + sde_vp = ScoreSdeVpPipeline(model=model, scheduler=scheduler) + + torch.manual_seed(0) + image = sde_vp(num_inference_steps=10) + + expected_image_sum = 4183.2012 + expected_image_mean = 1.3617 + + assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2 + assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4 + def test_module_from_pipeline(self): model = DiffWave(num_res_layers=4) noise_scheduler = DDPMScheduler(timesteps=12) From ba264419f40b94fd2e8135096db4780e1c188aef Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 27 Jun 2022 00:07:57 +0000 Subject: [PATCH 5/6] finish vp --- .../models/unet_sde_score_estimation.py | 5 ++--- .../pipelines/pipeline_score_sde_ve.py | 1 + .../pipelines/pipeline_score_sde_vp.py | 15 +++++---------- src/diffusers/schedulers/__init__.py | 2 +- src/diffusers/schedulers/scheduling_sde_ve.py | 2 ++ src/diffusers/schedulers/scheduling_sde_vp.py | 19 ++++++++++++++----- tests/test_modeling_utils.py | 4 ++-- 7 files changed, 27 insertions(+), 21 deletions(-) diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 784d528d..299f96c9 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -766,7 +766,6 @@ class NCSNpp(ModelMixin, ConfigMixin): continuous=continuous, ) self.act = act = get_act(nonlinearity) - self.register_buffer('sigmas', torch.tensor(np.linspace(np.log(50), np.log(0.01), 10))) self.nf = nf self.num_res_blocks = num_res_blocks @@ -939,7 +938,7 @@ class NCSNpp(ModelMixin, ConfigMixin): self.all_modules = nn.ModuleList(modules) - def forward(self, x, time_cond): + def forward(self, x, time_cond, sigmas=None): # timestep/noise_level embedding; only for continuous training modules = self.all_modules m_idx = 0 @@ -952,7 +951,7 @@ class NCSNpp(ModelMixin, ConfigMixin): elif self.embedding_type == "positional": # Sinusoidal positional embeddings. timesteps = time_cond - used_sigmas = self.sigmas[time_cond.long()] + used_sigmas = sigmas temb = get_timestep_embedding(timesteps, self.nf) else: diff --git a/src/diffusers/pipelines/pipeline_score_sde_ve.py b/src/diffusers/pipelines/pipeline_score_sde_ve.py index a1a4843a..1dfd304d 100644 --- a/src/diffusers/pipelines/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/pipeline_score_sde_ve.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 import torch + from diffusers import DiffusionPipeline diff --git a/src/diffusers/pipelines/pipeline_score_sde_vp.py b/src/diffusers/pipelines/pipeline_score_sde_vp.py index 9eb88629..29551d9a 100644 --- a/src/diffusers/pipelines/pipeline_score_sde_vp.py +++ b/src/diffusers/pipelines/pipeline_score_sde_vp.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 import torch + from diffusers import DiffusionPipeline @@ -16,27 +17,21 @@ class ScoreSdeVpPipeline(DiffusionPipeline): channels = self.model.config.num_channels shape = (1, channels, img_size, img_size) - beta_min, beta_max = 0.1, 20 - model = self.model.to(device) x = torch.randn(*shape).to(device) self.scheduler.set_timesteps(num_inference_steps) - for i, t in enumerate(self.scheduler.timesteps): + for t in self.scheduler.timesteps: t = t * torch.ones(shape[0], device=device) - sigma_t = t * (num_inference_steps - 1) + scaled_t = t * (num_inference_steps - 1) with torch.no_grad(): - result = model(x, sigma_t) - - log_mean_coeff = -0.25 * t ** 2 * (beta_max - beta_min) - 0.5 * t * beta_min - std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) - result = -result / std[:, None, None, None] + result = model(x, scaled_t) x, x_mean = self.scheduler.step_pred(result, x, t) - x_mean = (x_mean + 1.) / 2. + x_mean = (x_mean + 1.0) / 2.0 return x_mean diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 6a6d6286..ad66fe59 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -20,6 +20,6 @@ from .scheduling_ddim import DDIMScheduler from .scheduling_ddpm import DDPMScheduler from .scheduling_grad_tts import GradTTSScheduler from .scheduling_pndm import PNDMScheduler -from .scheduling_utils import SchedulerMixin from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler +from .scheduling_utils import SchedulerMixin diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 2456afad..79936105 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -52,6 +52,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ) def step_pred(self, result, x, t): + # TODO(Patrick) better comments + non-PyTorch t = t * torch.ones(x.shape[0], device=x.device) timestep = (t * (len(self.timesteps) - 1)).long() @@ -70,6 +71,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): return x, x_mean def step_correct(self, result, x): + # TODO(Patrick) better comments + non-PyTorch noise = torch.randn_like(x) grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean() noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index c7b64971..dda32a27 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -40,16 +40,25 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) def step_pred(self, result, x, t): - dt = -1. / len(self.timesteps) - z = torch.randn_like(x) + # TODO(Patrick) better comments + non-PyTorch + # postprocess model result + log_mean_coeff = ( + -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min + ) + std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) + result = -result / std[:, None, None, None] - beta_t = self.beta_min + t * (self.beta_max - self.beta_min) + # compute + dt = -1.0 / len(self.timesteps) + + beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min) drift = -0.5 * beta_t[:, None, None, None] * x diffusion = torch.sqrt(beta_t) - drift = drift - diffusion[:, None, None, None] ** 2 * result - x_mean = x + drift * dt + + # add noise + z = torch.randn_like(x) x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * z return x, x_mean diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 32bc3003..6c5c115f 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -746,8 +746,8 @@ class PipelineTesterMixin(unittest.TestCase): @slow def test_score_sde_vp_pipeline(self): - model = NCSNpp.from_pretrained("/home/patrick/cifar10-ddpmpp-vp") - scheduler = ScoreSdeVpScheduler() + model = NCSNpp.from_pretrained("fusing/cifar10-ddpmpp-vp") + scheduler = ScoreSdeVpScheduler.from_config("fusing/cifar10-ddpmpp-vp") sde_vp = ScoreSdeVpPipeline(model=model, scheduler=scheduler) From 9a4d53a4762e6b4c8766f66fcb02f78b99f170b5 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 27 Jun 2022 02:09:49 +0200 Subject: [PATCH 6/6] Update README.md --- README.md | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/README.md b/README.md index bee5d880..7f2704e5 100644 --- a/README.md +++ b/README.md @@ -234,6 +234,7 @@ See [paper](https://arxiv.org/abs/2011.13456) for more information on SDE VE. from diffusers import DiffusionPipeline import torch import PIL.Image +import numpy as np torch.manual_seed(32) @@ -249,6 +250,31 @@ image_pil = PIL.Image.fromarray(image[0]) # save image image_pil.save("test.png") ``` +#### **Example 32x32 image generation with SDE VP** + +See [paper](https://arxiv.org/abs/2011.13456) for more information on SDE VE. + +```python +from diffusers import DiffusionPipeline +import torch +import PIL.Image +import numpy as np + +torch.manual_seed(32) + +score_sde_sv = DiffusionPipeline.from_pretrained("fusing/cifar10-ddpmpp-deep-vp") + +# Note this might take up to 3 minutes on a GPU +image = score_sde_sv(num_inference_steps=1000) + +image = image.permute(0, 2, 3, 1).cpu().numpy() +image = np.clip(image * 255, 0, 255).astype(np.uint8) +image_pil = PIL.Image.fromarray(image[0]) + +# save image +image_pil.save("test.png") +``` + #### **Text to Image generation with Latent Diffusion**