improve loading a bit
This commit is contained in:
parent
3a32b8c916
commit
936cd08488
|
@ -208,6 +208,7 @@ class ConfigMixin:
|
|||
def extract_init_dict(cls, config_dict, **kwargs):
|
||||
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
|
||||
expected_keys.remove("self")
|
||||
expected_keys.remove("kwargs")
|
||||
init_dict = {}
|
||||
for key in expected_keys:
|
||||
if key in kwargs:
|
||||
|
|
|
@ -147,6 +147,7 @@ class ModelMixin(torch.nn.Module):
|
|||
models, `pixel_values` for vision models and `input_values` for speech models).
|
||||
"""
|
||||
config_name = CONFIG_NAME
|
||||
_automatically_saved_args = ["_diffusers_version", "_class_name", "name_or_path"]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -63,8 +63,18 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
|
|||
mid_block_scale_factor=1,
|
||||
center_input_sample=False,
|
||||
resnet_num_groups=30,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
# remove automatically added kwargs
|
||||
for arg in self._automatically_saved_args:
|
||||
kwargs.pop(arg, None)
|
||||
|
||||
if len(kwargs) > 0:
|
||||
raise ValueError(
|
||||
f"The following keyword arguments do not exist for {self.__class__}: {','.join(kwargs.keys())}"
|
||||
)
|
||||
|
||||
# register all __init__ params to be accessible via `self.config.<...>`
|
||||
# should probably be automated down the road as this is pure boiler plate code
|
||||
self.register_to_config(
|
||||
|
|
|
@ -59,8 +59,18 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
|
|||
mid_block_scale_factor=1,
|
||||
center_input_sample=False,
|
||||
resnet_num_groups=32,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
# remove automatically added kwargs
|
||||
for arg in self._automatically_saved_args:
|
||||
kwargs.pop(arg, None)
|
||||
|
||||
if len(kwargs) > 0:
|
||||
raise ValueError(
|
||||
f"The following keyword arguments do not exist for {self.__class__}: {','.join(kwargs.keys())}"
|
||||
)
|
||||
|
||||
# register all __init__ params to be accessible via `self.config.<...>`
|
||||
# should probably be automated down the road as this is pure boiler plate code
|
||||
self.register_to_config(
|
||||
|
|
Loading…
Reference in New Issue