Allow to set config params directly in init (#1419)
* fix * fix deprecated kwargs logic * add tests * finish
This commit is contained in:
parent
86aa747da9
commit
8faa822ddc
|
@ -80,14 +80,18 @@ class ConfigMixin:
|
|||
- **config_name** (`str`) -- A filename under which the config should stored when calling
|
||||
[`~ConfigMixin.save_config`] (should be overridden by parent class).
|
||||
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
|
||||
overridden by parent class).
|
||||
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by parent
|
||||
class).
|
||||
overridden by subclass).
|
||||
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
|
||||
- **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
|
||||
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
|
||||
subclass).
|
||||
"""
|
||||
config_name = None
|
||||
ignore_for_config = []
|
||||
has_compatibles = False
|
||||
|
||||
_deprecated_kwargs = []
|
||||
|
||||
def register_to_config(self, **kwargs):
|
||||
if self.config_name is None:
|
||||
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
|
||||
|
@ -195,10 +199,10 @@ class ConfigMixin:
|
|||
if "dtype" in unused_kwargs:
|
||||
init_dict["dtype"] = unused_kwargs.pop("dtype")
|
||||
|
||||
if "predict_epsilon" in unused_kwargs and "prediction_type" not in init_dict:
|
||||
deprecate("remove this", "0.10.0", "remove")
|
||||
predict_epsilon = unused_kwargs.pop("predict_epsilon")
|
||||
init_dict["prediction_type"] = "epsilon" if predict_epsilon else "sample"
|
||||
# add possible deprecated kwargs
|
||||
for deprecated_kwarg in cls._deprecated_kwargs:
|
||||
if deprecated_kwarg in unused_kwargs:
|
||||
init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
|
||||
|
||||
# Return model and optionally state and/or unused_kwargs
|
||||
model = cls(**init_dict)
|
||||
|
@ -526,7 +530,6 @@ def register_to_config(init):
|
|||
# Ignore private kwargs in the init.
|
||||
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
|
||||
config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
|
||||
init(self, *args, **init_kwargs)
|
||||
if not isinstance(self, ConfigMixin):
|
||||
raise RuntimeError(
|
||||
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
|
||||
|
@ -553,6 +556,7 @@ def register_to_config(init):
|
|||
)
|
||||
new_kwargs = {**config_init_kwargs, **new_kwargs}
|
||||
getattr(self, "register_to_config")(**new_kwargs)
|
||||
init(self, *args, **init_kwargs)
|
||||
|
||||
return inner_init
|
||||
|
||||
|
|
|
@ -254,7 +254,6 @@ class UNetMidBlock2D(nn.Module):
|
|||
attn_num_head_channels=1,
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -336,7 +335,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|||
cross_attention_dim=1280,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
|
|
@ -1039,7 +1039,6 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|||
cross_attention_dim=1280,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
|
|
@ -113,6 +113,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
|
|
@ -116,6 +116,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
|
|
|
@ -105,6 +105,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
|
|
@ -109,6 +109,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
|
|
|
@ -117,6 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
|
|
@ -149,6 +149,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
|
|
|
@ -265,3 +265,23 @@ class ModelTesterMixin:
|
|||
# check disable works
|
||||
model.disable_gradient_checkpointing()
|
||||
self.assertFalse(model.is_gradient_checkpointing)
|
||||
|
||||
def test_deprecated_kwargs(self):
|
||||
has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
|
||||
has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0
|
||||
|
||||
if has_kwarg_in_model_class and not has_deprecated_kwarg:
|
||||
raise ValueError(
|
||||
f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs"
|
||||
" under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are"
|
||||
" no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
|
||||
" [<deprecated_argument>]`"
|
||||
)
|
||||
|
||||
if not has_kwarg_in_model_class and has_deprecated_kwarg:
|
||||
raise ValueError(
|
||||
f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs"
|
||||
" under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to"
|
||||
f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument"
|
||||
" from `_deprecated_kwargs = [<deprecated_argument>]`"
|
||||
)
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import inspect
|
||||
|
||||
from diffusers.utils import is_flax_available
|
||||
from diffusers.utils.testing_utils import require_flax
|
||||
|
||||
|
@ -42,3 +44,23 @@ class FlaxModelTesterMixin:
|
|||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_deprecated_kwargs(self):
|
||||
has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
|
||||
has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0
|
||||
|
||||
if has_kwarg_in_model_class and not has_deprecated_kwarg:
|
||||
raise ValueError(
|
||||
f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs"
|
||||
" under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are"
|
||||
" no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
|
||||
" [<deprecated_argument>]`"
|
||||
)
|
||||
|
||||
if not has_kwarg_in_model_class and has_deprecated_kwarg:
|
||||
raise ValueError(
|
||||
f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs"
|
||||
" under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to"
|
||||
f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument"
|
||||
" from `_deprecated_kwargs = [<deprecated_argument>]`"
|
||||
)
|
||||
|
|
|
@ -562,6 +562,27 @@ class SchedulerCommonTest(unittest.TestCase):
|
|||
noised = scheduler.add_noise(scaled_sample, noise, t)
|
||||
self.assertEqual(noised.shape, scaled_sample.shape)
|
||||
|
||||
def test_deprecated_kwargs(self):
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
has_kwarg_in_model_class = "kwargs" in inspect.signature(scheduler_class.__init__).parameters
|
||||
has_deprecated_kwarg = len(scheduler_class._deprecated_kwargs) > 0
|
||||
|
||||
if has_kwarg_in_model_class and not has_deprecated_kwarg:
|
||||
raise ValueError(
|
||||
f"{scheduler_class} has `**kwargs` in its __init__ method but has not defined any deprecated"
|
||||
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if"
|
||||
" there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
|
||||
" [<deprecated_argument>]`"
|
||||
)
|
||||
|
||||
if not has_kwarg_in_model_class and has_deprecated_kwarg:
|
||||
raise ValueError(
|
||||
f"{scheduler_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated"
|
||||
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`"
|
||||
f" argument to {self.model_class}.__init__ if there are deprecated arguments or remove the"
|
||||
" deprecated argument from `_deprecated_kwargs = [<deprecated_argument>]`"
|
||||
)
|
||||
|
||||
|
||||
class DDPMSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (DDPMScheduler,)
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Dict, List, Tuple
|
||||
|
@ -228,6 +229,27 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
|
|||
|
||||
recursive_check(outputs_tuple[0], outputs_dict.prev_sample)
|
||||
|
||||
def test_deprecated_kwargs(self):
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
has_kwarg_in_model_class = "kwargs" in inspect.signature(scheduler_class.__init__).parameters
|
||||
has_deprecated_kwarg = len(scheduler_class._deprecated_kwargs) > 0
|
||||
|
||||
if has_kwarg_in_model_class and not has_deprecated_kwarg:
|
||||
raise ValueError(
|
||||
f"{scheduler_class} has `**kwargs` in its __init__ method but has not defined any deprecated"
|
||||
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if"
|
||||
" there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
|
||||
" [<deprecated_argument>]`"
|
||||
)
|
||||
|
||||
if not has_kwarg_in_model_class and has_deprecated_kwarg:
|
||||
raise ValueError(
|
||||
f"{scheduler_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated"
|
||||
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`"
|
||||
f" argument to {self.model_class}.__init__ if there are deprecated arguments or remove the"
|
||||
" deprecated argument from `_deprecated_kwargs = [<deprecated_argument>]`"
|
||||
)
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxDDPMSchedulerTest(FlaxSchedulerCommonTest):
|
||||
|
|
Loading…
Reference in New Issue