diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 1a7499c6..f06586b2 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -80,14 +80,18 @@ class ConfigMixin: - **config_name** (`str`) -- A filename under which the config should stored when calling [`~ConfigMixin.save_config`] (should be overridden by parent class). - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be - overridden by parent class). - - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by parent - class). + overridden by subclass). + - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass). + - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function + should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by + subclass). """ config_name = None ignore_for_config = [] has_compatibles = False + _deprecated_kwargs = [] + def register_to_config(self, **kwargs): if self.config_name is None: raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`") @@ -195,10 +199,10 @@ class ConfigMixin: if "dtype" in unused_kwargs: init_dict["dtype"] = unused_kwargs.pop("dtype") - if "predict_epsilon" in unused_kwargs and "prediction_type" not in init_dict: - deprecate("remove this", "0.10.0", "remove") - predict_epsilon = unused_kwargs.pop("predict_epsilon") - init_dict["prediction_type"] = "epsilon" if predict_epsilon else "sample" + # add possible deprecated kwargs + for deprecated_kwarg in cls._deprecated_kwargs: + if deprecated_kwarg in unused_kwargs: + init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg) # Return model and optionally state and/or unused_kwargs model = cls(**init_dict) @@ -526,7 +530,6 @@ def register_to_config(init): # Ignore private kwargs in the init. init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")} - init(self, *args, **init_kwargs) if not isinstance(self, ConfigMixin): raise RuntimeError( f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " @@ -553,6 +556,7 @@ def register_to_config(init): ) new_kwargs = {**config_init_kwargs, **new_kwargs} getattr(self, "register_to_config")(**new_kwargs) + init(self, *args, **init_kwargs) return inner_init diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 6b4a88c0..cce7e7fd 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -254,7 +254,6 @@ class UNetMidBlock2D(nn.Module): attn_num_head_channels=1, attention_type="default", output_scale_factor=1.0, - **kwargs, ): super().__init__() @@ -336,7 +335,6 @@ class UNetMidBlock2DCrossAttn(nn.Module): cross_attention_dim=1280, dual_cross_attention=False, use_linear_projection=False, - **kwargs, ): super().__init__() diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index fb8855b9..37a79b5c 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1039,7 +1039,6 @@ class UNetMidBlockFlatCrossAttn(nn.Module): cross_attention_dim=1280, dual_cross_attention=False, use_linear_projection=False, - **kwargs, ): super().__init__() diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index b16716f0..3640b375 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -113,6 +113,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): """ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _deprecated_kwargs = ["predict_epsilon"] @register_to_config def __init__( diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index 122c36f2..f98d9770 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -116,6 +116,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): """ _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _deprecated_kwargs = ["predict_epsilon"] @property def has_state(self): diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index c691630a..6f131659 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -105,6 +105,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): """ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _deprecated_kwargs = ["predict_epsilon"] @register_to_config def __init__( diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index 946665a0..97b38fd3 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -109,6 +109,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): """ _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _deprecated_kwargs = ["predict_epsilon"] @property def has_state(self): diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 2999ff7f..d38ceed2 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -117,6 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): """ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _deprecated_kwargs = ["predict_epsilon"] @register_to_config def __init__( diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py index 8bb0672f..4d56d99a 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py @@ -149,6 +149,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): """ _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _deprecated_kwargs = ["predict_epsilon"] @property def has_state(self): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 49bb4f6d..cad1887f 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -265,3 +265,23 @@ class ModelTesterMixin: # check disable works model.disable_gradient_checkpointing() self.assertFalse(model.is_gradient_checkpointing) + + def test_deprecated_kwargs(self): + has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters + has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0 + + if has_kwarg_in_model_class and not has_deprecated_kwarg: + raise ValueError( + f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs" + " under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are" + " no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" + " []`" + ) + + if not has_kwarg_in_model_class and has_deprecated_kwarg: + raise ValueError( + f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs" + " under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to" + f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument" + " from `_deprecated_kwargs = []`" + ) diff --git a/tests/test_modeling_common_flax.py b/tests/test_modeling_common_flax.py index 61849b22..8945aed7 100644 --- a/tests/test_modeling_common_flax.py +++ b/tests/test_modeling_common_flax.py @@ -1,3 +1,5 @@ +import inspect + from diffusers.utils import is_flax_available from diffusers.utils.testing_utils import require_flax @@ -42,3 +44,23 @@ class FlaxModelTesterMixin: self.assertIsNotNone(output) expected_shape = inputs_dict["sample"].shape self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_deprecated_kwargs(self): + has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters + has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0 + + if has_kwarg_in_model_class and not has_deprecated_kwarg: + raise ValueError( + f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs" + " under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are" + " no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" + " []`" + ) + + if not has_kwarg_in_model_class and has_deprecated_kwarg: + raise ValueError( + f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs" + " under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to" + f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument" + " from `_deprecated_kwargs = []`" + ) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 4406149d..6a765816 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -562,6 +562,27 @@ class SchedulerCommonTest(unittest.TestCase): noised = scheduler.add_noise(scaled_sample, noise, t) self.assertEqual(noised.shape, scaled_sample.shape) + def test_deprecated_kwargs(self): + for scheduler_class in self.scheduler_classes: + has_kwarg_in_model_class = "kwargs" in inspect.signature(scheduler_class.__init__).parameters + has_deprecated_kwarg = len(scheduler_class._deprecated_kwargs) > 0 + + if has_kwarg_in_model_class and not has_deprecated_kwarg: + raise ValueError( + f"{scheduler_class} has `**kwargs` in its __init__ method but has not defined any deprecated" + " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if" + " there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" + " []`" + ) + + if not has_kwarg_in_model_class and has_deprecated_kwarg: + raise ValueError( + f"{scheduler_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated" + " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`" + f" argument to {self.model_class}.__init__ if there are deprecated arguments or remove the" + " deprecated argument from `_deprecated_kwargs = []`" + ) + class DDPMSchedulerTest(SchedulerCommonTest): scheduler_classes = (DDPMScheduler,) diff --git a/tests/test_scheduler_flax.py b/tests/test_scheduler_flax.py index 6524e18d..5ada689b 100644 --- a/tests/test_scheduler_flax.py +++ b/tests/test_scheduler_flax.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import tempfile import unittest from typing import Dict, List, Tuple @@ -228,6 +229,27 @@ class FlaxSchedulerCommonTest(unittest.TestCase): recursive_check(outputs_tuple[0], outputs_dict.prev_sample) + def test_deprecated_kwargs(self): + for scheduler_class in self.scheduler_classes: + has_kwarg_in_model_class = "kwargs" in inspect.signature(scheduler_class.__init__).parameters + has_deprecated_kwarg = len(scheduler_class._deprecated_kwargs) > 0 + + if has_kwarg_in_model_class and not has_deprecated_kwarg: + raise ValueError( + f"{scheduler_class} has `**kwargs` in its __init__ method but has not defined any deprecated" + " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if" + " there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" + " []`" + ) + + if not has_kwarg_in_model_class and has_deprecated_kwarg: + raise ValueError( + f"{scheduler_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated" + " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`" + f" argument to {self.model_class}.__init__ if there are deprecated arguments or remove the" + " deprecated argument from `_deprecated_kwargs = []`" + ) + @require_flax class FlaxDDPMSchedulerTest(FlaxSchedulerCommonTest):