Allow to set config params directly in init (#1419)
* fix * fix deprecated kwargs logic * add tests * finish
This commit is contained in:
parent
86aa747da9
commit
8faa822ddc
|
@ -80,14 +80,18 @@ class ConfigMixin:
|
||||||
- **config_name** (`str`) -- A filename under which the config should stored when calling
|
- **config_name** (`str`) -- A filename under which the config should stored when calling
|
||||||
[`~ConfigMixin.save_config`] (should be overridden by parent class).
|
[`~ConfigMixin.save_config`] (should be overridden by parent class).
|
||||||
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
|
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
|
||||||
overridden by parent class).
|
overridden by subclass).
|
||||||
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by parent
|
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
|
||||||
class).
|
- **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
|
||||||
|
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
|
||||||
|
subclass).
|
||||||
"""
|
"""
|
||||||
config_name = None
|
config_name = None
|
||||||
ignore_for_config = []
|
ignore_for_config = []
|
||||||
has_compatibles = False
|
has_compatibles = False
|
||||||
|
|
||||||
|
_deprecated_kwargs = []
|
||||||
|
|
||||||
def register_to_config(self, **kwargs):
|
def register_to_config(self, **kwargs):
|
||||||
if self.config_name is None:
|
if self.config_name is None:
|
||||||
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
|
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
|
||||||
|
@ -195,10 +199,10 @@ class ConfigMixin:
|
||||||
if "dtype" in unused_kwargs:
|
if "dtype" in unused_kwargs:
|
||||||
init_dict["dtype"] = unused_kwargs.pop("dtype")
|
init_dict["dtype"] = unused_kwargs.pop("dtype")
|
||||||
|
|
||||||
if "predict_epsilon" in unused_kwargs and "prediction_type" not in init_dict:
|
# add possible deprecated kwargs
|
||||||
deprecate("remove this", "0.10.0", "remove")
|
for deprecated_kwarg in cls._deprecated_kwargs:
|
||||||
predict_epsilon = unused_kwargs.pop("predict_epsilon")
|
if deprecated_kwarg in unused_kwargs:
|
||||||
init_dict["prediction_type"] = "epsilon" if predict_epsilon else "sample"
|
init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
|
||||||
|
|
||||||
# Return model and optionally state and/or unused_kwargs
|
# Return model and optionally state and/or unused_kwargs
|
||||||
model = cls(**init_dict)
|
model = cls(**init_dict)
|
||||||
|
@ -526,7 +530,6 @@ def register_to_config(init):
|
||||||
# Ignore private kwargs in the init.
|
# Ignore private kwargs in the init.
|
||||||
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
|
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
|
||||||
config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
|
config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
|
||||||
init(self, *args, **init_kwargs)
|
|
||||||
if not isinstance(self, ConfigMixin):
|
if not isinstance(self, ConfigMixin):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
|
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
|
||||||
|
@ -553,6 +556,7 @@ def register_to_config(init):
|
||||||
)
|
)
|
||||||
new_kwargs = {**config_init_kwargs, **new_kwargs}
|
new_kwargs = {**config_init_kwargs, **new_kwargs}
|
||||||
getattr(self, "register_to_config")(**new_kwargs)
|
getattr(self, "register_to_config")(**new_kwargs)
|
||||||
|
init(self, *args, **init_kwargs)
|
||||||
|
|
||||||
return inner_init
|
return inner_init
|
||||||
|
|
||||||
|
|
|
@ -254,7 +254,6 @@ class UNetMidBlock2D(nn.Module):
|
||||||
attn_num_head_channels=1,
|
attn_num_head_channels=1,
|
||||||
attention_type="default",
|
attention_type="default",
|
||||||
output_scale_factor=1.0,
|
output_scale_factor=1.0,
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -336,7 +335,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||||
cross_attention_dim=1280,
|
cross_attention_dim=1280,
|
||||||
dual_cross_attention=False,
|
dual_cross_attention=False,
|
||||||
use_linear_projection=False,
|
use_linear_projection=False,
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
|
@ -1039,7 +1039,6 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||||
cross_attention_dim=1280,
|
cross_attention_dim=1280,
|
||||||
dual_cross_attention=False,
|
dual_cross_attention=False,
|
||||||
use_linear_projection=False,
|
use_linear_projection=False,
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
|
@ -113,6 +113,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||||
|
_deprecated_kwargs = ["predict_epsilon"]
|
||||||
|
|
||||||
@register_to_config
|
@register_to_config
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -116,6 +116,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||||
|
_deprecated_kwargs = ["predict_epsilon"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_state(self):
|
def has_state(self):
|
||||||
|
|
|
@ -105,6 +105,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||||
|
_deprecated_kwargs = ["predict_epsilon"]
|
||||||
|
|
||||||
@register_to_config
|
@register_to_config
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -109,6 +109,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||||
|
_deprecated_kwargs = ["predict_epsilon"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_state(self):
|
def has_state(self):
|
||||||
|
|
|
@ -117,6 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||||
|
_deprecated_kwargs = ["predict_epsilon"]
|
||||||
|
|
||||||
@register_to_config
|
@register_to_config
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -149,6 +149,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||||
|
_deprecated_kwargs = ["predict_epsilon"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_state(self):
|
def has_state(self):
|
||||||
|
|
|
@ -265,3 +265,23 @@ class ModelTesterMixin:
|
||||||
# check disable works
|
# check disable works
|
||||||
model.disable_gradient_checkpointing()
|
model.disable_gradient_checkpointing()
|
||||||
self.assertFalse(model.is_gradient_checkpointing)
|
self.assertFalse(model.is_gradient_checkpointing)
|
||||||
|
|
||||||
|
def test_deprecated_kwargs(self):
|
||||||
|
has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
|
||||||
|
has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0
|
||||||
|
|
||||||
|
if has_kwarg_in_model_class and not has_deprecated_kwarg:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs"
|
||||||
|
" under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are"
|
||||||
|
" no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
|
||||||
|
" [<deprecated_argument>]`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not has_kwarg_in_model_class and has_deprecated_kwarg:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs"
|
||||||
|
" under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to"
|
||||||
|
f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument"
|
||||||
|
" from `_deprecated_kwargs = [<deprecated_argument>]`"
|
||||||
|
)
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
import inspect
|
||||||
|
|
||||||
from diffusers.utils import is_flax_available
|
from diffusers.utils import is_flax_available
|
||||||
from diffusers.utils.testing_utils import require_flax
|
from diffusers.utils.testing_utils import require_flax
|
||||||
|
|
||||||
|
@ -42,3 +44,23 @@ class FlaxModelTesterMixin:
|
||||||
self.assertIsNotNone(output)
|
self.assertIsNotNone(output)
|
||||||
expected_shape = inputs_dict["sample"].shape
|
expected_shape = inputs_dict["sample"].shape
|
||||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||||
|
|
||||||
|
def test_deprecated_kwargs(self):
|
||||||
|
has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
|
||||||
|
has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0
|
||||||
|
|
||||||
|
if has_kwarg_in_model_class and not has_deprecated_kwarg:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs"
|
||||||
|
" under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are"
|
||||||
|
" no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
|
||||||
|
" [<deprecated_argument>]`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not has_kwarg_in_model_class and has_deprecated_kwarg:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs"
|
||||||
|
" under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to"
|
||||||
|
f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument"
|
||||||
|
" from `_deprecated_kwargs = [<deprecated_argument>]`"
|
||||||
|
)
|
||||||
|
|
|
@ -562,6 +562,27 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||||
noised = scheduler.add_noise(scaled_sample, noise, t)
|
noised = scheduler.add_noise(scaled_sample, noise, t)
|
||||||
self.assertEqual(noised.shape, scaled_sample.shape)
|
self.assertEqual(noised.shape, scaled_sample.shape)
|
||||||
|
|
||||||
|
def test_deprecated_kwargs(self):
|
||||||
|
for scheduler_class in self.scheduler_classes:
|
||||||
|
has_kwarg_in_model_class = "kwargs" in inspect.signature(scheduler_class.__init__).parameters
|
||||||
|
has_deprecated_kwarg = len(scheduler_class._deprecated_kwargs) > 0
|
||||||
|
|
||||||
|
if has_kwarg_in_model_class and not has_deprecated_kwarg:
|
||||||
|
raise ValueError(
|
||||||
|
f"{scheduler_class} has `**kwargs` in its __init__ method but has not defined any deprecated"
|
||||||
|
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if"
|
||||||
|
" there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
|
||||||
|
" [<deprecated_argument>]`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not has_kwarg_in_model_class and has_deprecated_kwarg:
|
||||||
|
raise ValueError(
|
||||||
|
f"{scheduler_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated"
|
||||||
|
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`"
|
||||||
|
f" argument to {self.model_class}.__init__ if there are deprecated arguments or remove the"
|
||||||
|
" deprecated argument from `_deprecated_kwargs = [<deprecated_argument>]`"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DDPMSchedulerTest(SchedulerCommonTest):
|
class DDPMSchedulerTest(SchedulerCommonTest):
|
||||||
scheduler_classes = (DDPMScheduler,)
|
scheduler_classes = (DDPMScheduler,)
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import inspect
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
@ -228,6 +229,27 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
|
||||||
|
|
||||||
recursive_check(outputs_tuple[0], outputs_dict.prev_sample)
|
recursive_check(outputs_tuple[0], outputs_dict.prev_sample)
|
||||||
|
|
||||||
|
def test_deprecated_kwargs(self):
|
||||||
|
for scheduler_class in self.scheduler_classes:
|
||||||
|
has_kwarg_in_model_class = "kwargs" in inspect.signature(scheduler_class.__init__).parameters
|
||||||
|
has_deprecated_kwarg = len(scheduler_class._deprecated_kwargs) > 0
|
||||||
|
|
||||||
|
if has_kwarg_in_model_class and not has_deprecated_kwarg:
|
||||||
|
raise ValueError(
|
||||||
|
f"{scheduler_class} has `**kwargs` in its __init__ method but has not defined any deprecated"
|
||||||
|
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if"
|
||||||
|
" there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
|
||||||
|
" [<deprecated_argument>]`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not has_kwarg_in_model_class and has_deprecated_kwarg:
|
||||||
|
raise ValueError(
|
||||||
|
f"{scheduler_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated"
|
||||||
|
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`"
|
||||||
|
f" argument to {self.model_class}.__init__ if there are deprecated arguments or remove the"
|
||||||
|
" deprecated argument from `_deprecated_kwargs = [<deprecated_argument>]`"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@require_flax
|
@require_flax
|
||||||
class FlaxDDPMSchedulerTest(FlaxSchedulerCommonTest):
|
class FlaxDDPMSchedulerTest(FlaxSchedulerCommonTest):
|
||||||
|
|
Loading…
Reference in New Issue