allow loading model from pipeline module
This commit is contained in:
parent
ca72c1f81d
commit
d81b56ba5c
|
@ -55,11 +55,20 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
config_name = "model_index.json"
|
config_name = "model_index.json"
|
||||||
|
|
||||||
def register_modules(self, **kwargs):
|
def register_modules(self, **kwargs):
|
||||||
|
# import it here to avoid circular import
|
||||||
|
from diffusers import pipelines
|
||||||
|
|
||||||
for name, module in kwargs.items():
|
for name, module in kwargs.items():
|
||||||
|
# check if the module is a pipeline module
|
||||||
|
is_pipeline_module = hasattr(pipelines, module.__module__.split(".")[-1])
|
||||||
|
|
||||||
# retrive library
|
# retrive library
|
||||||
library = module.__module__.split(".")[0]
|
library = module.__module__.split(".")[0]
|
||||||
# if library is not in LOADABLE_CLASSES, then it is a custom module
|
|
||||||
if library not in LOADABLE_CLASSES:
|
# if library is not in LOADABLE_CLASSES, then it is a custom module.
|
||||||
|
# Or if it's a pipeline module, then the module is inside the pipeline
|
||||||
|
# so we set the library to module name.
|
||||||
|
if library not in LOADABLE_CLASSES or is_pipeline_module:
|
||||||
library = module.__module__.split(".")[-1]
|
library = module.__module__.split(".")[-1]
|
||||||
|
|
||||||
# retrive class_name
|
# retrive class_name
|
||||||
|
@ -151,12 +160,22 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||||
|
|
||||||
init_kwargs = {}
|
init_kwargs = {}
|
||||||
|
|
||||||
|
# import it here to avoid circular import
|
||||||
|
from diffusers import pipelines
|
||||||
|
|
||||||
# 4. Load each module in the pipeline
|
# 4. Load each module in the pipeline
|
||||||
for name, (library_name, class_name) in init_dict.items():
|
for name, (library_name, class_name) in init_dict.items():
|
||||||
# if the model is not in diffusers or transformers, we need to load it from the hub
|
is_pipeline_module = hasattr(pipelines, library_name)
|
||||||
# assumes that it's a subclass of ModelMixin
|
# if the model is in a pipeline module, then we load it from the pipeline
|
||||||
if library_name == module_candidate_name:
|
if is_pipeline_module:
|
||||||
|
pipeline_module = getattr(pipelines, library_name)
|
||||||
|
class_obj = getattr(pipeline_module, class_name)
|
||||||
|
importable_classes = ALL_IMPORTABLE_CLASSES
|
||||||
|
class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()}
|
||||||
|
elif library_name == module_candidate_name:
|
||||||
|
# if the model is not in diffusers or transformers, we need to load it from the hub
|
||||||
|
# assumes that it's a subclass of ModelMixin
|
||||||
class_obj = get_class_from_dynamic_module(cached_folder, module_candidate, class_name, cached_folder)
|
class_obj = get_class_from_dynamic_module(cached_folder, module_candidate, class_name, cached_folder)
|
||||||
# since it's not from a library, we need to check class candidates for all importable classes
|
# since it's not from a library, we need to check class candidates for all importable classes
|
||||||
importable_classes = ALL_IMPORTABLE_CLASSES
|
importable_classes = ALL_IMPORTABLE_CLASSES
|
||||||
|
|
Loading…
Reference in New Issue