Allow to set config params directly in init (#1419)

* fix

* fix deprecated kwargs logic

* add tests

* finish
This commit is contained in:
Patrick von Platen 2022-11-25 15:07:09 +01:00 committed by GitHub
parent 86aa747da9
commit 8faa822ddc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 103 additions and 11 deletions

View File

@ -80,14 +80,18 @@ class ConfigMixin:
- **config_name** (`str`) -- A filename under which the config should stored when calling - **config_name** (`str`) -- A filename under which the config should stored when calling
[`~ConfigMixin.save_config`] (should be overridden by parent class). [`~ConfigMixin.save_config`] (should be overridden by parent class).
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
overridden by parent class). overridden by subclass).
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by parent - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
class). - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
subclass).
""" """
config_name = None config_name = None
ignore_for_config = [] ignore_for_config = []
has_compatibles = False has_compatibles = False
_deprecated_kwargs = []
def register_to_config(self, **kwargs): def register_to_config(self, **kwargs):
if self.config_name is None: if self.config_name is None:
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`") raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
@ -195,10 +199,10 @@ class ConfigMixin:
if "dtype" in unused_kwargs: if "dtype" in unused_kwargs:
init_dict["dtype"] = unused_kwargs.pop("dtype") init_dict["dtype"] = unused_kwargs.pop("dtype")
if "predict_epsilon" in unused_kwargs and "prediction_type" not in init_dict: # add possible deprecated kwargs
deprecate("remove this", "0.10.0", "remove") for deprecated_kwarg in cls._deprecated_kwargs:
predict_epsilon = unused_kwargs.pop("predict_epsilon") if deprecated_kwarg in unused_kwargs:
init_dict["prediction_type"] = "epsilon" if predict_epsilon else "sample" init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
# Return model and optionally state and/or unused_kwargs # Return model and optionally state and/or unused_kwargs
model = cls(**init_dict) model = cls(**init_dict)
@ -526,7 +530,6 @@ def register_to_config(init):
# Ignore private kwargs in the init. # Ignore private kwargs in the init.
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")} config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
init(self, *args, **init_kwargs)
if not isinstance(self, ConfigMixin): if not isinstance(self, ConfigMixin):
raise RuntimeError( raise RuntimeError(
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
@ -553,6 +556,7 @@ def register_to_config(init):
) )
new_kwargs = {**config_init_kwargs, **new_kwargs} new_kwargs = {**config_init_kwargs, **new_kwargs}
getattr(self, "register_to_config")(**new_kwargs) getattr(self, "register_to_config")(**new_kwargs)
init(self, *args, **init_kwargs)
return inner_init return inner_init

View File

@ -254,7 +254,6 @@ class UNetMidBlock2D(nn.Module):
attn_num_head_channels=1, attn_num_head_channels=1,
attention_type="default", attention_type="default",
output_scale_factor=1.0, output_scale_factor=1.0,
**kwargs,
): ):
super().__init__() super().__init__()
@ -336,7 +335,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
cross_attention_dim=1280, cross_attention_dim=1280,
dual_cross_attention=False, dual_cross_attention=False,
use_linear_projection=False, use_linear_projection=False,
**kwargs,
): ):
super().__init__() super().__init__()

View File

@ -1039,7 +1039,6 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
cross_attention_dim=1280, cross_attention_dim=1280,
dual_cross_attention=False, dual_cross_attention=False,
use_linear_projection=False, use_linear_projection=False,
**kwargs,
): ):
super().__init__() super().__init__()

View File

@ -113,6 +113,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_deprecated_kwargs = ["predict_epsilon"]
@register_to_config @register_to_config
def __init__( def __init__(

View File

@ -116,6 +116,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
""" """
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_deprecated_kwargs = ["predict_epsilon"]
@property @property
def has_state(self): def has_state(self):

View File

@ -105,6 +105,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_deprecated_kwargs = ["predict_epsilon"]
@register_to_config @register_to_config
def __init__( def __init__(

View File

@ -109,6 +109,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
""" """
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_deprecated_kwargs = ["predict_epsilon"]
@property @property
def has_state(self): def has_state(self):

View File

@ -117,6 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_deprecated_kwargs = ["predict_epsilon"]
@register_to_config @register_to_config
def __init__( def __init__(

View File

@ -149,6 +149,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
""" """
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_deprecated_kwargs = ["predict_epsilon"]
@property @property
def has_state(self): def has_state(self):

View File

@ -265,3 +265,23 @@ class ModelTesterMixin:
# check disable works # check disable works
model.disable_gradient_checkpointing() model.disable_gradient_checkpointing()
self.assertFalse(model.is_gradient_checkpointing) self.assertFalse(model.is_gradient_checkpointing)
def test_deprecated_kwargs(self):
has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0
if has_kwarg_in_model_class and not has_deprecated_kwarg:
raise ValueError(
f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs"
" under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are"
" no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
" [<deprecated_argument>]`"
)
if not has_kwarg_in_model_class and has_deprecated_kwarg:
raise ValueError(
f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs"
" under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to"
f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument"
" from `_deprecated_kwargs = [<deprecated_argument>]`"
)

