From f3937bc8f3667772c9f1428b66f0c44b6087b04d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 31 Aug 2022 19:29:38 +0200 Subject: [PATCH] [Refactor] Remove set_seed (#289) * [Refactor] Remove set_seed and class attributes * apply anton's suggestiosn * fix * Apply suggestions from code review Co-authored-by: Pedro Cuenca * up * update * make style * Apply suggestions from code review Co-authored-by: Anton Lozhkov * make fix-copies * make style * make style and new copies Co-authored-by: Pedro Cuenca Co-authored-by: Anton Lozhkov --- src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 2 +- .../score_sde_ve/pipeline_score_sde_ve.py | 6 ++-- .../pipeline_stochastic_karras_ve.py | 1 + src/diffusers/schedulers/scheduling_sde_ve.py | 28 +++++++++++-------- src/diffusers/utils/dummy_scipy_objects.py | 1 + .../utils/dummy_transformers_objects.py | 1 + utils/check_dummies.py | 2 +- 7 files changed, 25 insertions(+), 16 deletions(-) diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index 27c156de..5d735a39 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -56,7 +56,7 @@ class DDPMPipeline(DiffusionPipeline): model_output = self.unet(image, t)["sample"] # 2. compute previous image: x_t -> t_t-1 - image = self.scheduler.step(model_output, t, image)["prev_sample"] + image = self.scheduler.step(model_output, t, image, generator=generator)["prev_sample"] image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() diff --git a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py index 7d72ddf7..0ab92eff 100644 --- a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py @@ -30,7 +30,7 @@ class ScoreSdeVePipeline(DiffusionPipeline): model = self.unet - sample = torch.randn(*shape) * self.scheduler.config.sigma_max + sample = torch.randn(*shape, generator=generator) * self.scheduler.config.sigma_max sample = sample.to(self.device) self.scheduler.set_timesteps(num_inference_steps) @@ -42,11 +42,11 @@ class ScoreSdeVePipeline(DiffusionPipeline): # correction step for _ in range(self.scheduler.correct_steps): model_output = self.unet(sample, sigma_t)["sample"] - sample = self.scheduler.step_correct(model_output, sample)["prev_sample"] + sample = self.scheduler.step_correct(model_output, sample, generator=generator)["prev_sample"] # prediction step model_output = model(sample, sigma_t)["sample"] - output = self.scheduler.step_pred(model_output, t, sample) + output = self.scheduler.step_pred(model_output, t, sample, generator=generator) sample, sample_mean = output["prev_sample"], output["prev_sample_mean"] diff --git a/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py b/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py index 3bd95dd5..007395a1 100644 --- a/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py +++ b/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py @@ -19,6 +19,7 @@ class KarrasVePipeline(DiffusionPipeline): differential equations." https://arxiv.org/abs/2011.13456 """ + # add type hints for linting unet: UNet2DModel scheduler: KarrasVeScheduler diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 1d6e05d9..e3fec035 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -14,8 +14,8 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch -# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit -from typing import Union +import warnings +from typing import Optional, Union import numpy as np import torch @@ -98,6 +98,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") def set_seed(self, seed): + warnings.warn( + "The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a" + " generator instead.", + DeprecationWarning, + ) tensor_format = getattr(self, "tensor_format", "pt") if tensor_format == "np": np.random.seed(seed) @@ -111,14 +116,14 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray], - seed=None, + generator: Optional[torch.Generator] = None, + **kwargs, ): """ Predict the sample at the previous timestep by reversing the SDE. """ - if seed is not None: - self.set_seed(seed) - # TODO(Patrick) non-PyTorch + if "seed" in kwargs and kwargs["seed"] is not None: + self.set_seed(kwargs["seed"]) if self.timesteps is None: raise ValueError( @@ -140,7 +145,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): drift = drift - diffusion[:, None, None, None] ** 2 * model_output # equation 6: sample noise for the diffusion term of - noise = self.randn_like(sample) + noise = self.randn_like(sample, generator=generator) prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep # TODO is the variable diffusion the correct scaling term for the noise? prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g @@ -151,14 +156,15 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): self, model_output: Union[torch.FloatTensor, np.ndarray], sample: Union[torch.FloatTensor, np.ndarray], - seed=None, + generator: Optional[torch.Generator] = None, + **kwargs, ): """ Correct the predicted sample based on the output model_output of the network. This is often run repeatedly after making the prediction for the previous timestep. """ - if seed is not None: - self.set_seed(seed) + if "seed" in kwargs and kwargs["seed"] is not None: + self.set_seed(kwargs["seed"]) if self.timesteps is None: raise ValueError( @@ -167,7 +173,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" # sample noise for correction - noise = self.randn_like(sample) + noise = self.randn_like(sample, generator=generator) # compute step size from the model_output, the noise, and the snr grad_norm = self.norm(model_output) diff --git a/src/diffusers/utils/dummy_scipy_objects.py b/src/diffusers/utils/dummy_scipy_objects.py index 889baf67..3706c575 100644 --- a/src/diffusers/utils/dummy_scipy_objects.py +++ b/src/diffusers/utils/dummy_scipy_objects.py @@ -1,5 +1,6 @@ # This file is autogenerated by the command `make fix-copies`, do not edit. # flake8: noqa + from ..utils import DummyObject, requires_backends diff --git a/src/diffusers/utils/dummy_transformers_objects.py b/src/diffusers/utils/dummy_transformers_objects.py index 4c216b62..e05eb814 100644 --- a/src/diffusers/utils/dummy_transformers_objects.py +++ b/src/diffusers/utils/dummy_transformers_objects.py @@ -1,5 +1,6 @@ # This file is autogenerated by the command `make fix-copies`, do not edit. # flake8: noqa + from ..utils import DummyObject, requires_backends diff --git a/utils/check_dummies.py b/utils/check_dummies.py index 928f9ac1..60c954ad 100644 --- a/utils/check_dummies.py +++ b/utils/check_dummies.py @@ -107,7 +107,7 @@ def create_dummy_files(): for backend, objects in backend_specific_objects.items(): backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]" dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n" - dummy_file += "# flake8: noqa\n" + dummy_file += "# flake8: noqa\n\n" dummy_file += "from ..utils import DummyObject, requires_backends\n\n" dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects]) dummy_files[backend] = dummy_file