[Scheduler] Move predict epsilon to init (#1155)

* [Scheduler] Move predict epsilon to init

* up

* uP

* uP

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* up

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
Patrick von Platen 2022-11-08 18:08:08 +01:00 committed by GitHub
parent 5786b0e2f7
commit 249d9bc0e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 193 additions and 25 deletions

View File

@ -1,4 +1,5 @@
import argparse import argparse
import inspect
import math import math
import os import os
from pathlib import Path from pathlib import Path
@ -190,10 +191,10 @@ def parse_args():
) )
parser.add_argument( parser.add_argument(
"--predict_mode", "--predict_epsilon",
type=str, action="store_true",
default="eps", default=True,
help="What the model should predict. 'eps' to predict error, 'x0' to directly predict reconstruction", help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
) )
parser.add_argument("--ddpm_num_steps", type=int, default=1000) parser.add_argument("--ddpm_num_steps", type=int, default=1000)
@ -252,7 +253,17 @@ def main(args):
"UpBlock2D", "UpBlock2D",
), ),
) )
accepts_predict_epsilon = "predict_epsilon" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
if accepts_predict_epsilon:
noise_scheduler = DDPMScheduler(
num_train_timesteps=args.ddpm_num_steps,
beta_schedule=args.ddpm_beta_schedule,
predict_epsilon=args.predict_epsilon,
)
else:
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule) noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
optimizer = torch.optim.AdamW( optimizer = torch.optim.AdamW(
model.parameters(), model.parameters(),
lr=args.learning_rate, lr=args.learning_rate,
@ -351,9 +362,9 @@ def main(args):
# Predict the noise residual # Predict the noise residual
model_output = model(noisy_images, timesteps).sample model_output = model(noisy_images, timesteps).sample
if args.predict_mode == "eps": if args.predict_epsilon:
loss = F.mse_loss(model_output, noise) # this could have different weights! loss = F.mse_loss(model_output, noise) # this could have different weights!
elif args.predict_mode == "x0": else:
alpha_t = _extract_into_tensor( alpha_t = _extract_into_tensor(
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1) noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
) )
@ -401,7 +412,6 @@ def main(args):
generator=generator, generator=generator,
batch_size=args.eval_batch_size, batch_size=args.eval_batch_size,
output_type="numpy", output_type="numpy",
predict_epsilon=args.predict_mode == "eps",
).images ).images
# denormalize the images and save to tensorboard # denormalize the images and save to tensorboard

View File

@ -334,6 +334,11 @@ class ConfigMixin:
# 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
init_dict = {} init_dict = {}
for key in expected_keys: for key in expected_keys:
# if config param is passed to kwarg and is present in config dict
# it should overwrite existing config dict key
if key in kwargs and key in config_dict:
config_dict[key] = kwargs.pop(key)
if key in kwargs: if key in kwargs:
# overwrite key # overwrite key
init_dict[key] = kwargs.pop(key) init_dict[key] = kwargs.pop(key)

View File

