From dd10da76a78e9566d12ddf1eb5aac90021b7e51d Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Tue, 9 Aug 2022 15:58:30 +0200 Subject: [PATCH] Add an alternative Karras et al. stochastic scheduler for VE models (#160) * karras + VE, not flexible yet * Fix inputs incompatibility with the original unet * Roll back sigma scaling * Apply suggestions from code review * Old comment * Fix doc --- src/diffusers/__init__.py | 11 +- src/diffusers/pipelines/__init__.py | 1 + .../pipelines/stochatic_karras_ve/__init__.py | 1 + .../pipeline_stochastic_karras_ve.py | 80 +++++++++++ src/diffusers/schedulers/__init__.py | 1 + src/diffusers/schedulers/scheduling_ddpm.py | 2 +- .../schedulers/scheduling_karras_ve.py | 127 ++++++++++++++++++ tests/test_modeling_utils.py | 18 +++ 8 files changed, 238 insertions(+), 3 deletions(-) create mode 100644 src/diffusers/pipelines/stochatic_karras_ve/__init__.py create mode 100644 src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py create mode 100644 src/diffusers/schedulers/scheduling_karras_ve.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 32af42b5..f8313509 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -18,8 +18,15 @@ from .optimization import ( get_scheduler, ) from .pipeline_utils import DiffusionPipeline -from .pipelines import DDIMPipeline, DDPMPipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline -from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler +from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline +from .schedulers import ( + DDIMScheduler, + DDPMScheduler, + KarrasVeScheduler, + PNDMScheduler, + SchedulerMixin, + ScoreSdeVeScheduler, +) from .training_utils import EMAModel diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 50855568..c1b2068a 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -4,6 +4,7 @@ from .ddpm import DDPMPipeline from .latent_diffusion_uncond import LDMPipeline from .pndm import PNDMPipeline from .score_sde_ve import ScoreSdeVePipeline +from .stochatic_karras_ve import KarrasVePipeline if is_transformers_available(): diff --git a/src/diffusers/pipelines/stochatic_karras_ve/__init__.py b/src/diffusers/pipelines/stochatic_karras_ve/__init__.py new file mode 100644 index 00000000..5a63c1d2 --- /dev/null +++ b/src/diffusers/pipelines/stochatic_karras_ve/__init__.py @@ -0,0 +1 @@ +from .pipeline_stochastic_karras_ve import KarrasVePipeline diff --git a/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py b/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py new file mode 100644 index 00000000..27cb6a0e --- /dev/null +++ b/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 +import torch + +from tqdm.auto import tqdm + +from ...models import UNet2DModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import KarrasVeScheduler + + +class KarrasVePipeline(DiffusionPipeline): + """ + Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. + Use Algorithm 2 and the VE column of Table 1 from [1] for reference. + + [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364 + [2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456 + """ + + unet: UNet2DModel + scheduler: KarrasVeScheduler + + def __init__(self, unet, scheduler): + super().__init__() + scheduler = scheduler.set_format("pt") + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__(self, batch_size=1, num_inference_steps=50, generator=None, torch_device=None, output_type="pil"): + if torch_device is None: + torch_device = "cuda" if torch.cuda.is_available() else "cpu" + + img_size = self.unet.config.sample_size + shape = (batch_size, 3, img_size, img_size) + + model = self.unet.to(torch_device) + + # sample x_0 ~ N(0, sigma_0^2 * I) + sample = torch.randn(*shape) * self.scheduler.config.sigma_max + sample = sample.to(torch_device) + + self.scheduler.set_timesteps(num_inference_steps) + + for t in tqdm(self.scheduler.timesteps): + # here sigma_t == t_i from the paper + sigma = self.scheduler.schedule[t] + sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0 + + # 1. Select temporarily increased noise level sigma_hat + # 2. Add new noise to move from sample_i to sample_hat + sample_hat, sigma_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator) + + # 3. Predict the noise residual given the noise magnitude `sigma_hat` + # The model inputs and output are adjusted by following eq. (213) in [1]. + model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2)["sample"] + + # 4. Evaluate dx/dt at sigma_hat + # 5. Take Euler step from sigma to sigma_prev + step_output = self.scheduler.step(model_output, sigma_hat, sigma_prev, sample_hat) + + if sigma_prev != 0: + # 6. Apply 2nd order correction + # The model inputs and output are adjusted by following eq. (213) in [1]. + model_output = (sigma_prev / 2) * model((step_output["prev_sample"] + 1) / 2, sigma_prev / 2)["sample"] + step_output = self.scheduler.step_correct( + model_output, + sigma_hat, + sigma_prev, + sample_hat, + step_output["prev_sample"], + step_output["derivative"], + ) + sample = step_output["prev_sample"] + + sample = (sample / 2 + 0.5).clamp(0, 1) + sample = sample.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + sample = self.numpy_to_pil(sample) + + return {"sample": sample} diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 57a5c994..42a536aa 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -18,6 +18,7 @@ from .scheduling_ddim import DDIMScheduler from .scheduling_ddpm import DDPMScheduler +from .scheduling_karras_ve import KarrasVeScheduler from .scheduling_pndm import PNDMScheduler from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 3e20d706..9783f9a1 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -134,7 +134,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): generator=None, ): t = timestep - + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py new file mode 100644 index 00000000..9741189c --- /dev/null +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -0,0 +1,127 @@ +# Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin + + +class KarrasVeScheduler(SchedulerMixin, ConfigMixin): + """ + Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. + Use Algorithm 2 and the VE column of Table 1 from [1] for reference. + + [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364 + [2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456 + """ + + @register_to_config + def __init__( + self, + sigma_min=0.02, + sigma_max=100, + s_noise=1.007, + s_churn=80, + s_min=0.05, + s_max=50, + tensor_format="pt", + ): + """ + For more details on the parameters, see the original paper's Appendix E.: + "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. + The grid search values used to find the optimal {s_noise, s_churn, s_min, s_max} for a specific model + are described in Table 5 of the paper. + + Args: + sigma_min (`float`): minimum noise magnitude + sigma_max (`float`): maximum noise magnitude + s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. + A reasonable range is [1.000, 1.011]. + s_churn (`float`): the parameter controlling the overall amount of stochasticity. + A reasonable range is [0, 100]. + s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). + A reasonable range is [0, 10]. + s_max (`float`): the end value of the sigma range where we add noise. + A reasonable range is [0.2, 80]. + """ + # setable values + self.num_inference_steps = None + self.timesteps = None + self.schedule = None # sigma(t_i) + + self.tensor_format = tensor_format + self.set_format(tensor_format=tensor_format) + + def set_timesteps(self, num_inference_steps): + self.num_inference_steps = num_inference_steps + self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() + self.schedule = [ + (self.sigma_max * (self.sigma_min**2 / self.sigma_max**2) ** (i / (num_inference_steps - 1))) + for i in self.timesteps + ] + self.schedule = np.array(self.schedule, dtype=np.float32) + + self.set_format(tensor_format=self.tensor_format) + + def add_noise_to_input(self, sample, sigma, generator=None): + """ + Explicit Langevin-like "churn" step of adding noise to the sample according to + a factor gamma_i ≥ 0 to reach a higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. + """ + if self.s_min <= sigma <= self.s_max: + gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1) + else: + gamma = 0 + + # sample eps ~ N(0, S_noise^2 * I) + eps = self.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device) + sigma_hat = sigma + gamma * sigma + sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) + + return sample_hat, sigma_hat + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + sigma_hat: float, + sigma_prev: float, + sample_hat: Union[torch.FloatTensor, np.ndarray], + ): + pred_original_sample = sample_hat + sigma_hat * model_output + derivative = (sample_hat - pred_original_sample) / sigma_hat + sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative + + return {"prev_sample": sample_prev, "derivative": derivative} + + def step_correct( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + sigma_hat: float, + sigma_prev: float, + sample_hat: Union[torch.FloatTensor, np.ndarray], + sample_prev: Union[torch.FloatTensor, np.ndarray], + derivative: Union[torch.FloatTensor, np.ndarray], + ): + pred_original_sample = sample_prev + sigma_prev * model_output + derivative_corr = (sample_prev - pred_original_sample) / sigma_prev + sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) + return {"prev_sample": sample_prev, "derivative": derivative_corr} + + def add_noise(self, original_samples, noise, timesteps): + raise NotImplementedError() diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 9d783e60..072109e8 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -29,6 +29,8 @@ from diffusers import ( DDIMScheduler, DDPMPipeline, DDPMScheduler, + KarrasVePipeline, + KarrasVeScheduler, LDMPipeline, LDMTextToImagePipeline, PNDMPipeline, @@ -909,3 +911,19 @@ class PipelineTesterMixin(unittest.TestCase): # the values aren't exactly equal, but the images look the same visually assert np.abs(ddpm_images - ddim_images).max() < 1e-1 + + @slow + def test_karras_ve_pipeline(self): + model_id = "google/ncsnpp-celebahq-256" + model = UNet2DModel.from_pretrained(model_id) + scheduler = KarrasVeScheduler(tensor_format="pt") + + pipe = KarrasVePipeline(unet=model, scheduler=scheduler) + + generator = torch.manual_seed(0) + image = pipe(num_inference_steps=20, generator=generator, output_type="numpy")["sample"] + + image_slice = image[0, -3:, -3:, -1] + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.26815, 0.1581, 0.2658, 0.23248, 0.1550, 0.2539, 0.1131, 0.1024, 0.0837]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2