make ALL_IMPORTABLE_CLASSES static
This commit is contained in:
parent
02cdd68331
commit
6b66999e75
|
@ -46,6 +46,10 @@ LOADABLE_CLASSES = {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ALL_IMPORTABLE_CLASSES = {}
|
||||||
|
for library in LOADABLE_CLASSES:
|
||||||
|
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
||||||
|
|
||||||
|
|
||||||
class DiffusionPipeline(ConfigMixin):
|
class DiffusionPipeline(ConfigMixin):
|
||||||
|
|
||||||
|
@ -125,12 +129,6 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
|
|
||||||
init_kwargs = {}
|
init_kwargs = {}
|
||||||
|
|
||||||
# get all importable classes to get the load method name for custom models/components
|
|
||||||
# here we enforce that custom models/components should always subclass from base classes in tansformers and diffusers
|
|
||||||
all_importable_classes = {}
|
|
||||||
for library in LOADABLE_CLASSES:
|
|
||||||
all_importable_classes.update(LOADABLE_CLASSES[library])
|
|
||||||
|
|
||||||
for name, (library_name, class_name) in init_dict.items():
|
for name, (library_name, class_name) in init_dict.items():
|
||||||
|
|
||||||
# if the model is not in diffusers or transformers, we need to load it from the hub
|
# if the model is not in diffusers or transformers, we need to load it from the hub
|
||||||
|
@ -138,8 +136,8 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
if library_name == module_candidate_name:
|
if library_name == module_candidate_name:
|
||||||
class_obj = get_class_from_dynamic_module(cached_folder, module, class_name, cached_folder)
|
class_obj = get_class_from_dynamic_module(cached_folder, module, class_name, cached_folder)
|
||||||
# since it's not from a library, we need to check class candidates for all importable classes
|
# since it's not from a library, we need to check class candidates for all importable classes
|
||||||
importable_classes = all_importable_classes
|
importable_classes = ALL_IMPORTABLE_CLASSES
|
||||||
class_candidates = {c: class_obj for c in all_importable_classes}
|
class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()}
|
||||||
else:
|
else:
|
||||||
library = importlib.import_module(library_name)
|
library = importlib.import_module(library_name)
|
||||||
class_obj = getattr(library, class_name)
|
class_obj = getattr(library, class_name)
|
||||||
|
|
Loading…
Reference in New Issue