@ -18,7 +18,9 @@ from typing import Optional, Tuple, Union
import torch import torch
from ...configuration_utils import FrozenDict
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...utils import deprecate
class DDPMPipeline(DiffusionPipeline): class DDPMPipeline(DiffusionPipeline):
@ -45,7 +47,6 @@ class DDPMPipeline(DiffusionPipeline):
num_inference_steps: int = 1000, num_inference_steps: int = 1000,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
predict_epsilon: bool = True,
**kwargs, **kwargs,
) -> Union[ImagePipelineOutput, Tuple]: ) -> Union[ImagePipelineOutput, Tuple]:
r""" r"""
@ -69,6 +70,16 @@ class DDPMPipeline(DiffusionPipeline):
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images. generated images.
""" """
message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
new_config = dict(self.scheduler.config)
new_config["predict_epsilon"] = predict_epsilon
self.scheduler._internal_dict = FrozenDict(new_config)
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = torch.randn( image = torch.randn(

View File

@ -21,8 +21,8 @@ from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ..utils import BaseOutput from ..utils import BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
@ -99,6 +99,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`): clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability. option to clip predicted sample between -1 and 1 for numerical stability.
predict_epsilon (`bool`):
optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise.
""" """
@ -121,6 +123,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[np.ndarray] = None, trained_betas: Optional[np.ndarray] = None,
variance_type: str = "fixed_small", variance_type: str = "fixed_small",
clip_sample: bool = True, clip_sample: bool = True,
predict_epsilon: bool = True,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas) self.betas = torch.from_numpy(trained_betas)
@ -221,9 +224,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
model_output: torch.FloatTensor, model_output: torch.FloatTensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.FloatTensor,
predict_epsilon=True,
generator=None, generator=None,
return_dict: bool = True, return_dict: bool = True,
**kwargs,
) -> Union[DDPMSchedulerOutput, Tuple]: ) -> Union[DDPMSchedulerOutput, Tuple]:
""" """
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
@ -234,8 +237,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
timestep (`int`): current discrete timestep in the diffusion chain. timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
generator: random number generator. generator: random number generator.
return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
@ -245,6 +246,16 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
returning a tuple, the first element is the sample tensor. returning a tuple, the first element is the sample tensor.
""" """
message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
new_config = dict(self.config)
new_config["predict_epsilon"] = predict_epsilon
self._internal_dict = FrozenDict(new_config)
t = timestep t = timestep
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
@ -260,7 +271,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# 2. compute predicted original sample from predicted noise also called # 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if predict_epsilon: if self.config.predict_epsilon:
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
else: else:
pred_original_sample = model_output pred_original_sample = model_output

View File

@ -22,7 +22,8 @@ import flax
import jax.numpy as jnp import jax.numpy as jnp
from jax import random from jax import random
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
@ -97,7 +98,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`): clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability. option to clip predicted sample between -1 and 1 for numerical stability.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. predict_epsilon (`bool`):
optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise.
""" """
@ -115,6 +117,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
trained_betas: Optional[jnp.ndarray] = None, trained_betas: Optional[jnp.ndarray] = None,
variance_type: str = "fixed_small", variance_type: str = "fixed_small",
clip_sample: bool = True, clip_sample: bool = True,
predict_epsilon: bool = True,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = jnp.asarray(trained_betas) self.betas = jnp.asarray(trained_betas)
@ -196,6 +199,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
key: random.KeyArray, key: random.KeyArray,
predict_epsilon: bool = True, predict_epsilon: bool = True,
return_dict: bool = True, return_dict: bool = True,
**kwargs,
) -> Union[FlaxDDPMSchedulerOutput, Tuple]: ) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
""" """
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
@ -208,8 +212,6 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
sample (`jnp.ndarray`): sample (`jnp.ndarray`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
key (`random.KeyArray`): a PRNG key. key (`random.KeyArray`): a PRNG key.
predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class
Returns: Returns:
@ -217,6 +219,16 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
`tuple`. When returning a tuple, the first element is the sample tensor. `tuple`. When returning a tuple, the first element is the sample tensor.
""" """
message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
new_config = dict(self.config)
new_config["predict_epsilon"] = predict_epsilon
self._internal_dict = FrozenDict(new_config)
t = timestep t = timestep
if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]: if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]:
@ -232,7 +244,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
# 2. compute predicted original sample from predicted noise also called # 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if predict_epsilon: if self.config.predict_epsilon:
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
else: else:
pred_original_sample = model_output pred_original_sample = model_output

View File

@ -42,7 +42,6 @@ class CustomLocalPipeline(DiffusionPipeline):
self, self,
batch_size: int = 1, batch_size: int = 1,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
eta: float = 0.0,
num_inference_steps: int = 50, num_inference_steps: int = 50,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
@ -89,7 +88,7 @@ class CustomLocalPipeline(DiffusionPipeline):
# 2. predict previous mean of image x_t-1 and add variance depending on eta # 2. predict previous mean of image x_t-1 and add variance depending on eta
# eta corresponds to η in paper and should be between [0, 1] # eta corresponds to η in paper and should be between [0, 1]
# do x_t -> x_t-1 # do x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, eta).prev_sample image = self.scheduler.step(model_output, t, image).prev_sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()

View File

