allow loading model from pipeline module

This commit is contained in:
patil-suraj 2022-06-14 12:50:27 +02:00
parent ca72c1f81d
commit d81b56ba5c
1 changed files with 25 additions and 6 deletions

View File

@ -55,11 +55,20 @@ class DiffusionPipeline(ConfigMixin):
config_name = "model_index.json"
def register_modules(self, **kwargs):
# import it here to avoid circular import
from diffusers import pipelines
for name, module in kwargs.items():
# check if the module is a pipeline module
is_pipeline_module = hasattr(pipelines, module.__module__.split(".")[-1])
# retrive library
library = module.__module__.split(".")[0]
# if library is not in LOADABLE_CLASSES, then it is a custom module
if library not in LOADABLE_CLASSES:
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# so we set the library to module name.
if library not in LOADABLE_CLASSES or is_pipeline_module:
library = module.__module__.split(".")[-1]
# retrive class_name
@ -152,11 +161,21 @@ class DiffusionPipeline(ConfigMixin):
init_kwargs = {}
# import it here to avoid circular import
from diffusers import pipelines
# 4. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
is_pipeline_module = hasattr(pipelines, library_name)
# if the model is in a pipeline module, then we load it from the pipeline
if is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name)
importable_classes = ALL_IMPORTABLE_CLASSES
class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()}
elif library_name == module_candidate_name:
# if the model is not in diffusers or transformers, we need to load it from the hub
# assumes that it's a subclass of ModelMixin
if library_name == module_candidate_name:
class_obj = get_class_from_dynamic_module(cached_folder, module_candidate, class_name, cached_folder)
# since it's not from a library, we need to check class candidates for all importable classes
importable_classes = ALL_IMPORTABLE_CLASSES