[Scheduler] Move predict epsilon to init (#1155)
* [Scheduler] Move predict epsilon to init * up * uP * uP * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * up Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
parent
5786b0e2f7
commit
249d9bc0e7
|
@ -1,4 +1,5 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
import inspect
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -190,10 +191,10 @@ def parse_args():
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--predict_mode",
|
"--predict_epsilon",
|
||||||
type=str,
|
action="store_true",
|
||||||
default="eps",
|
default=True,
|
||||||
help="What the model should predict. 'eps' to predict error, 'x0' to directly predict reconstruction",
|
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("--ddpm_num_steps", type=int, default=1000)
|
parser.add_argument("--ddpm_num_steps", type=int, default=1000)
|
||||||
|
@ -252,7 +253,17 @@ def main(args):
|
||||||
"UpBlock2D",
|
"UpBlock2D",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
accepts_predict_epsilon = "predict_epsilon" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
|
||||||
|
|
||||||
|
if accepts_predict_epsilon:
|
||||||
|
noise_scheduler = DDPMScheduler(
|
||||||
|
num_train_timesteps=args.ddpm_num_steps,
|
||||||
|
beta_schedule=args.ddpm_beta_schedule,
|
||||||
|
predict_epsilon=args.predict_epsilon,
|
||||||
|
)
|
||||||
|
else:
|
||||||
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
|
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW(
|
optimizer = torch.optim.AdamW(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
lr=args.learning_rate,
|
lr=args.learning_rate,
|
||||||
|
@ -351,9 +362,9 @@ def main(args):
|
||||||
# Predict the noise residual
|
# Predict the noise residual
|
||||||
model_output = model(noisy_images, timesteps).sample
|
model_output = model(noisy_images, timesteps).sample
|
||||||
|
|
||||||
if args.predict_mode == "eps":
|
if args.predict_epsilon:
|
||||||
loss = F.mse_loss(model_output, noise) # this could have different weights!
|
loss = F.mse_loss(model_output, noise) # this could have different weights!
|
||||||
elif args.predict_mode == "x0":
|
else:
|
||||||
alpha_t = _extract_into_tensor(
|
alpha_t = _extract_into_tensor(
|
||||||
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
|
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
|
||||||
)
|
)
|
||||||
|
@ -401,7 +412,6 @@ def main(args):
|
||||||
generator=generator,
|
generator=generator,
|
||||||
batch_size=args.eval_batch_size,
|
batch_size=args.eval_batch_size,
|
||||||
output_type="numpy",
|
output_type="numpy",
|
||||||
predict_epsilon=args.predict_mode == "eps",
|
|
||||||
).images
|
).images
|
||||||
|
|
||||||
# denormalize the images and save to tensorboard
|
# denormalize the images and save to tensorboard
|
||||||
|
|
|
@ -334,6 +334,11 @@ class ConfigMixin:
|
||||||
# 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
|
# 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
|
||||||
init_dict = {}
|
init_dict = {}
|
||||||
for key in expected_keys:
|
for key in expected_keys:
|
||||||
|
# if config param is passed to kwarg and is present in config dict
|
||||||
|
# it should overwrite existing config dict key
|
||||||
|
if key in kwargs and key in config_dict:
|
||||||
|
config_dict[key] = kwargs.pop(key)
|
||||||
|
|
||||||
if key in kwargs:
|
if key in kwargs:
|
||||||
# overwrite key
|
# overwrite key
|
||||||
init_dict[key] = kwargs.pop(key)
|
init_dict[key] = kwargs.pop(key)
|
||||||
|
|
|
@ -18,7 +18,9 @@ from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from ...configuration_utils import FrozenDict
|
||||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
from ...utils import deprecate
|
||||||
|
|
||||||
|
|
||||||
class DDPMPipeline(DiffusionPipeline):
|
class DDPMPipeline(DiffusionPipeline):
|
||||||
|
@ -45,7 +47,6 @@ class DDPMPipeline(DiffusionPipeline):
|
||||||
num_inference_steps: int = 1000,
|
num_inference_steps: int = 1000,
|
||||||
output_type: Optional[str] = "pil",
|
output_type: Optional[str] = "pil",
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
predict_epsilon: bool = True,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[ImagePipelineOutput, Tuple]:
|
) -> Union[ImagePipelineOutput, Tuple]:
|
||||||
r"""
|
r"""
|
||||||
|
@ -69,6 +70,16 @@ class DDPMPipeline(DiffusionPipeline):
|
||||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
||||||
generated images.
|
generated images.
|
||||||
"""
|
"""
|
||||||
|
message = (
|
||||||
|
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
|
||||||
|
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
|
||||||
|
)
|
||||||
|
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||||
|
|
||||||
|
if predict_epsilon is not None:
|
||||||
|
new_config = dict(self.scheduler.config)
|
||||||
|
new_config["predict_epsilon"] = predict_epsilon
|
||||||
|
self.scheduler._internal_dict = FrozenDict(new_config)
|
||||||
|
|
||||||
# Sample gaussian noise to begin loop
|
# Sample gaussian noise to begin loop
|
||||||
image = torch.randn(
|
image = torch.randn(
|
||||||
|
|
|
@ -21,8 +21,8 @@ from typing import Optional, Tuple, Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
||||||
from ..utils import BaseOutput
|
from ..utils import BaseOutput, deprecate
|
||||||
from .scheduling_utils import SchedulerMixin
|
from .scheduling_utils import SchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
|
@ -99,6 +99,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
||||||
clip_sample (`bool`, default `True`):
|
clip_sample (`bool`, default `True`):
|
||||||
option to clip predicted sample between -1 and 1 for numerical stability.
|
option to clip predicted sample between -1 and 1 for numerical stability.
|
||||||
|
predict_epsilon (`bool`):
|
||||||
|
optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -121,6 +123,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
trained_betas: Optional[np.ndarray] = None,
|
trained_betas: Optional[np.ndarray] = None,
|
||||||
variance_type: str = "fixed_small",
|
variance_type: str = "fixed_small",
|
||||||
clip_sample: bool = True,
|
clip_sample: bool = True,
|
||||||
|
predict_epsilon: bool = True,
|
||||||
):
|
):
|
||||||
if trained_betas is not None:
|
if trained_betas is not None:
|
||||||
self.betas = torch.from_numpy(trained_betas)
|
self.betas = torch.from_numpy(trained_betas)
|
||||||
|
@ -221,9 +224,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
model_output: torch.FloatTensor,
|
model_output: torch.FloatTensor,
|
||||||
timestep: int,
|
timestep: int,
|
||||||
sample: torch.FloatTensor,
|
sample: torch.FloatTensor,
|
||||||
predict_epsilon=True,
|
|
||||||
generator=None,
|
generator=None,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
|
**kwargs,
|
||||||
) -> Union[DDPMSchedulerOutput, Tuple]:
|
) -> Union[DDPMSchedulerOutput, Tuple]:
|
||||||
"""
|
"""
|
||||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||||
|
@ -234,8 +237,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||||
sample (`torch.FloatTensor`):
|
sample (`torch.FloatTensor`):
|
||||||
current instance of sample being created by diffusion process.
|
current instance of sample being created by diffusion process.
|
||||||
predict_epsilon (`bool`):
|
|
||||||
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
|
|
||||||
generator: random number generator.
|
generator: random number generator.
|
||||||
return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
|
return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
|
||||||
|
|
||||||
|
@ -245,6 +246,16 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
returning a tuple, the first element is the sample tensor.
|
returning a tuple, the first element is the sample tensor.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
message = (
|
||||||
|
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
|
||||||
|
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
|
||||||
|
)
|
||||||
|
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||||
|
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
|
||||||
|
new_config = dict(self.config)
|
||||||
|
new_config["predict_epsilon"] = predict_epsilon
|
||||||
|
self._internal_dict = FrozenDict(new_config)
|
||||||
|
|
||||||
t = timestep
|
t = timestep
|
||||||
|
|
||||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
|
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
|
||||||
|
@ -260,7 +271,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
# 2. compute predicted original sample from predicted noise also called
|
# 2. compute predicted original sample from predicted noise also called
|
||||||
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
||||||
if predict_epsilon:
|
if self.config.predict_epsilon:
|
||||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||||
else:
|
else:
|
||||||
pred_original_sample = model_output
|
pred_original_sample = model_output
|
||||||
|
|
|
@ -22,7 +22,8 @@ import flax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from jax import random
|
from jax import random
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
||||||
|
from ..utils import deprecate
|
||||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
|
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
|
||||||
|
|
||||||
|
|
||||||
|
@ -97,7 +98,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||||
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
||||||
clip_sample (`bool`, default `True`):
|
clip_sample (`bool`, default `True`):
|
||||||
option to clip predicted sample between -1 and 1 for numerical stability.
|
option to clip predicted sample between -1 and 1 for numerical stability.
|
||||||
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
|
predict_epsilon (`bool`):
|
||||||
|
optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -115,6 +117,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||||
trained_betas: Optional[jnp.ndarray] = None,
|
trained_betas: Optional[jnp.ndarray] = None,
|
||||||
variance_type: str = "fixed_small",
|
variance_type: str = "fixed_small",
|
||||||
clip_sample: bool = True,
|
clip_sample: bool = True,
|
||||||
|
predict_epsilon: bool = True,
|
||||||
):
|
):
|
||||||
if trained_betas is not None:
|
if trained_betas is not None:
|
||||||
self.betas = jnp.asarray(trained_betas)
|
self.betas = jnp.asarray(trained_betas)
|
||||||
|
@ -196,6 +199,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||||
key: random.KeyArray,
|
key: random.KeyArray,
|
||||||
predict_epsilon: bool = True,
|
predict_epsilon: bool = True,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
|
**kwargs,
|
||||||
) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
|
) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
|
||||||
"""
|
"""
|
||||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||||
|
@ -208,8 +212,6 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||||
sample (`jnp.ndarray`):
|
sample (`jnp.ndarray`):
|
||||||
current instance of sample being created by diffusion process.
|
current instance of sample being created by diffusion process.
|
||||||
key (`random.KeyArray`): a PRNG key.
|
key (`random.KeyArray`): a PRNG key.
|
||||||
predict_epsilon (`bool`):
|
|
||||||
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
|
|
||||||
return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class
|
return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -217,6 +219,16 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||||
`tuple`. When returning a tuple, the first element is the sample tensor.
|
`tuple`. When returning a tuple, the first element is the sample tensor.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
message = (
|
||||||
|
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
|
||||||
|
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
|
||||||
|
)
|
||||||
|
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||||
|
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
|
||||||
|
new_config = dict(self.config)
|
||||||
|
new_config["predict_epsilon"] = predict_epsilon
|
||||||
|
self._internal_dict = FrozenDict(new_config)
|
||||||
|
|
||||||
t = timestep
|
t = timestep
|
||||||
|
|
||||||
if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]:
|
if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]:
|
||||||
|
@ -232,7 +244,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
# 2. compute predicted original sample from predicted noise also called
|
# 2. compute predicted original sample from predicted noise also called
|
||||||
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
||||||
if predict_epsilon:
|
if self.config.predict_epsilon:
|
||||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||||
else:
|
else:
|
||||||
pred_original_sample = model_output
|
pred_original_sample = model_output
|
||||||
|
|
|
@ -42,7 +42,6 @@ class CustomLocalPipeline(DiffusionPipeline):
|
||||||
self,
|
self,
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
generator: Optional[torch.Generator] = None,
|
generator: Optional[torch.Generator] = None,
|
||||||
eta: float = 0.0,
|
|
||||||
num_inference_steps: int = 50,
|
num_inference_steps: int = 50,
|
||||||
output_type: Optional[str] = "pil",
|
output_type: Optional[str] = "pil",
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
|
@ -89,7 +88,7 @@ class CustomLocalPipeline(DiffusionPipeline):
|
||||||
# 2. predict previous mean of image x_t-1 and add variance depending on eta
|
# 2. predict previous mean of image x_t-1 and add variance depending on eta
|
||||||
# eta corresponds to η in paper and should be between [0, 1]
|
# eta corresponds to η in paper and should be between [0, 1]
|
||||||
# do x_t -> x_t-1
|
# do x_t -> x_t-1
|
||||||
image = self.scheduler.step(model_output, t, image, eta).prev_sample
|
image = self.scheduler.step(model_output, t, image).prev_sample
|
||||||
|
|
||||||
image = (image / 2 + 0.5).clamp(0, 1)
|
image = (image / 2 + 0.5).clamp(0, 1)
|
||||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||||
|
|
|
@ -19,6 +19,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
|
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
|
||||||
|
from diffusers.utils import deprecate
|
||||||
from diffusers.utils.testing_utils import require_torch, slow, torch_device
|
from diffusers.utils.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from ...test_pipelines_common import PipelineTesterMixin
|
from ...test_pipelines_common import PipelineTesterMixin
|
||||||
|
@ -28,8 +29,74 @@ torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
|
||||||
|
|
||||||
class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
# FIXME: add fast tests
|
@property
|
||||||
pass
|
def dummy_uncond_unet(self):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
model = UNet2DModel(
|
||||||
|
block_out_channels=(32, 64),
|
||||||
|
layers_per_block=2,
|
||||||
|
sample_size=32,
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=3,
|
||||||
|
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
|
||||||
|
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def test_inference(self):
|
||||||
|
unet = self.dummy_uncond_unet
|
||||||
|
scheduler = DDPMScheduler()
|
||||||
|
|
||||||
|
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
|
||||||
|
ddpm.to(torch_device)
|
||||||
|
ddpm.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
# Warmup pass when using mps (see #372)
|
||||||
|
if torch_device == "mps":
|
||||||
|
_ = ddpm(num_inference_steps=1)
|
||||||
|
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
|
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
|
||||||
|
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
|
image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
|
||||||
|
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
assert image.shape == (1, 32, 32, 3)
|
||||||
|
expected_slice = np.array(
|
||||||
|
[5.589e-01, 7.089e-01, 2.632e-01, 6.841e-01, 1.000e-04, 9.999e-01, 1.973e-01, 1.000e-04, 8.010e-02]
|
||||||
|
)
|
||||||
|
tolerance = 1e-2 if torch_device != "mps" else 3e-2
|
||||||
|
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
|
||||||
|
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance
|
||||||
|
|
||||||
|
def test_inference_predict_epsilon(self):
|
||||||
|
deprecate("remove this test", "0.10.0", "remove")
|
||||||
|
unet = self.dummy_uncond_unet
|
||||||
|
scheduler = DDPMScheduler(predict_epsilon=False)
|
||||||
|
|
||||||
|
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
|
||||||
|
ddpm.to(torch_device)
|
||||||
|
ddpm.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
# Warmup pass when using mps (see #372)
|
||||||
|
if torch_device == "mps":
|
||||||
|
_ = ddpm(num_inference_steps=1)
|
||||||
|
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
|
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
|
||||||
|
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
|
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0]
|
||||||
|
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
image_eps_slice = image_eps[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
assert image.shape == (1, 32, 32, 3)
|
||||||
|
tolerance = 1e-2 if torch_device != "mps" else 3e-2
|
||||||
|
assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
|
|
|
@ -21,6 +21,7 @@ import unittest
|
||||||
import diffusers
|
import diffusers
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
|
DDPMScheduler,
|
||||||
DPMSolverMultistepScheduler,
|
DPMSolverMultistepScheduler,
|
||||||
EulerAncestralDiscreteScheduler,
|
EulerAncestralDiscreteScheduler,
|
||||||
EulerDiscreteScheduler,
|
EulerDiscreteScheduler,
|
||||||
|
@ -291,6 +292,29 @@ class ConfigTester(unittest.TestCase):
|
||||||
# no warning should be thrown
|
# no warning should be thrown
|
||||||
assert cap_logger.out == ""
|
assert cap_logger.out == ""
|
||||||
|
|
||||||
|
def test_overwrite_config_on_load(self):
|
||||||
|
logger = logging.get_logger("diffusers.configuration_utils")
|
||||||
|
|
||||||
|
with CaptureLogger(logger) as cap_logger:
|
||||||
|
ddpm = DDPMScheduler.from_config(
|
||||||
|
"hf-internal-testing/tiny-stable-diffusion-torch",
|
||||||
|
subfolder="scheduler",
|
||||||
|
predict_epsilon=False,
|
||||||
|
beta_end=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
with CaptureLogger(logger) as cap_logger_2:
|
||||||
|
ddpm_2 = DDPMScheduler.from_config("google/ddpm-celebahq-256", beta_start=88)
|
||||||
|
|
||||||
|
assert ddpm.__class__ == DDPMScheduler
|
||||||
|
assert ddpm.config.predict_epsilon is False
|
||||||
|
assert ddpm.config.beta_end == 8
|
||||||
|
assert ddpm_2.config.beta_start == 88
|
||||||
|
|
||||||
|
# no warning should be thrown
|
||||||
|
assert cap_logger.out == ""
|
||||||
|
assert cap_logger_2.out == ""
|
||||||
|
|
||||||
def test_load_dpmsolver(self):
|
def test_load_dpmsolver(self):
|
||||||
logger = logging.get_logger("diffusers.configuration_utils")
|
logger = logging.get_logger("diffusers.configuration_utils")
|
||||||
|
|
||||||
|
|
|
@ -107,6 +107,7 @@ class CustomPipelineTests(unittest.TestCase):
|
||||||
images, output_str = pipeline(num_inference_steps=2, output_type="np")
|
images, output_str = pipeline(num_inference_steps=2, output_type="np")
|
||||||
|
|
||||||
assert images[0].shape == (1, 32, 32, 3)
|
assert images[0].shape == (1, 32, 32, 3)
|
||||||
|
|
||||||
# compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102
|
# compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102
|
||||||
assert output_str == "This is a test"
|
assert output_str == "This is a test"
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ from diffusers import (
|
||||||
ScoreSdeVeScheduler,
|
ScoreSdeVeScheduler,
|
||||||
VQDiffusionScheduler,
|
VQDiffusionScheduler,
|
||||||
)
|
)
|
||||||
from diffusers.utils import torch_device
|
from diffusers.utils import deprecate, torch_device
|
||||||
|
|
||||||
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = False
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
@ -393,6 +393,34 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
||||||
for clip_sample in [True, False]:
|
for clip_sample in [True, False]:
|
||||||
self.check_over_configs(clip_sample=clip_sample)
|
self.check_over_configs(clip_sample=clip_sample)
|
||||||
|
|
||||||
|
def test_predict_epsilon(self):
|
||||||
|
for predict_epsilon in [True, False]:
|
||||||
|
self.check_over_configs(predict_epsilon=predict_epsilon)
|
||||||
|
|
||||||
|
def test_deprecated_epsilon(self):
|
||||||
|
deprecate("remove this test", "0.10.0", "remove")
|
||||||
|
scheduler_class = self.scheduler_classes[0]
|
||||||
|
scheduler_config = self.get_scheduler_config()
|
||||||
|
|
||||||
|
sample = self.dummy_sample_deter
|
||||||
|
residual = 0.1 * self.dummy_sample_deter
|
||||||
|
time_step = 4
|
||||||
|
|
||||||
|
scheduler = scheduler_class(**scheduler_config)
|
||||||
|
scheduler_eps = scheduler_class(predict_epsilon=False, **scheduler_config)
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
|
kwargs["generator"] = torch.Generator().manual_seed(0)
|
||||||
|
output = scheduler.step(residual, time_step, sample, predict_epsilon=False, **kwargs).prev_sample
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
|
kwargs["generator"] = torch.Generator().manual_seed(0)
|
||||||
|
output_eps = scheduler_eps.step(residual, time_step, sample, predict_epsilon=False, **kwargs).prev_sample
|
||||||
|
|
||||||
|
assert (output - output_eps).abs().sum() < 1e-5
|
||||||
|
|
||||||
def test_time_indices(self):
|
def test_time_indices(self):
|
||||||
for t in [0, 500, 999]:
|
for t in [0, 500, 999]:
|
||||||
self.check_over_forward(time_step=t)
|
self.check_over_forward(time_step=t)
|
||||||
|
|
Loading…
Reference in New Issue