@ -19,6 +19,7 @@ import numpy as np
import torch import torch
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.utils import deprecate
from diffusers.utils.testing_utils import require_torch, slow, torch_device from diffusers.utils.testing_utils import require_torch, slow, torch_device
from ...test_pipelines_common import PipelineTesterMixin from ...test_pipelines_common import PipelineTesterMixin
@ -28,8 +29,74 @@ torch.backends.cuda.matmul.allow_tf32 = False
class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# FIXME: add fast tests @property
pass def dummy_uncond_unet(self):
torch.manual_seed(0)
model = UNet2DModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=3,
out_channels=3,
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
return model
def test_inference(self):
unet = self.dummy_uncond_unet
scheduler = DDPMScheduler()
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
# Warmup pass when using mps (see #372)
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)
generator = torch.manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
generator = torch.manual_seed(0)
image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array(
[5.589e-01, 7.089e-01, 2.632e-01, 6.841e-01, 1.000e-04, 9.999e-01, 1.973e-01, 1.000e-04, 8.010e-02]
)
tolerance = 1e-2 if torch_device != "mps" else 3e-2
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance
def test_inference_predict_epsilon(self):
deprecate("remove this test", "0.10.0", "remove")
unet = self.dummy_uncond_unet
scheduler = DDPMScheduler(predict_epsilon=False)
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
# Warmup pass when using mps (see #372)
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)
generator = torch.manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
generator = torch.manual_seed(0)
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0]
image_slice = image[0, -3:, -3:, -1]
image_eps_slice = image_eps[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
tolerance = 1e-2 if torch_device != "mps" else 3e-2
assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance
@slow @slow

24
tests/test_config.py Executable file → Normal file
View File

@ -21,6 +21,7 @@ import unittest
import diffusers import diffusers
from diffusers import ( from diffusers import (
DDIMScheduler, DDIMScheduler,
DDPMScheduler,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler, EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler, EulerDiscreteScheduler,
@ -291,6 +292,29 @@ class ConfigTester(unittest.TestCase):
# no warning should be thrown # no warning should be thrown
assert cap_logger.out == "" assert cap_logger.out == ""
def test_overwrite_config_on_load(self):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
ddpm = DDPMScheduler.from_config(
"hf-internal-testing/tiny-stable-diffusion-torch",
subfolder="scheduler",
predict_epsilon=False,
beta_end=8,
)
with CaptureLogger(logger) as cap_logger_2:
ddpm_2 = DDPMScheduler.from_config("google/ddpm-celebahq-256", beta_start=88)
assert ddpm.__class__ == DDPMScheduler
assert ddpm.config.predict_epsilon is False
assert ddpm.config.beta_end == 8
assert ddpm_2.config.beta_start == 88
# no warning should be thrown
assert cap_logger.out == ""
assert cap_logger_2.out == ""
def test_load_dpmsolver(self): def test_load_dpmsolver(self):
logger = logging.get_logger("diffusers.configuration_utils") logger = logging.get_logger("diffusers.configuration_utils")

View File

@ -107,6 +107,7 @@ class CustomPipelineTests(unittest.TestCase):
images, output_str = pipeline(num_inference_steps=2, output_type="np") images, output_str = pipeline(num_inference_steps=2, output_type="np")
assert images[0].shape == (1, 32, 32, 3) assert images[0].shape == (1, 32, 32, 3)
# compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102 # compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102
assert output_str == "This is a test" assert output_str == "This is a test"

View File

@ -33,7 +33,7 @@ from diffusers import (
ScoreSdeVeScheduler, ScoreSdeVeScheduler,
VQDiffusionScheduler, VQDiffusionScheduler,
) )
from diffusers.utils import torch_device from diffusers.utils import deprecate, torch_device
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
@ -393,6 +393,34 @@ class DDPMSchedulerTest(SchedulerCommonTest):
for clip_sample in [True, False]: for clip_sample in [True, False]:
self.check_over_configs(clip_sample=clip_sample) self.check_over_configs(clip_sample=clip_sample)
def test_predict_epsilon(self):
for predict_epsilon in [True, False]:
self.check_over_configs(predict_epsilon=predict_epsilon)
def test_deprecated_epsilon(self):
deprecate("remove this test", "0.10.0", "remove")
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
sample = self.dummy_sample_deter
residual = 0.1 * self.dummy_sample_deter
time_step = 4
scheduler = scheduler_class(**scheduler_config)
scheduler_eps = scheduler_class(predict_epsilon=False, **scheduler_config)
kwargs = {}
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
kwargs["generator"] = torch.Generator().manual_seed(0)
output = scheduler.step(residual, time_step, sample, predict_epsilon=False, **kwargs).prev_sample
kwargs = {}
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
kwargs["generator"] = torch.Generator().manual_seed(0)
output_eps = scheduler_eps.step(residual, time_step, sample, predict_epsilon=False, **kwargs).prev_sample
assert (output - output_eps).abs().sum() < 1e-5
def test_time_indices(self): def test_time_indices(self):
for t in [0, 500, 999]: for t in [0, 500, 999]:
self.check_over_forward(time_step=t) self.check_over_forward(time_step=t)