diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 34bba89e..ca61120f 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -89,6 +89,7 @@ class ConfigMixin: self.to_json_file(output_config_file) logger.info(f"ConfigMixinuration saved in {output_config_file}") + @classmethod def get_config_dict( @@ -182,35 +183,43 @@ class ConfigMixin: logger.info(f"loading configuration file {config_file}") else: logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}") + + return config_dict + @classmethod + def extract_init_dict(cls, config_dict, **kwargs): expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys()) expected_keys.remove("self") - + import ipdb; ipdb.set_trace() + init_dict = {} for key in expected_keys: if key in kwargs: # overwrite key - config_dict[key] = kwargs.pop(key) + init_dict[key] = kwargs.pop(key) + elif key in config_dict: + # use value from config dict + init_dict[key] = config_dict.pop(key) - passed_keys = set(config_dict.keys()) - - unused_kwargs = kwargs - for key in passed_keys - expected_keys: - unused_kwargs[key] = config_dict.pop(key) + unused_kwargs = config_dict.update(kwargs) + + passed_keys = set(init_dict.keys()) if len(expected_keys - passed_keys) > 0: logger.warn( f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values." ) - return config_dict, unused_kwargs + return init_dict, unused_kwargs @classmethod def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs): - config_dict, unused_kwargs = cls.get_config_dict( + config_dict = cls.get_config_dict( pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs ) - model = cls(**config_dict) + init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs) + + model = cls(**init_dict) if return_unused_kwargs: return model, unused_kwargs diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 8037d121..ba3a823e 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -97,16 +97,17 @@ class DiffusionPipeline(ConfigMixin): else: cached_folder = pretrained_model_name_or_path - config_dict, pipeline_kwargs = cls.get_config_dict(cached_folder) + config_dict = cls.get_config_dict(cached_folder) + module = config_dict["_module"] + class_name_ = config_dict["_class_name"] + class_obj = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder) - module = pipeline_kwargs.pop("_module", None) - # TODO(Suraj) - make from hub import work - # Make `ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom-pipe")` work - # Add Sylvains code from transformers + init_dict, unused = class_obj.extract_init_dict(config_dict, **kwargs) + import ipdb; ipdb.set_trace() init_kwargs = {} - for name, (library_name, class_name) in config_dict.items(): + for name, (library_name, class_name) in init_dict.items(): importable_classes = LOADABLE_CLASSES[library_name] if library_name == module: @@ -131,6 +132,6 @@ class DiffusionPipeline(ConfigMixin): init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) - class_obj = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder) + model = class_obj(**init_kwargs) return model