update config dict logic

This commit is contained in:
patil-suraj 2022-06-07 14:26:20 +02:00
parent a61a961345
commit fe99460b5f
2 changed files with 27 additions and 17 deletions

View File

@ -89,6 +89,7 @@ class ConfigMixin:
self.to_json_file(output_config_file) self.to_json_file(output_config_file)
logger.info(f"ConfigMixinuration saved in {output_config_file}") logger.info(f"ConfigMixinuration saved in {output_config_file}")
@classmethod @classmethod
def get_config_dict( def get_config_dict(
@ -182,35 +183,43 @@ class ConfigMixin:
logger.info(f"loading configuration file {config_file}") logger.info(f"loading configuration file {config_file}")
else: else:
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}") logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
return config_dict
@classmethod
def extract_init_dict(cls, config_dict, **kwargs):
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys()) expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
expected_keys.remove("self") expected_keys.remove("self")
import ipdb; ipdb.set_trace()
init_dict = {}
for key in expected_keys: for key in expected_keys:
if key in kwargs: if key in kwargs:
# overwrite key # overwrite key
config_dict[key] = kwargs.pop(key) init_dict[key] = kwargs.pop(key)
elif key in config_dict:
# use value from config dict
init_dict[key] = config_dict.pop(key)
passed_keys = set(config_dict.keys())
unused_kwargs = kwargs
for key in passed_keys - expected_keys:
unused_kwargs[key] = config_dict.pop(key)
unused_kwargs = config_dict.update(kwargs)
passed_keys = set(init_dict.keys())
if len(expected_keys - passed_keys) > 0: if len(expected_keys - passed_keys) > 0:
logger.warn( logger.warn(
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values." f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
) )
return config_dict, unused_kwargs return init_dict, unused_kwargs
@classmethod @classmethod
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs): def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
config_dict, unused_kwargs = cls.get_config_dict( config_dict = cls.get_config_dict(
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
) )
model = cls(**config_dict) init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
model = cls(**init_dict)
if return_unused_kwargs: if return_unused_kwargs:
return model, unused_kwargs return model, unused_kwargs

View File

@ -97,16 +97,17 @@ class DiffusionPipeline(ConfigMixin):
else: else:
cached_folder = pretrained_model_name_or_path cached_folder = pretrained_model_name_or_path
config_dict, pipeline_kwargs = cls.get_config_dict(cached_folder) config_dict = cls.get_config_dict(cached_folder)
module = config_dict["_module"]
class_name_ = config_dict["_class_name"]
class_obj = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
module = pipeline_kwargs.pop("_module", None) init_dict, unused = class_obj.extract_init_dict(config_dict, **kwargs)
# TODO(Suraj) - make from hub import work import ipdb; ipdb.set_trace()
# Make `ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom-pipe")` work
# Add Sylvains code from transformers
init_kwargs = {} init_kwargs = {}
for name, (library_name, class_name) in config_dict.items(): for name, (library_name, class_name) in init_dict.items():
importable_classes = LOADABLE_CLASSES[library_name] importable_classes = LOADABLE_CLASSES[library_name]
if library_name == module: if library_name == module:
@ -131,6 +132,6 @@ class DiffusionPipeline(ConfigMixin):
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
class_obj = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
model = class_obj(**init_kwargs) model = class_obj(**init_kwargs)
return model return model