update config dict logic

This commit is contained in:
patil-suraj 2022-06-07 14:26:20 +02:00
parent a61a961345
commit fe99460b5f
2 changed files with 27 additions and 17 deletions

View File

@ -89,6 +89,7 @@ class ConfigMixin:
self.to_json_file(output_config_file)
logger.info(f"ConfigMixinuration saved in {output_config_file}")
@classmethod
def get_config_dict(
@ -182,35 +183,43 @@ class ConfigMixin:
logger.info(f"loading configuration file {config_file}")
else:
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
return config_dict
@classmethod
def extract_init_dict(cls, config_dict, **kwargs):
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
expected_keys.remove("self")
import ipdb; ipdb.set_trace()
init_dict = {}
for key in expected_keys:
if key in kwargs:
# overwrite key
config_dict[key] = kwargs.pop(key)
init_dict[key] = kwargs.pop(key)
elif key in config_dict:
# use value from config dict
init_dict[key] = config_dict.pop(key)
passed_keys = set(config_dict.keys())
unused_kwargs = kwargs
for key in passed_keys - expected_keys:
unused_kwargs[key] = config_dict.pop(key)
unused_kwargs = config_dict.update(kwargs)
passed_keys = set(init_dict.keys())
if len(expected_keys - passed_keys) > 0:
logger.warn(
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
)
return config_dict, unused_kwargs
return init_dict, unused_kwargs
@classmethod
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
config_dict, unused_kwargs = cls.get_config_dict(
config_dict = cls.get_config_dict(
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
)
model = cls(**config_dict)
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
model = cls(**init_dict)
if return_unused_kwargs:
return model, unused_kwargs

View File

@ -97,16 +97,17 @@ class DiffusionPipeline(ConfigMixin):
else:
cached_folder = pretrained_model_name_or_path
config_dict, pipeline_kwargs = cls.get_config_dict(cached_folder)
config_dict = cls.get_config_dict(cached_folder)
module = config_dict["_module"]
class_name_ = config_dict["_class_name"]
class_obj = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
module = pipeline_kwargs.pop("_module", None)
# TODO(Suraj) - make from hub import work
# Make `ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom-pipe")` work
# Add Sylvains code from transformers
init_dict, unused = class_obj.extract_init_dict(config_dict, **kwargs)
import ipdb; ipdb.set_trace()
init_kwargs = {}
for name, (library_name, class_name) in config_dict.items():
for name, (library_name, class_name) in init_dict.items():
importable_classes = LOADABLE_CLASSES[library_name]
if library_name == module:
@ -131,6 +132,6 @@ class DiffusionPipeline(ConfigMixin):
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
class_obj = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
model = class_obj(**init_kwargs)
return model