Deprecate `predict_epsilon` (#1393)

* Adapt ddpm, ddpmsolver to prediction_type.

* Deprecate predict_epsilon in __init__.

* Bring FlaxDDIMScheduler up to date with DDIMScheduler.

* Set prediction_type as an ivar for consistency.

* Convert pipeline_ddpm

* Adapt tests.

* Adapt unconditional training script.

* Adapt BitDiffusion example.

* Add missing kwargs in dpmsolver_multistep

* Ugly workaround to accept deprecated predict_epsilon when loading
schedulers using from_pretrained.

* make style

* Remove import no longer in use.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Use config.prediction_type everywhere

* Add a couple of Flax prediction type tests.

* make style

* fix register deprecated arg

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Pedro Cuenca 2022-11-25 14:02:15 +01:00 committed by GitHub
parent babfb8a020
commit d52388f486
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 260 additions and 87 deletions

View File

@ -138,7 +138,7 @@ def ddpm_bit_scheduler_step(
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
predict_epsilon=True,
prediction_type="epsilon",
generator=None,
return_dict: bool = True,
) -> Union[DDPMSchedulerOutput, Tuple]:
@ -150,8 +150,8 @@ def ddpm_bit_scheduler_step(
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the samples (`sample`).
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
Returns:
@ -174,10 +174,12 @@ def ddpm_bit_scheduler_step(
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if predict_epsilon:
if prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
else:
elif prediction_type == "sample":
pred_original_sample = model_output
else:
raise ValueError(f"Unsupported prediction_type {prediction_type}.")
# 3. Clip "predicted x_0"
scale = self.bit_scale

View File

@ -194,9 +194,10 @@ def parse_args():
)
parser.add_argument(
"--predict_epsilon",
action="store_true",
default=True,
"--prediction_type",
type=str,
default="epsilon",
choices=["epsilon", "sample"],
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
)
@ -256,13 +257,13 @@ def main(args):
"UpBlock2D",
),
)
accepts_predict_epsilon = "predict_epsilon" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
if accepts_predict_epsilon:
if accepts_prediction_type:
noise_scheduler = DDPMScheduler(
num_train_timesteps=args.ddpm_num_steps,
beta_schedule=args.ddpm_beta_schedule,
predict_epsilon=args.predict_epsilon,
prediction_type=args.prediction_type,
)
else:
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
@ -365,9 +366,9 @@ def main(args):
# Predict the noise residual
model_output = model(noisy_images, timesteps).sample
if args.predict_epsilon:
if args.prediction_type == "epsilon":
loss = F.mse_loss(model_output, noise) # this could have different weights!
else:
elif args.prediction_type == "sample":
alpha_t = _extract_into_tensor(
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
)
@ -376,6 +377,8 @@ def main(args):
model_output, clean_images, reduction="none"
) # use SNR weighting from distillation paper
loss = loss.mean()
else:
raise ValueError(f"Unsupported prediction type: {args.prediction_type}")
accelerator.backward(loss)

View File

@ -195,6 +195,11 @@ class ConfigMixin:
if "dtype" in unused_kwargs:
init_dict["dtype"] = unused_kwargs.pop("dtype")
if "predict_epsilon" in unused_kwargs and "prediction_type" not in init_dict:
deprecate("remove this", "0.10.0", "remove")
predict_epsilon = unused_kwargs.pop("predict_epsilon")
init_dict["prediction_type"] = "epsilon" if predict_epsilon else "sample"
# Return model and optionally state and/or unused_kwargs
model = cls(**init_dict)

View File

@ -89,6 +89,7 @@ class ValueGuidedRLPipeline(DiffusionPipeline):
x = x + scale * grad
x = self.reset_x0(x, conditions, self.action_dim)
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
# TODO: set prediction_type when instantiating the model
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
# apply conditions to the trajectory

View File

@ -70,14 +70,14 @@ class DDPMPipeline(DiffusionPipeline):
generated images.
"""
message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
new_config = dict(self.scheduler.config)
new_config["predict_epsilon"] = predict_epsilon
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
self.scheduler._internal_dict = FrozenDict(new_config)
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
@ -114,9 +114,7 @@ class DDPMPipeline(DiffusionPipeline):
model_output = self.unet(image, t).sample
# 2. compute previous image: x_t -> x_t-1
image = self.scheduler.step(
model_output, t, image, generator=generator, predict_epsilon=predict_epsilon
).prev_sample
image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()

View File

@ -23,7 +23,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin
@ -106,6 +106,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
`v-prediction` is not supported for this scheduler.
"""
@ -123,7 +126,16 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear":
@ -139,8 +151,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.prediction_type = prediction_type
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
@ -261,17 +271,17 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if self.prediction_type == "epsilon":
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.prediction_type == "sample":
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
elif self.prediction_type == "v_prediction":
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
# predict V
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or"
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction`"
)

View File

@ -23,6 +23,7 @@ import flax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
FlaxSchedulerMixin,
@ -108,6 +109,10 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
`v-prediction` is not supported for this scheduler.
"""
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@ -125,7 +130,17 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
beta_schedule: str = "linear",
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if beta_schedule == "linear":
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
elif beta_schedule == "scaled_linear":
@ -259,7 +274,19 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
# predict V
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction`"
)
# 4. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 α_t1)/(1 α_t)) * sqrt(1 α_t/α_t1)

View File

@ -99,9 +99,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
predict_epsilon (`bool`):
optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise.
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
`v-prediction` is not supported for this scheduler.
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@ -116,8 +116,17 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[np.ndarray] = None,
variance_type: str = "fixed_small",
clip_sample: bool = True,
predict_epsilon: bool = True,
prediction_type: str = "epsilon",
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear":
@ -241,13 +250,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"""
message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
if predict_epsilon is not None:
new_config = dict(self.config)
new_config["predict_epsilon"] = predict_epsilon
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
self._internal_dict = FrozenDict(new_config)
t = timestep
@ -265,10 +274,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if self.config.predict_epsilon:
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
else:
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
" for the DDPMScheduler."
)
# 3. Clip "predicted x_0"
if self.config.clip_sample:

View File

@ -103,9 +103,9 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
predict_epsilon (`bool`):
optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise.
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
`v-prediction` is not supported for this scheduler.
"""
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@ -124,8 +124,17 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
trained_betas: Optional[jnp.ndarray] = None,
variance_type: str = "fixed_small",
clip_sample: bool = True,
predict_epsilon: bool = True,
prediction_type: str = "epsilon",
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
elif beta_schedule == "linear":
@ -204,7 +213,6 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
timestep: int,
sample: jnp.ndarray,
key: random.KeyArray,
predict_epsilon: bool = True,
return_dict: bool = True,
**kwargs,
) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
@ -227,13 +235,13 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
if predict_epsilon is not None:
new_config = dict(self.config)
new_config["predict_epsilon"] = predict_epsilon
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
self._internal_dict = FrozenDict(new_config)
t = timestep
@ -251,10 +259,15 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if self.config.predict_epsilon:
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
else:
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
" for the FlaxDDPMScheduler."
)
# 3. Clip "predicted x_0"
if self.config.clip_sample:

View File

@ -21,7 +21,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, deprecate
from .scheduling_utils import SchedulerMixin, SchedulerOutput
@ -87,10 +87,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
solver_order (`int`, default `2`):
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling.
predict_epsilon (`bool`, default `True`):
we currently support both the noise prediction model and the data prediction model. If the model predicts
the noise / epsilon, set `predict_epsilon` to `True`. If the model predicts the data / x0 directly, set
`predict_epsilon` to `False`.
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`.
`v-prediction` is not supported for this scheduler.
thresholding (`bool`, default `False`):
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
@ -128,14 +127,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
solver_order: int = 2,
predict_epsilon: bool = True,
prediction_type: str = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear":
@ -221,11 +229,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
# DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type == "dpmsolver++":
if self.config.predict_epsilon:
if self.config.prediction_type == "epsilon":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t
else:
elif self.config.prediction_type == "sample":
x0_pred = model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
" for the DPMSolverMultistepScheduler."
)
if self.config.thresholding:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = torch.quantile(
@ -239,12 +253,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver":
if self.config.predict_epsilon:
if self.config.prediction_type == "epsilon":
return model_output
else:
elif self.config.prediction_type == "sample":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = (sample - alpha_t * model_output) / sigma_t
return epsilon
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
" for the DPMSolverMultistepScheduler."
)
def dpm_solver_first_order_update(
self,

View File

@ -23,6 +23,7 @@ import jax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
FlaxSchedulerMixin,
@ -118,10 +119,9 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
solver_order (`int`, default `2`):
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling.
predict_epsilon (`bool`, default `True`):
we currently support both the noise prediction model and the data prediction model. If the model predicts
the noise / epsilon, set `predict_epsilon` to `True`. If the model predicts the data / x0 directly, set
`predict_epsilon` to `False`.
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`.
`v-prediction` is not supported for this scheduler.
thresholding (`bool`, default `False`):
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
@ -163,14 +163,23 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
beta_schedule: str = "linear",
trained_betas: Optional[jnp.ndarray] = None,
solver_order: int = 2,
predict_epsilon: bool = True,
prediction_type: str = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
elif beta_schedule == "linear":
@ -260,11 +269,17 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
# DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type == "dpmsolver++":
if self.config.predict_epsilon:
if self.config.prediction_type == "epsilon":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t
else:
elif self.config.prediction_type == "sample":
x0_pred = model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
" for the FlaxDPMSolverMultistepScheduler."
)
if self.config.thresholding:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = jnp.percentile(
@ -277,12 +292,17 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
return x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver":
if self.config.predict_epsilon:
if self.config.prediction_type == "epsilon":
return model_output
else:
elif self.config.prediction_type == "sample":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = (sample - alpha_t * model_output) / sigma_t
return epsilon
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
" for the FlaxDPMSolverMultistepScheduler."
)
def dpm_solver_first_order_update(
self, model_output: jnp.ndarray, timestep: int, prev_timestep: int, sample: jnp.ndarray

View File

@ -92,8 +92,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.prediction_type = prediction_type
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
@ -232,14 +230,14 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.prediction_type == "epsilon":
if self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma_hat * model_output
elif self.prediction_type == "v_prediction":
elif self.config.prediction_type == "v_prediction":
# * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
else:
raise ValueError(
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, or `v_prediction`"
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
)
# 2. Convert to an ODE derivative

View File

@ -68,7 +68,7 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_inference_predict_epsilon(self):
def test_inference_deprecated_predict_epsilon(self):
deprecate("remove this test", "0.10.0", "remove")
unet = self.dummy_uncond_unet
scheduler = DDPMScheduler(predict_epsilon=False)
@ -98,6 +98,35 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
tolerance = 1e-2 if torch_device != "mps" else 3e-2
assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance
def test_inference_predict_sample(self):
unet = self.dummy_uncond_unet
scheduler = DDPMScheduler(prediction_type="sample")
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
# Warmup pass when using mps (see #372)
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
generator = generator.manual_seed(0)
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy")[0]
image_slice = image[0, -3:, -3:, -1]
image_eps_slice = image_eps[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
tolerance = 1e-2 if torch_device != "mps" else 3e-2
assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance
@slow
@require_torch_gpu

View File

@ -26,6 +26,7 @@ from diffusers import (
logging,
)
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import deprecate
from diffusers.utils.testing_utils import CaptureLogger
@ -194,17 +195,27 @@ class ConfigTester(unittest.TestCase):
ddpm = DDPMScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch",
subfolder="scheduler",
predict_epsilon=False,
prediction_type="sample",
beta_end=8,
)
with CaptureLogger(logger) as cap_logger_2:
ddpm_2 = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", beta_start=88)
with CaptureLogger(logger) as cap_logger:
deprecate("remove this case", "0.10.0", "remove")
ddpm_3 = DDPMScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch",
subfolder="scheduler",
predict_epsilon=False,
beta_end=8,
)
assert ddpm.__class__ == DDPMScheduler
assert ddpm.config.predict_epsilon is False
assert ddpm.config.prediction_type == "sample"
assert ddpm.config.beta_end == 8
assert ddpm_2.config.beta_start == 88
assert ddpm_3.config.prediction_type == "sample"
# no warning should be thrown
assert cap_logger.out == ""

View File

@ -20,7 +20,6 @@ import random
import shutil
import tempfile
import unittest
from functools import partial
import numpy as np
import torch
@ -332,14 +331,13 @@ class PipelineFastTests(unittest.TestCase):
@parameterized.expand(
[
[DDIMScheduler, DDIMPipeline, 32],
[partial(DDPMScheduler, predict_epsilon=True), DDPMPipeline, 32],
[DDPMScheduler, DDPMPipeline, 32],
[DDIMScheduler, DDIMPipeline, (32, 64)],
[partial(DDPMScheduler, predict_epsilon=True), DDPMPipeline, (64, 32)],
[DDPMScheduler, DDPMPipeline, (64, 32)],
]
)
def test_uncond_unet_components(self, scheduler_fn=DDPMScheduler, pipeline_fn=DDPMPipeline, sample_size=32):
unet = self.dummy_uncond_unet(sample_size)
# DDIM doesn't take `predict_epsilon`, and DDPM requires it -- so using partial in parameterized decorator
scheduler = scheduler_fn()
pipeline = pipeline_fn(unet, scheduler).to(torch_device)

View File

@ -599,7 +599,12 @@ class DDPMSchedulerTest(SchedulerCommonTest):
for clip_sample in [True, False]:
self.check_over_configs(clip_sample=clip_sample)
def test_predict_epsilon(self):
def test_prediction_type(self):
for prediction_type in ["epsilon", "sample"]:
self.check_over_configs(prediction_type=prediction_type)
def test_deprecated_predict_epsilon(self):
deprecate("remove this test", "0.10.0", "remove")
for predict_epsilon in [True, False]:
self.check_over_configs(predict_epsilon=predict_epsilon)
@ -795,7 +800,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
"beta_end": 0.02,
"beta_schedule": "linear",
"solver_order": 2,
"predict_epsilon": True,
"prediction_type": "epsilon",
"thresholding": False,
"sample_max_value": 1.0,
"algorithm_type": "dpmsolver++",
@ -921,10 +926,10 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
for order in [1, 2, 3]:
for solver_type in ["midpoint", "heun"]:
for threshold in [0.5, 1.0, 2.0]:
for predict_epsilon in [True, False]:
for prediction_type in ["epsilon", "sample"]:
self.check_over_configs(
thresholding=True,
predict_epsilon=predict_epsilon,
prediction_type=prediction_type,
sample_max_value=threshold,
algorithm_type="dpmsolver++",
solver_order=order,
@ -935,17 +940,17 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
for algorithm_type in ["dpmsolver", "dpmsolver++"]:
for solver_type in ["midpoint", "heun"]:
for order in [1, 2, 3]:
for predict_epsilon in [True, False]:
for prediction_type in ["epsilon", "sample"]:
self.check_over_configs(
solver_order=order,
solver_type=solver_type,
predict_epsilon=predict_epsilon,
prediction_type=prediction_type,
algorithm_type=algorithm_type,
)
sample = self.full_loop(
solver_order=order,
solver_type=solver_type,
predict_epsilon=predict_epsilon,
prediction_type=prediction_type,
algorithm_type=algorithm_type,
)
assert not torch.isnan(sample).any(), "Samples have nan numbers"

View File

@ -17,7 +17,7 @@ import unittest
from typing import Dict, List, Tuple
from diffusers import FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxPNDMScheduler
from diffusers.utils import is_flax_available
from diffusers.utils import deprecate, is_flax_available
from diffusers.utils.testing_utils import require_flax
@ -599,6 +599,26 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
assert abs(result_sum - 149.0784) < 1e-2
assert abs(result_mean - 0.1941) < 1e-3
def test_prediction_type(self):
for prediction_type in ["epsilon", "sample", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)
def test_deprecated_predict_epsilon(self):
deprecate("remove this test", "0.10.0", "remove")
for predict_epsilon in [True, False]:
self.check_over_configs(predict_epsilon=predict_epsilon)
def test_deprecated_predict_epsilon_to_prediction_type(self):
deprecate("remove this test", "0.10.0", "remove")
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(predict_epsilon=True)
scheduler = scheduler_class.from_config(scheduler_config)
assert scheduler.prediction_type == "epsilon"
scheduler_config = self.get_scheduler_config(predict_epsilon=False)
scheduler = scheduler_class.from_config(scheduler_config)
assert scheduler.prediction_type == "sample"
@require_flax
class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):