[Refactor] Remove set_seed (#289)

* [Refactor] Remove set_seed and class attributes

* apply anton's suggestiosn

* fix

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* up

* update

* make style

* Apply suggestions from code review

Co-authored-by: Anton Lozhkov <anton@huggingface.co>

* make fix-copies

* make style

* make style and new copies

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Anton Lozhkov <anton@huggingface.co>
This commit is contained in:
Patrick von Platen 2022-08-31 19:29:38 +02:00 committed by GitHub
parent 384fcac6df
commit f3937bc8f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 25 additions and 16 deletions

View File

@ -56,7 +56,7 @@ class DDPMPipeline(DiffusionPipeline):
model_output = self.unet(image, t)["sample"]
# 2. compute previous image: x_t -> t_t-1
image = self.scheduler.step(model_output, t, image)["prev_sample"]
image = self.scheduler.step(model_output, t, image, generator=generator)["prev_sample"]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()

View File

@ -30,7 +30,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
model = self.unet
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
sample = torch.randn(*shape, generator=generator) * self.scheduler.config.sigma_max
sample = sample.to(self.device)
self.scheduler.set_timesteps(num_inference_steps)
@ -42,11 +42,11 @@ class ScoreSdeVePipeline(DiffusionPipeline):
# correction step
for _ in range(self.scheduler.correct_steps):
model_output = self.unet(sample, sigma_t)["sample"]
sample = self.scheduler.step_correct(model_output, sample)["prev_sample"]
sample = self.scheduler.step_correct(model_output, sample, generator=generator)["prev_sample"]
# prediction step
model_output = model(sample, sigma_t)["sample"]
output = self.scheduler.step_pred(model_output, t, sample)
output = self.scheduler.step_pred(model_output, t, sample, generator=generator)
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]

View File

@ -19,6 +19,7 @@ class KarrasVePipeline(DiffusionPipeline):
differential equations." https://arxiv.org/abs/2011.13456
"""
# add type hints for linting
unet: UNet2DModel
scheduler: KarrasVeScheduler

View File

@ -14,8 +14,8 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
from typing import Union
import warnings
from typing import Optional, Union
import numpy as np
import torch
@ -98,6 +98,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def set_seed(self, seed):
warnings.warn(
"The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a"
" generator instead.",
DeprecationWarning,
)
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
np.random.seed(seed)
@ -111,14 +116,14 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
seed=None,
generator: Optional[torch.Generator] = None,
**kwargs,
):
"""
Predict the sample at the previous timestep by reversing the SDE.
"""
if seed is not None:
self.set_seed(seed)
# TODO(Patrick) non-PyTorch
if "seed" in kwargs and kwargs["seed"] is not None:
self.set_seed(kwargs["seed"])
if self.timesteps is None:
raise ValueError(
@ -140,7 +145,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
drift = drift - diffusion[:, None, None, None] ** 2 * model_output
# equation 6: sample noise for the diffusion term of
noise = self.randn_like(sample)
noise = self.randn_like(sample, generator=generator)
prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
# TODO is the variable diffusion the correct scaling term for the noise?
prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
@ -151,14 +156,15 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
self,
model_output: Union[torch.FloatTensor, np.ndarray],
sample: Union[torch.FloatTensor, np.ndarray],
seed=None,
generator: Optional[torch.Generator] = None,
**kwargs,
):
"""
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
after making the prediction for the previous timestep.
"""
if seed is not None:
self.set_seed(seed)
if "seed" in kwargs and kwargs["seed"] is not None:
self.set_seed(kwargs["seed"])
if self.timesteps is None:
raise ValueError(
@ -167,7 +173,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
# sample noise for correction
noise = self.randn_like(sample)
noise = self.randn_like(sample, generator=generator)
# compute step size from the model_output, the noise, and the snr
grad_norm = self.norm(model_output)

View File

@ -1,5 +1,6 @@
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from ..utils import DummyObject, requires_backends

View File

@ -1,5 +1,6 @@
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from ..utils import DummyObject, requires_backends

View File

@ -107,7 +107,7 @@ def create_dummy_files():
for backend, objects in backend_specific_objects.items():
backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]"
dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
dummy_file += "# flake8: noqa\n"
dummy_file += "# flake8: noqa\n\n"
dummy_file += "from ..utils import DummyObject, requires_backends\n\n"
dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects])
dummy_files[backend] = dummy_file