improve loading a bit

This commit is contained in:
Patrick von Platen 2022-07-19 22:02:54 +00:00
parent 3a32b8c916
commit 936cd08488
4 changed files with 22 additions and 0 deletions

View File

@ -208,6 +208,7 @@ class ConfigMixin:
def extract_init_dict(cls, config_dict, **kwargs):
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
expected_keys.remove("self")
expected_keys.remove("kwargs")
init_dict = {}
for key in expected_keys:
if key in kwargs:

View File

@ -147,6 +147,7 @@ class ModelMixin(torch.nn.Module):
models, `pixel_values` for vision models and `input_values` for speech models).
"""
config_name = CONFIG_NAME
_automatically_saved_args = ["_diffusers_version", "_class_name", "name_or_path"]
def __init__(self):
super().__init__()

View File

@ -63,8 +63,18 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
mid_block_scale_factor=1,
center_input_sample=False,
resnet_num_groups=30,
**kwargs,
):
super().__init__()
# remove automatically added kwargs
for arg in self._automatically_saved_args:
kwargs.pop(arg, None)
if len(kwargs) > 0:
raise ValueError(
f"The following keyword arguments do not exist for {self.__class__}: {','.join(kwargs.keys())}"
)
# register all __init__ params to be accessible via `self.config.<...>`
# should probably be automated down the road as this is pure boiler plate code
self.register_to_config(

View File

@ -59,8 +59,18 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
mid_block_scale_factor=1,
center_input_sample=False,
resnet_num_groups=32,
**kwargs,
):
super().__init__()
# remove automatically added kwargs
for arg in self._automatically_saved_args:
kwargs.pop(arg, None)
if len(kwargs) > 0:
raise ValueError(
f"The following keyword arguments do not exist for {self.__class__}: {','.join(kwargs.keys())}"
)
# register all __init__ params to be accessible via `self.config.<...>`
# should probably be automated down the road as this is pure boiler plate code
self.register_to_config(