update config dict logic
This commit is contained in:
parent
a61a961345
commit
fe99460b5f
|
@ -90,6 +90,7 @@ class ConfigMixin:
|
|||
self.to_json_file(output_config_file)
|
||||
logger.info(f"ConfigMixinuration saved in {output_config_file}")
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config_dict(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||
|
@ -183,34 +184,42 @@ class ConfigMixin:
|
|||
else:
|
||||
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
|
||||
|
||||
return config_dict
|
||||
|
||||
@classmethod
|
||||
def extract_init_dict(cls, config_dict, **kwargs):
|
||||
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
|
||||
expected_keys.remove("self")
|
||||
|
||||
import ipdb; ipdb.set_trace()
|
||||
init_dict = {}
|
||||
for key in expected_keys:
|
||||
if key in kwargs:
|
||||
# overwrite key
|
||||
config_dict[key] = kwargs.pop(key)
|
||||
init_dict[key] = kwargs.pop(key)
|
||||
elif key in config_dict:
|
||||
# use value from config dict
|
||||
init_dict[key] = config_dict.pop(key)
|
||||
|
||||
passed_keys = set(config_dict.keys())
|
||||
|
||||
unused_kwargs = kwargs
|
||||
for key in passed_keys - expected_keys:
|
||||
unused_kwargs[key] = config_dict.pop(key)
|
||||
unused_kwargs = config_dict.update(kwargs)
|
||||
|
||||
passed_keys = set(init_dict.keys())
|
||||
if len(expected_keys - passed_keys) > 0:
|
||||
logger.warn(
|
||||
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
|
||||
)
|
||||
|
||||
return config_dict, unused_kwargs
|
||||
return init_dict, unused_kwargs
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
|
||||
config_dict, unused_kwargs = cls.get_config_dict(
|
||||
config_dict = cls.get_config_dict(
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
|
||||
)
|
||||
|
||||
model = cls(**config_dict)
|
||||
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
model = cls(**init_dict)
|
||||
|
||||
if return_unused_kwargs:
|
||||
return model, unused_kwargs
|
||||
|
|
|
@ -97,16 +97,17 @@ class DiffusionPipeline(ConfigMixin):
|
|||
else:
|
||||
cached_folder = pretrained_model_name_or_path
|
||||
|
||||
config_dict, pipeline_kwargs = cls.get_config_dict(cached_folder)
|
||||
config_dict = cls.get_config_dict(cached_folder)
|
||||
module = config_dict["_module"]
|
||||
class_name_ = config_dict["_class_name"]
|
||||
class_obj = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
|
||||
|
||||
module = pipeline_kwargs.pop("_module", None)
|
||||
# TODO(Suraj) - make from hub import work
|
||||
# Make `ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom-pipe")` work
|
||||
# Add Sylvains code from transformers
|
||||
init_dict, unused = class_obj.extract_init_dict(config_dict, **kwargs)
|
||||
import ipdb; ipdb.set_trace()
|
||||
|
||||
init_kwargs = {}
|
||||
|
||||
for name, (library_name, class_name) in config_dict.items():
|
||||
for name, (library_name, class_name) in init_dict.items():
|
||||
importable_classes = LOADABLE_CLASSES[library_name]
|
||||
|
||||
if library_name == module:
|
||||
|
@ -131,6 +132,6 @@ class DiffusionPipeline(ConfigMixin):
|
|||
|
||||
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
||||
|
||||
class_obj = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
|
||||
|
||||
model = class_obj(**init_kwargs)
|
||||
return model
|
||||
|
|
Loading…
Reference in New Issue