[Refactor] Remove set_seed (#289)
* [Refactor] Remove set_seed and class attributes * apply anton's suggestiosn * fix * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * up * update * make style * Apply suggestions from code review Co-authored-by: Anton Lozhkov <anton@huggingface.co> * make fix-copies * make style * make style and new copies Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Anton Lozhkov <anton@huggingface.co>
This commit is contained in:
parent
384fcac6df
commit
f3937bc8f3
|
@ -56,7 +56,7 @@ class DDPMPipeline(DiffusionPipeline):
|
|||
model_output = self.unet(image, t)["sample"]
|
||||
|
||||
# 2. compute previous image: x_t -> t_t-1
|
||||
image = self.scheduler.step(model_output, t, image)["prev_sample"]
|
||||
image = self.scheduler.step(model_output, t, image, generator=generator)["prev_sample"]
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
|
|
@ -30,7 +30,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
|||
|
||||
model = self.unet
|
||||
|
||||
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
|
||||
sample = torch.randn(*shape, generator=generator) * self.scheduler.config.sigma_max
|
||||
sample = sample.to(self.device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
@ -42,11 +42,11 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
|||
# correction step
|
||||
for _ in range(self.scheduler.correct_steps):
|
||||
model_output = self.unet(sample, sigma_t)["sample"]
|
||||
sample = self.scheduler.step_correct(model_output, sample)["prev_sample"]
|
||||
sample = self.scheduler.step_correct(model_output, sample, generator=generator)["prev_sample"]
|
||||
|
||||
# prediction step
|
||||
model_output = model(sample, sigma_t)["sample"]
|
||||
output = self.scheduler.step_pred(model_output, t, sample)
|
||||
output = self.scheduler.step_pred(model_output, t, sample, generator=generator)
|
||||
|
||||
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ class KarrasVePipeline(DiffusionPipeline):
|
|||
differential equations." https://arxiv.org/abs/2011.13456
|
||||
"""
|
||||
|
||||
# add type hints for linting
|
||||
unet: UNet2DModel
|
||||
scheduler: KarrasVeScheduler
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
|
||||
|
||||
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
|
||||
from typing import Union
|
||||
import warnings
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -98,6 +98,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|||
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
|
||||
|
||||
def set_seed(self, seed):
|
||||
warnings.warn(
|
||||
"The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a"
|
||||
" generator instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
tensor_format = getattr(self, "tensor_format", "pt")
|
||||
if tensor_format == "np":
|
||||
np.random.seed(seed)
|
||||
|
@ -111,14 +116,14 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
seed=None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Predict the sample at the previous timestep by reversing the SDE.
|
||||
"""
|
||||
if seed is not None:
|
||||
self.set_seed(seed)
|
||||
# TODO(Patrick) non-PyTorch
|
||||
if "seed" in kwargs and kwargs["seed"] is not None:
|
||||
self.set_seed(kwargs["seed"])
|
||||
|
||||
if self.timesteps is None:
|
||||
raise ValueError(
|
||||
|
@ -140,7 +145,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|||
drift = drift - diffusion[:, None, None, None] ** 2 * model_output
|
||||
|
||||
# equation 6: sample noise for the diffusion term of
|
||||
noise = self.randn_like(sample)
|
||||
noise = self.randn_like(sample, generator=generator)
|
||||
prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
|
||||
# TODO is the variable diffusion the correct scaling term for the noise?
|
||||
prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
|
||||
|
@ -151,14 +156,15 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
seed=None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
|
||||
after making the prediction for the previous timestep.
|
||||
"""
|
||||
if seed is not None:
|
||||
self.set_seed(seed)
|
||||
if "seed" in kwargs and kwargs["seed"] is not None:
|
||||
self.set_seed(kwargs["seed"])
|
||||
|
||||
if self.timesteps is None:
|
||||
raise ValueError(
|
||||
|
@ -167,7 +173,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
|
||||
# sample noise for correction
|
||||
noise = self.randn_like(sample)
|
||||
noise = self.randn_like(sample, generator=generator)
|
||||
|
||||
# compute step size from the model_output, the noise, and the snr
|
||||
grad_norm = self.norm(model_output)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
# flake8: noqa
|
||||
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
# flake8: noqa
|
||||
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
|
|
|
@ -107,7 +107,7 @@ def create_dummy_files():
|
|||
for backend, objects in backend_specific_objects.items():
|
||||
backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]"
|
||||
dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
||||
dummy_file += "# flake8: noqa\n"
|
||||
dummy_file += "# flake8: noqa\n\n"
|
||||
dummy_file += "from ..utils import DummyObject, requires_backends\n\n"
|
||||
dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects])
|
||||
dummy_files[backend] = dummy_file
|
||||
|
|
Loading…
Reference in New Issue