[Bump version] 0.13.0dev0 & Deprecate `predict_epsilon` (#2109)
* [Bump version] 0.13 * Bump model up * up
This commit is contained in:
parent
b0cc7c202b
commit
09779cbb40
|
@ -185,7 +185,7 @@ class ImagicStableDiffusionPipeline(DiffusionPipeline):
|
|||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
message = "Please use `image` instead of `init_image`."
|
||||
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
|
||||
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
|
||||
image = init_image or image
|
||||
|
||||
accelerator = Accelerator(
|
||||
|
|
|
@ -759,7 +759,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
message = "Please use `image` instead of `init_image`."
|
||||
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
|
||||
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
|
||||
image = init_image or image
|
||||
|
||||
# 0. Default height and width to unet
|
||||
|
|
|
@ -745,7 +745,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
|
|||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
message = "Please use `image` instead of `init_image`."
|
||||
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
|
||||
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
|
||||
image = init_image or image
|
||||
|
||||
# 0. Default height and width to unet
|
||||
|
|
|
@ -46,7 +46,7 @@ from transformers import AutoTokenizer, PretrainedConfig
|
|||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.12.0")
|
||||
check_min_version("0.13.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel,
|
|||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.12.0")
|
||||
check_min_version("0.13.0.dev0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
|
|
|
@ -54,7 +54,7 @@ from transformers import AutoTokenizer, PretrainedConfig
|
|||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.12.0")
|
||||
check_min_version("0.13.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
|
|
@ -45,7 +45,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
|
|||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.12.0")
|
||||
check_min_version("0.13.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel,
|
|||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.12.0")
|
||||
check_min_version("0.13.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
|
|||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.12.0")
|
||||
check_min_version("0.13.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
|
|
@ -68,7 +68,7 @@ else:
|
|||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.12.0")
|
||||
check_min_version("0.13.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ else:
|
|||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.12.0")
|
||||
check_min_version("0.13.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ from tqdm.auto import tqdm
|
|||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.12.0")
|
||||
check_min_version("0.13.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ from tqdm.auto import tqdm
|
|||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.12.0")
|
||||
check_min_version("0.13.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
|
2
setup.py
2
setup.py
|
@ -219,7 +219,7 @@ install_requires = [
|
|||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.12.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.13.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="Diffusers",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
__version__ = "0.12.0"
|
||||
__version__ = "0.13.0.dev0"
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .utils import (
|
||||
|
|
|
@ -606,7 +606,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
message = "Please use `image` instead of `init_image`."
|
||||
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
|
||||
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
|
||||
image = init_image or image
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
|
|
|
@ -17,8 +17,7 @@ from typing import List, Optional, Tuple, Union
|
|||
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...utils import deprecate, randn_tensor
|
||||
from ...utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
|
@ -46,7 +45,6 @@ class DDPMPipeline(DiffusionPipeline):
|
|||
num_inference_steps: int = 1000,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
) -> Union[ImagePipelineOutput, Tuple]:
|
||||
r"""
|
||||
Args:
|
||||
|
@ -68,30 +66,6 @@ class DDPMPipeline(DiffusionPipeline):
|
|||
[`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
|
||||
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
|
||||
|
||||
if predict_epsilon is not None:
|
||||
new_config = dict(self.scheduler.config)
|
||||
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
|
||||
self.scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
|
||||
message = (
|
||||
f"The `generator` device is `{generator.device}` and does not match the pipeline "
|
||||
f"device `{self.device}`, so the `generator` will be ignored. "
|
||||
f'Please use `torch.Generator(device="{self.device}")` instead.'
|
||||
)
|
||||
deprecate(
|
||||
"generator.device == 'cpu'",
|
||||
"0.13.0",
|
||||
message,
|
||||
)
|
||||
generator = None
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
if isinstance(self.unet.sample_size, int):
|
||||
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
|
||||
|
|
|
@ -102,7 +102,7 @@ class LDMSuperResolutionPipeline(DiffusionPipeline):
|
|||
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
message = "Please use `image` instead of `init_image`."
|
||||
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
|
||||
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
|
||||
image = init_image or image
|
||||
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
|
|
|
@ -623,7 +623,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
|||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
message = "Please use `image` instead of `init_image`."
|
||||
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
|
||||
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
|
||||
image = init_image or image
|
||||
|
||||
# 1. Check inputs
|
||||
|
|
|
@ -311,7 +311,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
message = "Please use `image` instead of `init_image`."
|
||||
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
|
||||
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
|
||||
image = init_image or image
|
||||
|
||||
if isinstance(prompt, str):
|
||||
|
|
|
@ -303,7 +303,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
|||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
message = "Please use `image` instead of `init_image`."
|
||||
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
|
||||
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
|
||||
image = init_image or image
|
||||
|
||||
if isinstance(prompt, str):
|
||||
|
|
|
@ -616,7 +616,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
message = "Please use `image` instead of `init_image`."
|
||||
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
|
||||
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
|
||||
image = init_image or image
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
|
|
|
@ -556,7 +556,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
|||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
message = "Please use `image` instead of `init_image`."
|
||||
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
|
||||
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
|
||||
image = init_image or image
|
||||
|
||||
# 1. Check inputs
|
||||
|
|
|
@ -23,7 +23,7 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, deprecate, randn_tensor
|
||||
from ..utils import BaseOutput, randn_tensor
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
||||
|
||||
|
||||
|
@ -113,7 +113,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
"""
|
||||
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
|
@ -128,16 +127,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
prediction_type: str = "epsilon",
|
||||
**kwargs,
|
||||
):
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" DDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
|
|
|
@ -22,7 +22,6 @@ import flax
|
|||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from .scheduling_utils_flax import (
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
|
@ -103,7 +102,6 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|||
"""
|
||||
|
||||
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
|
||||
dtype: jnp.dtype
|
||||
|
||||
|
@ -123,16 +121,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|||
steps_offset: int = 0,
|
||||
prediction_type: str = "epsilon",
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
**kwargs,
|
||||
):
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
f" {self.__class__.__name__}.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
self.dtype = dtype
|
||||
|
||||
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState:
|
||||
|
|
|
@ -21,8 +21,8 @@ from typing import List, Optional, Tuple, Union
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
||||
from ..utils import BaseOutput, deprecate, randn_tensor
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, randn_tensor
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
||||
|
||||
|
||||
|
@ -106,7 +106,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
"""
|
||||
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
|
@ -120,16 +119,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
variance_type: str = "fixed_small",
|
||||
clip_sample: bool = True,
|
||||
prediction_type: str = "epsilon",
|
||||
**kwargs,
|
||||
):
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
|
@ -239,7 +229,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
sample: torch.FloatTensor,
|
||||
generator=None,
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
) -> Union[DDPMSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
|
@ -259,16 +248,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
new_config = dict(self.config)
|
||||
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
|
||||
self._internal_dict = FrozenDict(new_config)
|
||||
|
||||
t = timestep
|
||||
|
||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
|
||||
|
|
|
@ -22,7 +22,6 @@ import jax
|
|||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from .scheduling_utils_flax import (
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
|
@ -86,7 +85,6 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|||
"""
|
||||
|
||||
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
|
||||
dtype: jnp.dtype
|
||||
|
||||
|
@ -106,16 +104,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|||
clip_sample: bool = True,
|
||||
prediction_type: str = "epsilon",
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
**kwargs,
|
||||
):
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
f" {self.__class__.__name__}.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
self.dtype = dtype
|
||||
|
||||
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState:
|
||||
|
|
|
@ -21,7 +21,6 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
|
@ -118,7 +117,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||
"""
|
||||
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
|
@ -137,16 +135,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||
algorithm_type: str = "dpmsolver++",
|
||||
solver_type: str = "midpoint",
|
||||
lower_order_final: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" DPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
|
|
|
@ -22,7 +22,6 @@ import jax
|
|||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from .scheduling_utils_flax import (
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
|
@ -141,7 +140,6 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|||
"""
|
||||
|
||||
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
|
||||
dtype: jnp.dtype
|
||||
|
||||
|
@ -166,16 +164,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|||
solver_type: str = "midpoint",
|
||||
lower_order_final: bool = True,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
**kwargs,
|
||||
):
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
f" {self.__class__.__name__}.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
self.dtype = dtype
|
||||
|
||||
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState:
|
||||
|
|
|
@ -19,7 +19,6 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
|
||||
from diffusers.utils import deprecate
|
||||
from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device
|
||||
|
||||
|
||||
|
@ -67,32 +66,6 @@ class DDPMPipelineFastTests(unittest.TestCase):
|
|||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_inference_deprecated_predict_epsilon(self):
|
||||
deprecate("remove this test", "0.13.0", "remove")
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = DDPMScheduler(predict_epsilon=False)
|
||||
|
||||
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
|
||||
ddpm.to(torch_device)
|
||||
ddpm.set_progress_bar_config(disable=None)
|
||||
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
_ = ddpm(num_inference_steps=1)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_eps_slice = image_eps[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
tolerance = 1e-2 if torch_device != "mps" else 3e-2
|
||||
assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance
|
||||
|
||||
def test_inference_predict_sample(self):
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = DDPMScheduler(prediction_type="sample")
|
||||
|
|
|
@ -26,7 +26,6 @@ from diffusers import (
|
|||
logging,
|
||||
)
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.utils import deprecate
|
||||
from diffusers.utils.testing_utils import CaptureLogger
|
||||
|
||||
|
||||
|
@ -202,20 +201,10 @@ class ConfigTester(unittest.TestCase):
|
|||
with CaptureLogger(logger) as cap_logger_2:
|
||||
ddpm_2 = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", beta_start=88)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
deprecate("remove this case", "0.13.0", "remove")
|
||||
ddpm_3 = DDPMScheduler.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch",
|
||||
subfolder="scheduler",
|
||||
predict_epsilon=False,
|
||||
beta_end=8,
|
||||
)
|
||||
|
||||
assert ddpm.__class__ == DDPMScheduler
|
||||
assert ddpm.config.prediction_type == "sample"
|
||||
assert ddpm.config.beta_end == 8
|
||||
assert ddpm_2.config.beta_start == 88
|
||||
assert ddpm_3.config.prediction_type == "sample"
|
||||
|
||||
# no warning should be thrown
|
||||
assert cap_logger.out == ""
|
||||
|
|
|
@ -45,7 +45,7 @@ from diffusers import (
|
|||
)
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from diffusers.utils import deprecate, torch_device
|
||||
from diffusers.utils import torch_device
|
||||
from diffusers.utils.testing_utils import CaptureLogger
|
||||
|
||||
|
||||
|
@ -645,35 +645,6 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
|||
for prediction_type in ["epsilon", "sample", "v_prediction"]:
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
||||
def test_deprecated_predict_epsilon(self):
|
||||
deprecate("remove this test", "0.13.0", "remove")
|
||||
for predict_epsilon in [True, False]:
|
||||
self.check_over_configs(predict_epsilon=predict_epsilon)
|
||||
|
||||
def test_deprecated_epsilon(self):
|
||||
deprecate("remove this test", "0.13.0", "remove")
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
|
||||
sample = self.dummy_sample_deter
|
||||
residual = 0.1 * self.dummy_sample_deter
|
||||
time_step = 4
|
||||
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler_eps = scheduler_class(predict_epsilon=False, **scheduler_config)
|
||||
|
||||
kwargs = {}
|
||||
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||
kwargs["generator"] = torch.manual_seed(0)
|
||||
output = scheduler.step(residual, time_step, sample, predict_epsilon=False, **kwargs).prev_sample
|
||||
|
||||
kwargs = {}
|
||||
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||
kwargs["generator"] = torch.manual_seed(0)
|
||||
output_eps = scheduler_eps.step(residual, time_step, sample, predict_epsilon=False, **kwargs).prev_sample
|
||||
|
||||
assert (output - output_eps).abs().sum() < 1e-5
|
||||
|
||||
def test_time_indices(self):
|
||||
for t in [0, 500, 999]:
|
||||
self.check_over_forward(time_step=t)
|
||||
|
|
|
@ -18,7 +18,7 @@ import unittest
|
|||
from typing import Dict, List, Tuple
|
||||
|
||||
from diffusers import FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxPNDMScheduler
|
||||
from diffusers.utils import deprecate, is_flax_available
|
||||
from diffusers.utils import is_flax_available
|
||||
from diffusers.utils.testing_utils import require_flax
|
||||
|
||||
|
||||
|
@ -626,22 +626,6 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
|
|||
for prediction_type in ["epsilon", "sample", "v_prediction"]:
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
||||
def test_deprecated_predict_epsilon(self):
|
||||
deprecate("remove this test", "0.13.0", "remove")
|
||||
for predict_epsilon in [True, False]:
|
||||
self.check_over_configs(predict_epsilon=predict_epsilon)
|
||||
|
||||
def test_deprecated_predict_epsilon_to_prediction_type(self):
|
||||
deprecate("remove this test", "0.13.0", "remove")
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config(predict_epsilon=True)
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
assert scheduler.prediction_type == "epsilon"
|
||||
|
||||
scheduler_config = self.get_scheduler_config(predict_epsilon=False)
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
assert scheduler.prediction_type == "sample"
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
|
||||
|
|
Loading…
Reference in New Issue