[Bump version] 0.13.0dev0 & Deprecate `predict_epsilon` (#2109)

* [Bump version] 0.13

* Bump model up

* up
This commit is contained in:
Patrick von Platen 2023-01-25 18:59:02 +02:00 committed by GitHub
parent b0cc7c202b
commit 09779cbb40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 28 additions and 212 deletions

View File

@ -185,7 +185,7 @@ class ImagicStableDiffusionPipeline(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
"""
message = "Please use `image` instead of `init_image`."
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
image = init_image or image
accelerator = Accelerator(

View File

@ -759,7 +759,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
"""
message = "Please use `image` instead of `init_image`."
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
image = init_image or image
# 0. Default height and width to unet

View File

@ -745,7 +745,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
(nsfw) content, according to the `safety_checker`.
"""
message = "Please use `image` instead of `init_image`."
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
image = init_image or image
# 0. Default height and width to unet

View File

@ -46,7 +46,7 @@ from transformers import AutoTokenizer, PretrainedConfig
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.12.0")
check_min_version("0.13.0.dev0")
logger = get_logger(__name__)

View File

@ -36,7 +36,7 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel,
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.12.0")
check_min_version("0.13.0.dev0")
# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))

View File

@ -54,7 +54,7 @@ from transformers import AutoTokenizer, PretrainedConfig
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.12.0")
check_min_version("0.13.0.dev0")
logger = get_logger(__name__)

View File

@ -45,7 +45,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.12.0")
check_min_version("0.13.0.dev0")
logger = get_logger(__name__, log_level="INFO")

View File

@ -34,7 +34,7 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel,
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.12.0")
check_min_version("0.13.0.dev0")
logger = logging.getLogger(__name__)

View File

@ -47,7 +47,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.12.0")
check_min_version("0.13.0.dev0")
logger = get_logger(__name__, log_level="INFO")

View File

@ -68,7 +68,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.12.0")
check_min_version("0.13.0.dev0")
logger = get_logger(__name__)

View File

@ -57,7 +57,7 @@ else:
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.12.0")
check_min_version("0.13.0.dev0")
logger = logging.getLogger(__name__)

View File

@ -33,7 +33,7 @@ from tqdm.auto import tqdm
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.12.0")
check_min_version("0.13.0.dev0")
logger = get_logger(__name__, log_level="INFO")

View File

@ -30,7 +30,7 @@ from tqdm.auto import tqdm
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.12.0")
check_min_version("0.13.0.dev0")
logger = get_logger(__name__)

View File

@ -219,7 +219,7 @@ install_requires = [
setup(
name="diffusers",
version="0.12.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.13.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="Diffusers",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",

View File

@ -1,4 +1,4 @@
__version__ = "0.12.0"
__version__ = "0.13.0.dev0"
from .configuration_utils import ConfigMixin
from .utils import (

View File

@ -606,7 +606,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
"""
message = "Please use `image` instead of `init_image`."
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
image = init_image or image
# 1. Check inputs. Raise error if not correct

View File

@ -17,8 +17,7 @@ from typing import List, Optional, Tuple, Union
import torch
from ...configuration_utils import FrozenDict
from ...utils import deprecate, randn_tensor
from ...utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@ -46,7 +45,6 @@ class DDPMPipeline(DiffusionPipeline):
num_inference_steps: int = 1000,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Args:
@ -68,30 +66,6 @@ class DDPMPipeline(DiffusionPipeline):
[`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
"""
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
if predict_epsilon is not None:
new_config = dict(self.scheduler.config)
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
self.scheduler._internal_dict = FrozenDict(new_config)
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
message = (
f"The `generator` device is `{generator.device}` and does not match the pipeline "
f"device `{self.device}`, so the `generator` will be ignored. "
f'Please use `torch.Generator(device="{self.device}")` instead.'
)
deprecate(
"generator.device == 'cpu'",
"0.13.0",
message,
)
generator = None
# Sample gaussian noise to begin loop
if isinstance(self.unet.sample_size, int):
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)

View File

@ -102,7 +102,7 @@ class LDMSuperResolutionPipeline(DiffusionPipeline):
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
"""
message = "Please use `image` instead of `init_image`."
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
image = init_image or image
if isinstance(image, PIL.Image.Image):

View File

@ -623,7 +623,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
"""
message = "Please use `image` instead of `init_image`."
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
image = init_image or image
# 1. Check inputs

View File

@ -311,7 +311,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
"""
message = "Please use `image` instead of `init_image`."
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
image = init_image or image
if isinstance(prompt, str):

View File

@ -303,7 +303,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
"""
message = "Please use `image` instead of `init_image`."
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
image = init_image or image
if isinstance(prompt, str):

View File

@ -616,7 +616,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
"""
message = "Please use `image` instead of `init_image`."
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
image = init_image or image
# 1. Check inputs. Raise error if not correct

View File

@ -556,7 +556,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
"""
message = "Please use `image` instead of `init_image`."
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
image = init_image or image
# 1. Check inputs

View File

@ -23,7 +23,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, deprecate, randn_tensor
from ..utils import BaseOutput, randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
@ -113,7 +113,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
order = 1
@register_to_config
@ -128,16 +127,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":

View File

@ -22,7 +22,6 @@ import flax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
@ -103,7 +102,6 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
dtype: jnp.dtype
@ -123,16 +121,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
steps_offset: int = 0,
prediction_type: str = "epsilon",
dtype: jnp.dtype = jnp.float32,
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
f" {self.__class__.__name__}.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
self.dtype = dtype
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState:

View File

@ -21,8 +21,8 @@ from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ..utils import BaseOutput, deprecate, randn_tensor
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
@ -106,7 +106,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
order = 1
@register_to_config
@ -120,16 +119,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance_type: str = "fixed_small",
clip_sample: bool = True,
prediction_type: str = "epsilon",
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
@ -239,7 +229,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
sample: torch.FloatTensor,
generator=None,
return_dict: bool = True,
**kwargs,
) -> Union[DDPMSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
@ -259,16 +248,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
returning a tuple, the first element is the sample tensor.
"""
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
if predict_epsilon is not None:
new_config = dict(self.config)
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
self._internal_dict = FrozenDict(new_config)
t = timestep
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:

View File

@ -22,7 +22,6 @@ import jax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
@ -86,7 +85,6 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
dtype: jnp.dtype
@ -106,16 +104,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
clip_sample: bool = True,
prediction_type: str = "epsilon",
dtype: jnp.dtype = jnp.float32,
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
f" {self.__class__.__name__}.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
self.dtype = dtype
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState:

View File

@ -21,7 +21,6 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
@ -118,7 +117,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
order = 1
@register_to_config
@ -137,16 +135,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":

View File

@ -22,7 +22,6 @@ import jax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
@ -141,7 +140,6 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
dtype: jnp.dtype
@ -166,16 +164,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
solver_type: str = "midpoint",
lower_order_final: bool = True,
dtype: jnp.dtype = jnp.float32,
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
f" {self.__class__.__name__}.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
self.dtype = dtype
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState:

View File

@ -19,7 +19,6 @@ import numpy as np
import torch
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.utils import deprecate
from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device
@ -67,32 +66,6 @@ class DDPMPipelineFastTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_inference_deprecated_predict_epsilon(self):
deprecate("remove this test", "0.13.0", "remove")
unet = self.dummy_uncond_unet
scheduler = DDPMScheduler(predict_epsilon=False)
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
# Warmup pass when using mps (see #372)
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)
generator = torch.manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
generator = torch.manual_seed(0)
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0]
image_slice = image[0, -3:, -3:, -1]
image_eps_slice = image_eps[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
tolerance = 1e-2 if torch_device != "mps" else 3e-2
assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance
def test_inference_predict_sample(self):
unet = self.dummy_uncond_unet
scheduler = DDPMScheduler(prediction_type="sample")

View File

@ -26,7 +26,6 @@ from diffusers import (
logging,
)
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import deprecate
from diffusers.utils.testing_utils import CaptureLogger
@ -202,20 +201,10 @@ class ConfigTester(unittest.TestCase):
with CaptureLogger(logger) as cap_logger_2:
ddpm_2 = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", beta_start=88)
with CaptureLogger(logger) as cap_logger:
deprecate("remove this case", "0.13.0", "remove")
ddpm_3 = DDPMScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch",
subfolder="scheduler",
predict_epsilon=False,
beta_end=8,
)
assert ddpm.__class__ == DDPMScheduler
assert ddpm.config.prediction_type == "sample"
assert ddpm.config.beta_end == 8
assert ddpm_2.config.beta_start == 88
assert ddpm_3.config.prediction_type == "sample"
# no warning should be thrown
assert cap_logger.out == ""

View File

@ -45,7 +45,7 @@ from diffusers import (
)
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils import deprecate, torch_device
from diffusers.utils import torch_device
from diffusers.utils.testing_utils import CaptureLogger
@ -645,35 +645,6 @@ class DDPMSchedulerTest(SchedulerCommonTest):
for prediction_type in ["epsilon", "sample", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)
def test_deprecated_predict_epsilon(self):
deprecate("remove this test", "0.13.0", "remove")
for predict_epsilon in [True, False]:
self.check_over_configs(predict_epsilon=predict_epsilon)
def test_deprecated_epsilon(self):
deprecate("remove this test", "0.13.0", "remove")
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
sample = self.dummy_sample_deter
residual = 0.1 * self.dummy_sample_deter
time_step = 4
scheduler = scheduler_class(**scheduler_config)
scheduler_eps = scheduler_class(predict_epsilon=False, **scheduler_config)
kwargs = {}
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
kwargs["generator"] = torch.manual_seed(0)
output = scheduler.step(residual, time_step, sample, predict_epsilon=False, **kwargs).prev_sample
kwargs = {}
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
kwargs["generator"] = torch.manual_seed(0)
output_eps = scheduler_eps.step(residual, time_step, sample, predict_epsilon=False, **kwargs).prev_sample
assert (output - output_eps).abs().sum() < 1e-5
def test_time_indices(self):
for t in [0, 500, 999]:
self.check_over_forward(time_step=t)

View File

@ -18,7 +18,7 @@ import unittest
from typing import Dict, List, Tuple
from diffusers import FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxPNDMScheduler
from diffusers.utils import deprecate, is_flax_available
from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import require_flax
@ -626,22 +626,6 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
for prediction_type in ["epsilon", "sample", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)
def test_deprecated_predict_epsilon(self):
deprecate("remove this test", "0.13.0", "remove")
for predict_epsilon in [True, False]:
self.check_over_configs(predict_epsilon=predict_epsilon)
def test_deprecated_predict_epsilon_to_prediction_type(self):
deprecate("remove this test", "0.13.0", "remove")
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(predict_epsilon=True)
scheduler = scheduler_class.from_config(scheduler_config)
assert scheduler.prediction_type == "epsilon"
scheduler_config = self.get_scheduler_config(predict_epsilon=False)
scheduler = scheduler_class.from_config(scheduler_config)
assert scheduler.prediction_type == "sample"
@require_flax
class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):