fix
This commit is contained in:
parent
2665677b0a
commit
86064df7b5
|
@ -101,13 +101,15 @@ class DiffusionPipeline(ConfigMixin):
|
|||
|
||||
config_dict = cls.get_config_dict(cached_folder)
|
||||
|
||||
module_candidate = config_dict["_module"]
|
||||
|
||||
# if we load from explicit class, let's use it
|
||||
if cls != DiffusionPipeline:
|
||||
pipeline_class = cls
|
||||
else:
|
||||
# else we need to load the correct module from the Hub
|
||||
class_name_ = config_dict["_class_name"]
|
||||
module = config_dict["_module"]
|
||||
module = module_candidate
|
||||
pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
|
||||
|
||||
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
|
@ -117,7 +119,7 @@ class DiffusionPipeline(ConfigMixin):
|
|||
for name, (library_name, class_name) in init_dict.items():
|
||||
importable_classes = LOADABLE_CLASSES[library_name]
|
||||
|
||||
if library_name == module:
|
||||
if library_name == module_candidate:
|
||||
# TODO(Suraj)
|
||||
# for vq
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue