Deprecate `predict_epsilon` (#1393)
* Adapt ddpm, ddpmsolver to prediction_type. * Deprecate predict_epsilon in __init__. * Bring FlaxDDIMScheduler up to date with DDIMScheduler. * Set prediction_type as an ivar for consistency. * Convert pipeline_ddpm * Adapt tests. * Adapt unconditional training script. * Adapt BitDiffusion example. * Add missing kwargs in dpmsolver_multistep * Ugly workaround to accept deprecated predict_epsilon when loading schedulers using from_pretrained. * make style * Remove import no longer in use. * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Use config.prediction_type everywhere * Add a couple of Flax prediction type tests. * make style * fix register deprecated arg Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
babfb8a020
commit
d52388f486
|
@ -138,7 +138,7 @@ def ddpm_bit_scheduler_step(
|
|||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
predict_epsilon=True,
|
||||
prediction_type="epsilon",
|
||||
generator=None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DDPMSchedulerOutput, Tuple]:
|
||||
|
@ -150,8 +150,8 @@ def ddpm_bit_scheduler_step(
|
|||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
predict_epsilon (`bool`):
|
||||
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
|
||||
prediction_type (`str`, default `epsilon`):
|
||||
indicates whether the model predicts the noise (epsilon), or the samples (`sample`).
|
||||
generator: random number generator.
|
||||
return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
|
||||
Returns:
|
||||
|
@ -174,10 +174,12 @@ def ddpm_bit_scheduler_step(
|
|||
|
||||
# 2. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
if predict_epsilon:
|
||||
if prediction_type == "epsilon":
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
else:
|
||||
elif prediction_type == "sample":
|
||||
pred_original_sample = model_output
|
||||
else:
|
||||
raise ValueError(f"Unsupported prediction_type {prediction_type}.")
|
||||
|
||||
# 3. Clip "predicted x_0"
|
||||
scale = self.bit_scale
|
||||
|
|
|
@ -194,9 +194,10 @@ def parse_args():
|
|||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--predict_epsilon",
|
||||
action="store_true",
|
||||
default=True,
|
||||
"--prediction_type",
|
||||
type=str,
|
||||
default="epsilon",
|
||||
choices=["epsilon", "sample"],
|
||||
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
|
||||
)
|
||||
|
||||
|
@ -256,13 +257,13 @@ def main(args):
|
|||
"UpBlock2D",
|
||||
),
|
||||
)
|
||||
accepts_predict_epsilon = "predict_epsilon" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
|
||||
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
|
||||
|
||||
if accepts_predict_epsilon:
|
||||
if accepts_prediction_type:
|
||||
noise_scheduler = DDPMScheduler(
|
||||
num_train_timesteps=args.ddpm_num_steps,
|
||||
beta_schedule=args.ddpm_beta_schedule,
|
||||
predict_epsilon=args.predict_epsilon,
|
||||
prediction_type=args.prediction_type,
|
||||
)
|
||||
else:
|
||||
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
|
||||
|
@ -365,9 +366,9 @@ def main(args):
|
|||
# Predict the noise residual
|
||||
model_output = model(noisy_images, timesteps).sample
|
||||
|
||||
if args.predict_epsilon:
|
||||
if args.prediction_type == "epsilon":
|
||||
loss = F.mse_loss(model_output, noise) # this could have different weights!
|
||||
else:
|
||||
elif args.prediction_type == "sample":
|
||||
alpha_t = _extract_into_tensor(
|
||||
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
|
||||
)
|
||||
|
@ -376,6 +377,8 @@ def main(args):
|
|||
model_output, clean_images, reduction="none"
|
||||
) # use SNR weighting from distillation paper
|
||||
loss = loss.mean()
|
||||
else:
|
||||
raise ValueError(f"Unsupported prediction type: {args.prediction_type}")
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
|
|
|
@ -195,6 +195,11 @@ class ConfigMixin:
|
|||
if "dtype" in unused_kwargs:
|
||||
init_dict["dtype"] = unused_kwargs.pop("dtype")
|
||||
|
||||
if "predict_epsilon" in unused_kwargs and "prediction_type" not in init_dict:
|
||||
deprecate("remove this", "0.10.0", "remove")
|
||||
predict_epsilon = unused_kwargs.pop("predict_epsilon")
|
||||
init_dict["prediction_type"] = "epsilon" if predict_epsilon else "sample"
|
||||
|
||||
# Return model and optionally state and/or unused_kwargs
|
||||
model = cls(**init_dict)
|
||||
|
||||
|
|
|
@ -89,6 +89,7 @@ class ValueGuidedRLPipeline(DiffusionPipeline):
|
|||
x = x + scale * grad
|
||||
x = self.reset_x0(x, conditions, self.action_dim)
|
||||
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
|
||||
# TODO: set prediction_type when instantiating the model
|
||||
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
|
||||
|
||||
# apply conditions to the trajectory
|
||||
|
|
|
@ -70,14 +70,14 @@ class DDPMPipeline(DiffusionPipeline):
|
|||
generated images.
|
||||
"""
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||
|
||||
if predict_epsilon is not None:
|
||||
new_config = dict(self.scheduler.config)
|
||||
new_config["predict_epsilon"] = predict_epsilon
|
||||
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
|
||||
self.scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
|
||||
|
@ -114,9 +114,7 @@ class DDPMPipeline(DiffusionPipeline):
|
|||
model_output = self.unet(image, t).sample
|
||||
|
||||
# 2. compute previous image: x_t -> x_t-1
|
||||
image = self.scheduler.step(
|
||||
model_output, t, image, generator=generator, predict_epsilon=predict_epsilon
|
||||
).prev_sample
|
||||
image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
|
|
@ -23,7 +23,7 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
|
@ -106,6 +106,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
an offset added to the inference steps. You can use a combination of `offset=1` and
|
||||
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
||||
stable diffusion.
|
||||
prediction_type (`str`, default `epsilon`):
|
||||
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
|
||||
`v-prediction` is not supported for this scheduler.
|
||||
|
||||
"""
|
||||
|
||||
|
@ -123,7 +126,16 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
prediction_type: str = "epsilon",
|
||||
**kwargs,
|
||||
):
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" DDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
|
@ -139,8 +151,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
self.prediction_type = prediction_type
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
|
@ -261,17 +271,17 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
# 3. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
if self.prediction_type == "epsilon":
|
||||
if self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
elif self.prediction_type == "sample":
|
||||
elif self.config.prediction_type == "sample":
|
||||
pred_original_sample = model_output
|
||||
elif self.prediction_type == "v_prediction":
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
||||
# predict V
|
||||
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or"
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
||||
" `v_prediction`"
|
||||
)
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ import flax
|
|||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from .scheduling_utils_flax import (
|
||||
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
|
||||
FlaxSchedulerMixin,
|
||||
|
@ -108,6 +109,10 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|||
an offset added to the inference steps. You can use a combination of `offset=1` and
|
||||
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
||||
stable diffusion.
|
||||
prediction_type (`str`, default `epsilon`):
|
||||
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
|
||||
`v-prediction` is not supported for this scheduler.
|
||||
|
||||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
@ -125,7 +130,17 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|||
beta_schedule: str = "linear",
|
||||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
prediction_type: str = "epsilon",
|
||||
**kwargs,
|
||||
):
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" FlaxDDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
if beta_schedule == "linear":
|
||||
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
|
@ -259,7 +274,19 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|||
|
||||
# 3. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
if self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
elif self.config.prediction_type == "sample":
|
||||
pred_original_sample = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
||||
# predict V
|
||||
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
||||
" `v_prediction`"
|
||||
)
|
||||
|
||||
# 4. compute variance: "sigma_t(η)" -> see formula (16)
|
||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||
|
|
|
@ -99,9 +99,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
||||
clip_sample (`bool`, default `True`):
|
||||
option to clip predicted sample between -1 and 1 for numerical stability.
|
||||
predict_epsilon (`bool`):
|
||||
optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise.
|
||||
|
||||
prediction_type (`str`, default `epsilon`):
|
||||
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
|
||||
`v-prediction` is not supported for this scheduler.
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
@ -116,8 +116,17 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
trained_betas: Optional[np.ndarray] = None,
|
||||
variance_type: str = "fixed_small",
|
||||
clip_sample: bool = True,
|
||||
predict_epsilon: bool = True,
|
||||
prediction_type: str = "epsilon",
|
||||
**kwargs,
|
||||
):
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
|
@ -241,13 +250,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
"""
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
|
||||
if predict_epsilon is not None:
|
||||
new_config = dict(self.config)
|
||||
new_config["predict_epsilon"] = predict_epsilon
|
||||
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
|
||||
self._internal_dict = FrozenDict(new_config)
|
||||
|
||||
t = timestep
|
||||
|
@ -265,10 +274,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
# 2. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
if self.config.predict_epsilon:
|
||||
if self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
else:
|
||||
elif self.config.prediction_type == "sample":
|
||||
pred_original_sample = model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
|
||||
" for the DDPMScheduler."
|
||||
)
|
||||
|
||||
# 3. Clip "predicted x_0"
|
||||
if self.config.clip_sample:
|
||||
|
|
|
@ -103,9 +103,9 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|||
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
||||
clip_sample (`bool`, default `True`):
|
||||
option to clip predicted sample between -1 and 1 for numerical stability.
|
||||
predict_epsilon (`bool`):
|
||||
optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise.
|
||||
|
||||
prediction_type (`str`, default `epsilon`):
|
||||
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
|
||||
`v-prediction` is not supported for this scheduler.
|
||||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
@ -124,8 +124,17 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|||
trained_betas: Optional[jnp.ndarray] = None,
|
||||
variance_type: str = "fixed_small",
|
||||
clip_sample: bool = True,
|
||||
predict_epsilon: bool = True,
|
||||
prediction_type: str = "epsilon",
|
||||
**kwargs,
|
||||
):
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = jnp.asarray(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
|
@ -204,7 +213,6 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|||
timestep: int,
|
||||
sample: jnp.ndarray,
|
||||
key: random.KeyArray,
|
||||
predict_epsilon: bool = True,
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
|
||||
|
@ -227,13 +235,13 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|||
|
||||
"""
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
|
||||
if predict_epsilon is not None:
|
||||
new_config = dict(self.config)
|
||||
new_config["predict_epsilon"] = predict_epsilon
|
||||
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
|
||||
self._internal_dict = FrozenDict(new_config)
|
||||
|
||||
t = timestep
|
||||
|
@ -251,10 +259,15 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|||
|
||||
# 2. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
if self.config.predict_epsilon:
|
||||
if self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
else:
|
||||
elif self.config.prediction_type == "sample":
|
||||
pred_original_sample = model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
|
||||
" for the FlaxDDPMScheduler."
|
||||
)
|
||||
|
||||
# 3. Clip "predicted x_0"
|
||||
if self.config.clip_sample:
|
||||
|
|
|
@ -21,7 +21,7 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, deprecate
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
|
@ -87,10 +87,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||
solver_order (`int`, default `2`):
|
||||
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
|
||||
sampling, and `solver_order=3` for unconditional sampling.
|
||||
predict_epsilon (`bool`, default `True`):
|
||||
we currently support both the noise prediction model and the data prediction model. If the model predicts
|
||||
the noise / epsilon, set `predict_epsilon` to `True`. If the model predicts the data / x0 directly, set
|
||||
`predict_epsilon` to `False`.
|
||||
prediction_type (`str`, default `epsilon`):
|
||||
indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`.
|
||||
`v-prediction` is not supported for this scheduler.
|
||||
thresholding (`bool`, default `False`):
|
||||
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
|
||||
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
|
||||
|
@ -128,14 +127,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
solver_order: int = 2,
|
||||
predict_epsilon: bool = True,
|
||||
prediction_type: str = "epsilon",
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
sample_max_value: float = 1.0,
|
||||
algorithm_type: str = "dpmsolver++",
|
||||
solver_type: str = "midpoint",
|
||||
lower_order_final: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" DPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
|
@ -221,11 +229,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||
"""
|
||||
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
if self.config.predict_epsilon:
|
||||
if self.config.prediction_type == "epsilon":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||||
else:
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0_pred = model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
|
||||
" for the DPMSolverMultistepScheduler."
|
||||
)
|
||||
|
||||
if self.config.thresholding:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
dynamic_max_val = torch.quantile(
|
||||
|
@ -239,12 +253,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||
return x0_pred
|
||||
# DPM-Solver needs to solve an integral of the noise prediction model.
|
||||
elif self.config.algorithm_type == "dpmsolver":
|
||||
if self.config.predict_epsilon:
|
||||
if self.config.prediction_type == "epsilon":
|
||||
return model_output
|
||||
else:
|
||||
elif self.config.prediction_type == "sample":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = (sample - alpha_t * model_output) / sigma_t
|
||||
return epsilon
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
|
||||
" for the DPMSolverMultistepScheduler."
|
||||
)
|
||||
|
||||
def dpm_solver_first_order_update(
|
||||
self,
|
||||
|
|
|
@ -23,6 +23,7 @@ import jax
|
|||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from .scheduling_utils_flax import (
|
||||
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
|
||||
FlaxSchedulerMixin,
|
||||
|
@ -118,10 +119,9 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|||
solver_order (`int`, default `2`):
|
||||
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
|
||||
sampling, and `solver_order=3` for unconditional sampling.
|
||||
predict_epsilon (`bool`, default `True`):
|
||||
we currently support both the noise prediction model and the data prediction model. If the model predicts
|
||||
the noise / epsilon, set `predict_epsilon` to `True`. If the model predicts the data / x0 directly, set
|
||||
`predict_epsilon` to `False`.
|
||||
prediction_type (`str`, default `epsilon`):
|
||||
indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`.
|
||||
`v-prediction` is not supported for this scheduler.
|
||||
thresholding (`bool`, default `False`):
|
||||
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
|
||||
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
|
||||
|
@ -163,14 +163,23 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[jnp.ndarray] = None,
|
||||
solver_order: int = 2,
|
||||
predict_epsilon: bool = True,
|
||||
prediction_type: str = "epsilon",
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
sample_max_value: float = 1.0,
|
||||
algorithm_type: str = "dpmsolver++",
|
||||
solver_type: str = "midpoint",
|
||||
lower_order_final: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" FlaxDPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = jnp.asarray(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
|
@ -260,11 +269,17 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|||
"""
|
||||
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
if self.config.predict_epsilon:
|
||||
if self.config.prediction_type == "epsilon":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||||
else:
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0_pred = model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
|
||||
" for the FlaxDPMSolverMultistepScheduler."
|
||||
)
|
||||
|
||||
if self.config.thresholding:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
dynamic_max_val = jnp.percentile(
|
||||
|
@ -277,12 +292,17 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|||
return x0_pred
|
||||
# DPM-Solver needs to solve an integral of the noise prediction model.
|
||||
elif self.config.algorithm_type == "dpmsolver":
|
||||
if self.config.predict_epsilon:
|
||||
if self.config.prediction_type == "epsilon":
|
||||
return model_output
|
||||
else:
|
||||
elif self.config.prediction_type == "sample":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = (sample - alpha_t * model_output) / sigma_t
|
||||
return epsilon
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
|
||||
" for the FlaxDPMSolverMultistepScheduler."
|
||||
)
|
||||
|
||||
def dpm_solver_first_order_update(
|
||||
self, model_output: jnp.ndarray, timestep: int, prev_timestep: int, sample: jnp.ndarray
|
||||
|
|
|
@ -92,8 +92,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
self.prediction_type = prediction_type
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
|
@ -232,14 +230,14 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
if self.prediction_type == "epsilon":
|
||||
if self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = sample - sigma_hat * model_output
|
||||
elif self.prediction_type == "v_prediction":
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
# * c_out + input * c_skip
|
||||
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
||||
)
|
||||
|
||||
# 2. Convert to an ODE derivative
|
||||
|
|
|
@ -68,7 +68,7 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_inference_predict_epsilon(self):
|
||||
def test_inference_deprecated_predict_epsilon(self):
|
||||
deprecate("remove this test", "0.10.0", "remove")
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = DDPMScheduler(predict_epsilon=False)
|
||||
|
@ -98,6 +98,35 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||
tolerance = 1e-2 if torch_device != "mps" else 3e-2
|
||||
assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance
|
||||
|
||||
def test_inference_predict_sample(self):
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = DDPMScheduler(prediction_type="sample")
|
||||
|
||||
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
|
||||
ddpm.to(torch_device)
|
||||
ddpm.set_progress_bar_config(disable=None)
|
||||
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
_ = ddpm(num_inference_steps=1)
|
||||
|
||||
if torch_device == "mps":
|
||||
# device type MPS is not supported for torch.Generator() api.
|
||||
generator = torch.manual_seed(0)
|
||||
else:
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
|
||||
|
||||
generator = generator.manual_seed(0)
|
||||
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy")[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_eps_slice = image_eps[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
tolerance = 1e-2 if torch_device != "mps" else 3e-2
|
||||
assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
|
|
@ -26,6 +26,7 @@ from diffusers import (
|
|||
logging,
|
||||
)
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.utils import deprecate
|
||||
from diffusers.utils.testing_utils import CaptureLogger
|
||||
|
||||
|
||||
|
@ -194,17 +195,27 @@ class ConfigTester(unittest.TestCase):
|
|||
ddpm = DDPMScheduler.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch",
|
||||
subfolder="scheduler",
|
||||
predict_epsilon=False,
|
||||
prediction_type="sample",
|
||||
beta_end=8,
|
||||
)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger_2:
|
||||
ddpm_2 = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", beta_start=88)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
deprecate("remove this case", "0.10.0", "remove")
|
||||
ddpm_3 = DDPMScheduler.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch",
|
||||
subfolder="scheduler",
|
||||
predict_epsilon=False,
|
||||
beta_end=8,
|
||||
)
|
||||
|
||||
assert ddpm.__class__ == DDPMScheduler
|
||||
assert ddpm.config.predict_epsilon is False
|
||||
assert ddpm.config.prediction_type == "sample"
|
||||
assert ddpm.config.beta_end == 8
|
||||
assert ddpm_2.config.beta_start == 88
|
||||
assert ddpm_3.config.prediction_type == "sample"
|
||||
|
||||
# no warning should be thrown
|
||||
assert cap_logger.out == ""
|
||||
|
|
|
@ -20,7 +20,6 @@ import random
|
|||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -332,14 +331,13 @@ class PipelineFastTests(unittest.TestCase):
|
|||
@parameterized.expand(
|
||||
[
|
||||
[DDIMScheduler, DDIMPipeline, 32],
|
||||
[partial(DDPMScheduler, predict_epsilon=True), DDPMPipeline, 32],
|
||||
[DDPMScheduler, DDPMPipeline, 32],
|
||||
[DDIMScheduler, DDIMPipeline, (32, 64)],
|
||||
[partial(DDPMScheduler, predict_epsilon=True), DDPMPipeline, (64, 32)],
|
||||
[DDPMScheduler, DDPMPipeline, (64, 32)],
|
||||
]
|
||||
)
|
||||
def test_uncond_unet_components(self, scheduler_fn=DDPMScheduler, pipeline_fn=DDPMPipeline, sample_size=32):
|
||||
unet = self.dummy_uncond_unet(sample_size)
|
||||
# DDIM doesn't take `predict_epsilon`, and DDPM requires it -- so using partial in parameterized decorator
|
||||
scheduler = scheduler_fn()
|
||||
pipeline = pipeline_fn(unet, scheduler).to(torch_device)
|
||||
|
||||
|
|
|
@ -599,7 +599,12 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
|||
for clip_sample in [True, False]:
|
||||
self.check_over_configs(clip_sample=clip_sample)
|
||||
|
||||
def test_predict_epsilon(self):
|
||||
def test_prediction_type(self):
|
||||
for prediction_type in ["epsilon", "sample"]:
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
||||
def test_deprecated_predict_epsilon(self):
|
||||
deprecate("remove this test", "0.10.0", "remove")
|
||||
for predict_epsilon in [True, False]:
|
||||
self.check_over_configs(predict_epsilon=predict_epsilon)
|
||||
|
||||
|
@ -795,7 +800,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
|
|||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
"solver_order": 2,
|
||||
"predict_epsilon": True,
|
||||
"prediction_type": "epsilon",
|
||||
"thresholding": False,
|
||||
"sample_max_value": 1.0,
|
||||
"algorithm_type": "dpmsolver++",
|
||||
|
@ -921,10 +926,10 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
|
|||
for order in [1, 2, 3]:
|
||||
for solver_type in ["midpoint", "heun"]:
|
||||
for threshold in [0.5, 1.0, 2.0]:
|
||||
for predict_epsilon in [True, False]:
|
||||
for prediction_type in ["epsilon", "sample"]:
|
||||
self.check_over_configs(
|
||||
thresholding=True,
|
||||
predict_epsilon=predict_epsilon,
|
||||
prediction_type=prediction_type,
|
||||
sample_max_value=threshold,
|
||||
algorithm_type="dpmsolver++",
|
||||
solver_order=order,
|
||||
|
@ -935,17 +940,17 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
|
|||
for algorithm_type in ["dpmsolver", "dpmsolver++"]:
|
||||
for solver_type in ["midpoint", "heun"]:
|
||||
for order in [1, 2, 3]:
|
||||
for predict_epsilon in [True, False]:
|
||||
for prediction_type in ["epsilon", "sample"]:
|
||||
self.check_over_configs(
|
||||
solver_order=order,
|
||||
solver_type=solver_type,
|
||||
predict_epsilon=predict_epsilon,
|
||||
prediction_type=prediction_type,
|
||||
algorithm_type=algorithm_type,
|
||||
)
|
||||
sample = self.full_loop(
|
||||
solver_order=order,
|
||||
solver_type=solver_type,
|
||||
predict_epsilon=predict_epsilon,
|
||||
prediction_type=prediction_type,
|
||||
algorithm_type=algorithm_type,
|
||||
)
|
||||
assert not torch.isnan(sample).any(), "Samples have nan numbers"
|
||||
|
|
|
@ -17,7 +17,7 @@ import unittest
|
|||
from typing import Dict, List, Tuple
|
||||
|
||||
from diffusers import FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxPNDMScheduler
|
||||
from diffusers.utils import is_flax_available
|
||||
from diffusers.utils import deprecate, is_flax_available
|
||||
from diffusers.utils.testing_utils import require_flax
|
||||
|
||||
|
||||
|
@ -599,6 +599,26 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
|
|||
assert abs(result_sum - 149.0784) < 1e-2
|
||||
assert abs(result_mean - 0.1941) < 1e-3
|
||||
|
||||
def test_prediction_type(self):
|
||||
for prediction_type in ["epsilon", "sample", "v_prediction"]:
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
||||
def test_deprecated_predict_epsilon(self):
|
||||
deprecate("remove this test", "0.10.0", "remove")
|
||||
for predict_epsilon in [True, False]:
|
||||
self.check_over_configs(predict_epsilon=predict_epsilon)
|
||||
|
||||
def test_deprecated_predict_epsilon_to_prediction_type(self):
|
||||
deprecate("remove this test", "0.10.0", "remove")
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config(predict_epsilon=True)
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
assert scheduler.prediction_type == "epsilon"
|
||||
|
||||
scheduler_config = self.get_scheduler_config(predict_epsilon=False)
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
assert scheduler.prediction_type == "sample"
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
|
||||
|
|
Loading…
Reference in New Issue