View File

@ -1,3 +1,5 @@
import inspect
from diffusers.utils import is_flax_available from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import require_flax from diffusers.utils.testing_utils import require_flax
@ -42,3 +44,23 @@ class FlaxModelTesterMixin:
self.assertIsNotNone(output) self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_deprecated_kwargs(self):
has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0
if has_kwarg_in_model_class and not has_deprecated_kwarg:
raise ValueError(
f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs"
" under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are"
" no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
" [<deprecated_argument>]`"
)
if not has_kwarg_in_model_class and has_deprecated_kwarg:
raise ValueError(
f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs"
" under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to"
f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument"
" from `_deprecated_kwargs = [<deprecated_argument>]`"
)

View File

@ -562,6 +562,27 @@ class SchedulerCommonTest(unittest.TestCase):
noised = scheduler.add_noise(scaled_sample, noise, t) noised = scheduler.add_noise(scaled_sample, noise, t)
self.assertEqual(noised.shape, scaled_sample.shape) self.assertEqual(noised.shape, scaled_sample.shape)
def test_deprecated_kwargs(self):
for scheduler_class in self.scheduler_classes:
has_kwarg_in_model_class = "kwargs" in inspect.signature(scheduler_class.__init__).parameters
has_deprecated_kwarg = len(scheduler_class._deprecated_kwargs) > 0
if has_kwarg_in_model_class and not has_deprecated_kwarg:
raise ValueError(
f"{scheduler_class} has `**kwargs` in its __init__ method but has not defined any deprecated"
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if"
" there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
" [<deprecated_argument>]`"
)
if not has_kwarg_in_model_class and has_deprecated_kwarg:
raise ValueError(
f"{scheduler_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated"
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`"
f" argument to {self.model_class}.__init__ if there are deprecated arguments or remove the"
" deprecated argument from `_deprecated_kwargs = [<deprecated_argument>]`"
)
class DDPMSchedulerTest(SchedulerCommonTest): class DDPMSchedulerTest(SchedulerCommonTest):
scheduler_classes = (DDPMScheduler,) scheduler_classes = (DDPMScheduler,)

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
import tempfile import tempfile
import unittest import unittest
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
@ -228,6 +229,27 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
recursive_check(outputs_tuple[0], outputs_dict.prev_sample) recursive_check(outputs_tuple[0], outputs_dict.prev_sample)
def test_deprecated_kwargs(self):
for scheduler_class in self.scheduler_classes:
has_kwarg_in_model_class = "kwargs" in inspect.signature(scheduler_class.__init__).parameters
has_deprecated_kwarg = len(scheduler_class._deprecated_kwargs) > 0
if has_kwarg_in_model_class and not has_deprecated_kwarg:
raise ValueError(
f"{scheduler_class} has `**kwargs` in its __init__ method but has not defined any deprecated"
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if"
" there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
" [<deprecated_argument>]`"
)
if not has_kwarg_in_model_class and has_deprecated_kwarg:
raise ValueError(
f"{scheduler_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated"
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`"
f" argument to {self.model_class}.__init__ if there are deprecated arguments or remove the"
" deprecated argument from `_deprecated_kwargs = [<deprecated_argument>]`"
)
@require_flax @require_flax
class FlaxDDPMSchedulerTest(FlaxSchedulerCommonTest): class FlaxDDPMSchedulerTest(FlaxSchedulerCommonTest):