update config dict logic
This commit is contained in:
parent
a61a961345
commit
fe99460b5f
|
@ -89,6 +89,7 @@ class ConfigMixin:
|
||||||
|
|
||||||
self.to_json_file(output_config_file)
|
self.to_json_file(output_config_file)
|
||||||
logger.info(f"ConfigMixinuration saved in {output_config_file}")
|
logger.info(f"ConfigMixinuration saved in {output_config_file}")
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config_dict(
|
def get_config_dict(
|
||||||
|
@ -182,35 +183,43 @@ class ConfigMixin:
|
||||||
logger.info(f"loading configuration file {config_file}")
|
logger.info(f"loading configuration file {config_file}")
|
||||||
else:
|
else:
|
||||||
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
|
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
|
||||||
|
|
||||||
|
return config_dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def extract_init_dict(cls, config_dict, **kwargs):
|
||||||
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
|
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
|
||||||
expected_keys.remove("self")
|
expected_keys.remove("self")
|
||||||
|
import ipdb; ipdb.set_trace()
|
||||||
|
init_dict = {}
|
||||||
for key in expected_keys:
|
for key in expected_keys:
|
||||||
if key in kwargs:
|
if key in kwargs:
|
||||||
# overwrite key
|
# overwrite key
|
||||||
config_dict[key] = kwargs.pop(key)
|
init_dict[key] = kwargs.pop(key)
|
||||||
|
elif key in config_dict:
|
||||||
|
# use value from config dict
|
||||||
|
init_dict[key] = config_dict.pop(key)
|
||||||
|
|
||||||
passed_keys = set(config_dict.keys())
|
|
||||||
|
|
||||||
unused_kwargs = kwargs
|
|
||||||
for key in passed_keys - expected_keys:
|
|
||||||
unused_kwargs[key] = config_dict.pop(key)
|
|
||||||
|
|
||||||
|
unused_kwargs = config_dict.update(kwargs)
|
||||||
|
|
||||||
|
passed_keys = set(init_dict.keys())
|
||||||
if len(expected_keys - passed_keys) > 0:
|
if len(expected_keys - passed_keys) > 0:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
|
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
|
||||||
)
|
)
|
||||||
|
|
||||||
return config_dict, unused_kwargs
|
return init_dict, unused_kwargs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
|
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
|
||||||
config_dict, unused_kwargs = cls.get_config_dict(
|
config_dict = cls.get_config_dict(
|
||||||
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
|
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
model = cls(**config_dict)
|
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
|
||||||
|
|
||||||
|
model = cls(**init_dict)
|
||||||
|
|
||||||
if return_unused_kwargs:
|
if return_unused_kwargs:
|
||||||
return model, unused_kwargs
|
return model, unused_kwargs
|
||||||
|
|
|
@ -97,16 +97,17 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
else:
|
else:
|
||||||
cached_folder = pretrained_model_name_or_path
|
cached_folder = pretrained_model_name_or_path
|
||||||
|
|
||||||
config_dict, pipeline_kwargs = cls.get_config_dict(cached_folder)
|
config_dict = cls.get_config_dict(cached_folder)
|
||||||
|
module = config_dict["_module"]
|
||||||
|
class_name_ = config_dict["_class_name"]
|
||||||
|
class_obj = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
|
||||||
|
|
||||||
module = pipeline_kwargs.pop("_module", None)
|
init_dict, unused = class_obj.extract_init_dict(config_dict, **kwargs)
|
||||||
# TODO(Suraj) - make from hub import work
|
import ipdb; ipdb.set_trace()
|
||||||
# Make `ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom-pipe")` work
|
|
||||||
# Add Sylvains code from transformers
|
|
||||||
|
|
||||||
init_kwargs = {}
|
init_kwargs = {}
|
||||||
|
|
||||||
for name, (library_name, class_name) in config_dict.items():
|
for name, (library_name, class_name) in init_dict.items():
|
||||||
importable_classes = LOADABLE_CLASSES[library_name]
|
importable_classes = LOADABLE_CLASSES[library_name]
|
||||||
|
|
||||||
if library_name == module:
|
if library_name == module:
|
||||||
|
@ -131,6 +132,6 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
|
|
||||||
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
||||||
|
|
||||||
class_obj = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
|
|
||||||
model = class_obj(**init_kwargs)
|
model = class_obj(**init_kwargs)
|
||||||
return model
|
return model
|
||||||
|
|
Loading…
Reference in New Issue