diff --git a/README.md b/README.md index 6c2c9799..7f2704e5 100644 --- a/README.md +++ b/README.md @@ -226,6 +226,56 @@ image_pil = PIL.Image.fromarray(image_processed[0]) image_pil.save("test.png") ``` +#### **Example 1024x1024 image generation with SDE VE** + +See [paper](https://arxiv.org/abs/2011.13456) for more information on SDE VE. + +```python +from diffusers import DiffusionPipeline +import torch +import PIL.Image +import numpy as np + +torch.manual_seed(32) + +score_sde_sv = DiffusionPipeline.from_pretrained("fusing/ffhq_ncsnpp") + +# Note this might take up to 3 minutes on a GPU +image = score_sde_sv(num_inference_steps=2000) + +image = image.permute(0, 2, 3, 1).cpu().numpy() +image = np.clip(image * 255, 0, 255).astype(np.uint8) +image_pil = PIL.Image.fromarray(image[0]) + +# save image +image_pil.save("test.png") +``` +#### **Example 32x32 image generation with SDE VP** + +See [paper](https://arxiv.org/abs/2011.13456) for more information on SDE VE. + +```python +from diffusers import DiffusionPipeline +import torch +import PIL.Image +import numpy as np + +torch.manual_seed(32) + +score_sde_sv = DiffusionPipeline.from_pretrained("fusing/cifar10-ddpmpp-deep-vp") + +# Note this might take up to 3 minutes on a GPU +image = score_sde_sv(num_inference_steps=1000) + +image = image.permute(0, 2, 3, 1).cpu().numpy() +image = np.clip(image * 255, 0, 255).astype(np.uint8) +image_pil = PIL.Image.fromarray(image[0]) + +# save image +image_pil.save("test.png") +``` + + #### **Text to Image generation with Latent Diffusion** _Note: To use latent diffusion install transformers from [this branch](https://github.com/patil-suraj/transformers/tree/ldm-bert)._ diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index ac68a6c3..213b9a5b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -9,8 +9,16 @@ __version__ = "0.0.4" from .modeling_utils import ModelMixin from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel from .pipeline_utils import DiffusionPipeline -from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline -from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin, VeSdeScheduler +from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline, ScoreSdeVePipeline, ScoreSdeVpPipeline +from .schedulers import ( + DDIMScheduler, + DDPMScheduler, + GradTTSScheduler, + PNDMScheduler, + SchedulerMixin, + ScoreSdeVeScheduler, + ScoreSdeVpScheduler, +) if is_transformers_available(): diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index d46782c7..299f96c9 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -766,7 +766,6 @@ class NCSNpp(ModelMixin, ConfigMixin): continuous=continuous, ) self.act = act = get_act(nonlinearity) - # self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config))) self.nf = nf self.num_res_blocks = num_res_blocks @@ -939,7 +938,7 @@ class NCSNpp(ModelMixin, ConfigMixin): self.all_modules = nn.ModuleList(modules) - def forward(self, x, time_cond): + def forward(self, x, time_cond, sigmas=None): # timestep/noise_level embedding; only for continuous training modules = self.all_modules m_idx = 0 @@ -952,7 +951,7 @@ class NCSNpp(ModelMixin, ConfigMixin): elif self.embedding_type == "positional": # Sinusoidal positional embeddings. timesteps = time_cond - used_sigmas = self.sigmas[time_cond.long()] + used_sigmas = sigmas temb = get_timestep_embedding(timesteps, self.nf) else: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index e724149a..5d7b1f14 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -3,9 +3,11 @@ from .pipeline_bddm import BDDMPipeline from .pipeline_ddim import DDIMPipeline from .pipeline_ddpm import DDPMPipeline from .pipeline_pndm import PNDMPipeline +from .pipeline_score_sde_ve import ScoreSdeVePipeline +from .pipeline_score_sde_vp import ScoreSdeVpPipeline -# from .pipeline_score_sde import NCSNppPipeline +# from .pipeline_score_sde import ScoreSdeVePipeline if is_transformers_available(): diff --git a/src/diffusers/pipelines/pipeline_score_sde.py b/src/diffusers/pipelines/pipeline_score_sde.py deleted file mode 100755 index 5b3cb5bc..00000000 --- a/src/diffusers/pipelines/pipeline_score_sde.py +++ /dev/null @@ -1,94 +0,0 @@ -#!/usr/bin/env python3 -import numpy as np -import torch - -import PIL -from diffusers import DiffusionPipeline - - -# from configs.ve import ffhq_ncsnpp_continuous as configs -# from configs.ve import cifar10_ncsnpp_continuous as configs - -# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth" -# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth" -# Note usually we need to restore ema etc... -# ema restored checkpoint used from below -torch.backends.cuda.matmul.allow_tf32 = False -torch.manual_seed(0) - - -class NCSNppPipeline(DiffusionPipeline): - def __init__(self, model, scheduler): - super().__init__() - self.register_modules(model=model, scheduler=scheduler) - - def __call__(self, generator=None): - N = self.scheduler.config.N - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - - img_size = self.model.config.image_size - channels = self.model.config.num_channels - shape = (1, channels, img_size, img_size) - - model = torch.nn.DataParallel(self.model.to(device)) - - centered = False - n_steps = 1 - - # Initial sample - x = torch.randn(*shape) * self.scheduler.config.sigma_max - x = x.to(device) - - for i in range(N): - sigma_t = self.scheduler.get_sigma_t(i) * torch.ones(shape[0], device=device) - - for _ in range(n_steps): - with torch.no_grad(): - result = model(x, sigma_t) - x = self.scheduler.step_correct(result, x) - - with torch.no_grad(): - result = model(x, sigma_t) - - x, x_mean = self.scheduler.step_pred(result, x, i) - - x = x_mean - - if centered: - x = (x + 1.0) / 2.0 - - return x - - -pipeline = NCSNppPipeline.from_pretrained("/home/patrick/ffhq_ncsnpp") -x = pipeline() - - -# for 5 cifar10 -# x_sum = 106071.9922 -# x_mean = 34.52864456176758 - -# for 1000 cifar10 -# x_sum = 461.9700 -# x_mean = 0.1504 - -# for N=2 for 1024 -x_sum = 3382810112.0 -x_mean = 1075.366455078125 - - -def check_x_sum_x_mean(x, x_sum, x_mean): - assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}" - assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" - - -check_x_sum_x_mean(x, x_sum, x_mean) - - -def save_image(x): - image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8) - image_pil = PIL.Image.fromarray(image_processed[0]) - image_pil.save("../images/hey.png") - - -# save_image(x) diff --git a/src/diffusers/pipelines/pipeline_score_sde_ve.py b/src/diffusers/pipelines/pipeline_score_sde_ve.py new file mode 100644 index 00000000..1dfd304d --- /dev/null +++ b/src/diffusers/pipelines/pipeline_score_sde_ve.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +import torch + +from diffusers import DiffusionPipeline + + +# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names +class ScoreSdeVePipeline(DiffusionPipeline): + def __init__(self, model, scheduler): + super().__init__() + self.register_modules(model=model, scheduler=scheduler) + + def __call__(self, num_inference_steps=2000, generator=None): + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + img_size = self.model.config.image_size + channels = self.model.config.num_channels + shape = (1, channels, img_size, img_size) + + model = self.model.to(device) + + # TODO(Patrick) move to scheduler config + n_steps = 1 + + x = torch.randn(*shape) * self.scheduler.config.sigma_max + x = x.to(device) + + self.scheduler.set_timesteps(num_inference_steps) + self.scheduler.set_sigmas(num_inference_steps) + + for i, t in enumerate(self.scheduler.timesteps): + sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device) + + for _ in range(n_steps): + with torch.no_grad(): + result = self.model(x, sigma_t) + x = self.scheduler.step_correct(result, x) + + with torch.no_grad(): + result = model(x, sigma_t) + + x, x_mean = self.scheduler.step_pred(result, x, t) + + return x_mean diff --git a/src/diffusers/pipelines/pipeline_score_sde_vp.py b/src/diffusers/pipelines/pipeline_score_sde_vp.py new file mode 100644 index 00000000..29551d9a --- /dev/null +++ b/src/diffusers/pipelines/pipeline_score_sde_vp.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +import torch + +from diffusers import DiffusionPipeline + + +# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names +class ScoreSdeVpPipeline(DiffusionPipeline): + def __init__(self, model, scheduler): + super().__init__() + self.register_modules(model=model, scheduler=scheduler) + + def __call__(self, num_inference_steps=1000, generator=None): + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + img_size = self.model.config.image_size + channels = self.model.config.num_channels + shape = (1, channels, img_size, img_size) + + model = self.model.to(device) + + x = torch.randn(*shape).to(device) + + self.scheduler.set_timesteps(num_inference_steps) + + for t in self.scheduler.timesteps: + t = t * torch.ones(shape[0], device=device) + scaled_t = t * (num_inference_steps - 1) + + with torch.no_grad(): + result = model(x, scaled_t) + + x, x_mean = self.scheduler.step_pred(result, x, t) + + x_mean = (x_mean + 1.0) / 2.0 + + return x_mean diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index ea306266..ad66fe59 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -20,5 +20,6 @@ from .scheduling_ddim import DDIMScheduler from .scheduling_ddpm import DDPMScheduler from .scheduling_grad_tts import GradTTSScheduler from .scheduling_pndm import PNDMScheduler +from .scheduling_sde_ve import ScoreSdeVeScheduler +from .scheduling_sde_vp import ScoreSdeVpScheduler from .scheduling_utils import SchedulerMixin -from .scheduling_ve_sde import VeSdeScheduler diff --git a/src/diffusers/schedulers/scheduling_ve_sde.py b/src/diffusers/schedulers/scheduling_sde_ve.py similarity index 59% rename from src/diffusers/schedulers/scheduling_ve_sde.py rename to src/diffusers/schedulers/scheduling_sde_ve.py index 6f188272..79936105 100644 --- a/src/diffusers/schedulers/scheduling_ve_sde.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -1,4 +1,4 @@ -# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved. +# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim +# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch + +# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit import numpy as np import torch @@ -21,34 +23,42 @@ from ..configuration_utils import ConfigMixin from .scheduling_utils import SchedulerMixin -class VeSdeScheduler(SchedulerMixin, ConfigMixin): - def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, N=2, sampling_eps=1e-5, tensor_format="np"): +class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): + def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, sampling_eps=1e-5, tensor_format="np"): super().__init__() self.register_to_config( snr=snr, sigma_min=sigma_min, sigma_max=sigma_max, - N=N, sampling_eps=sampling_eps, ) - # (PVP) - clean up with .config. - self.sigma_min = sigma_min - self.sigma_max = sigma_max - self.snr = snr - self.N = N - self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N)) - self.timesteps = torch.linspace(1, sampling_eps, N) - def get_sigma_t(self, t): - return self.sigma_min * (self.sigma_max / self.sigma_min) ** self.timesteps[t] + self.sigmas = None + self.discrete_sigmas = None + self.timesteps = None + + def set_timesteps(self, num_inference_steps): + self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) + + def set_sigmas(self, num_inference_steps): + if self.timesteps is None: + self.set_timesteps(num_inference_steps) + + self.discrete_sigmas = torch.exp( + torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps) + ) + self.sigmas = torch.tensor( + [self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps] + ) def step_pred(self, result, x, t): - t = self.timesteps[t] * torch.ones(x.shape[0], device=x.device) + # TODO(Patrick) better comments + non-PyTorch + t = t * torch.ones(x.shape[0], device=x.device) + timestep = (t * (len(self.timesteps) - 1)).long() - timestep = (t * (self.N - 1)).long() sigma = self.discrete_sigmas.to(t.device)[timestep] adjacent_sigma = torch.where( - timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(t.device) + timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(timestep.device) ) f = torch.zeros_like(x) G = torch.sqrt(sigma**2 - adjacent_sigma**2) @@ -61,10 +71,11 @@ class VeSdeScheduler(SchedulerMixin, ConfigMixin): return x, x_mean def step_correct(self, result, x): + # TODO(Patrick) better comments + non-PyTorch noise = torch.randn_like(x) grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean() noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() - step_size = (self.snr * noise_norm / grad_norm) ** 2 * 2 + step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 step_size = step_size * torch.ones(x.shape[0], device=x.device) x_mean = x + step_size[:, None, None, None] * result diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py new file mode 100644 index 00000000..dda32a27 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -0,0 +1,64 @@ +# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch + +# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin +from .scheduling_utils import SchedulerMixin + + +class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): + def __init__(self, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"): + super().__init__() + self.register_to_config( + beta_min=beta_min, + beta_max=beta_max, + sampling_eps=sampling_eps, + ) + + self.sigmas = None + self.discrete_sigmas = None + self.timesteps = None + + def set_timesteps(self, num_inference_steps): + self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) + + def step_pred(self, result, x, t): + # TODO(Patrick) better comments + non-PyTorch + # postprocess model result + log_mean_coeff = ( + -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min + ) + std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) + result = -result / std[:, None, None, None] + + # compute + dt = -1.0 / len(self.timesteps) + + beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min) + drift = -0.5 * beta_t[:, None, None, None] * x + diffusion = torch.sqrt(beta_t) + drift = drift - diffusion[:, None, None, None] ** 2 * result + x_mean = x + drift * dt + + # add noise + z = torch.randn_like(x) + x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * z + + return x, x_mean diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index db4ed6eb..6c5c115f 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -33,8 +33,13 @@ from diffusers import ( GradTTSPipeline, GradTTSScheduler, LatentDiffusionPipeline, + NCSNpp, PNDMPipeline, PNDMScheduler, + ScoreSdeVePipeline, + ScoreSdeVeScheduler, + ScoreSdeVpPipeline, + ScoreSdeVpScheduler, UNetGradTTSModel, UNetLDMModel, UNetModel, @@ -721,6 +726,40 @@ class PipelineTesterMixin(unittest.TestCase): ) assert (mel_spec[0, :3, :3].cpu().flatten() - expected_slice).abs().max() < 1e-2 + @slow + def test_score_sde_ve_pipeline(self): + torch.manual_seed(0) + + model = NCSNpp.from_pretrained("fusing/ffhq_ncsnpp") + scheduler = ScoreSdeVeScheduler.from_config("fusing/ffhq_ncsnpp") + + sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler) + + image = sde_ve(num_inference_steps=2) + + expected_image_sum = 3382810112.0 + expected_image_mean = 1075.366455078125 + + assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2 + assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4 + + @slow + def test_score_sde_vp_pipeline(self): + + model = NCSNpp.from_pretrained("fusing/cifar10-ddpmpp-vp") + scheduler = ScoreSdeVpScheduler.from_config("fusing/cifar10-ddpmpp-vp") + + sde_vp = ScoreSdeVpPipeline(model=model, scheduler=scheduler) + + torch.manual_seed(0) + image = sde_vp(num_inference_steps=10) + + expected_image_sum = 4183.2012 + expected_image_mean = 1.3617 + + assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2 + assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4 + def test_module_from_pipeline(self): model = DiffWave(num_res_layers=4) noise_scheduler = DDPMScheduler(timesteps=12)