allow loading model from pipeline module
This commit is contained in:
parent
ca72c1f81d
commit
d81b56ba5c
|
@ -55,11 +55,20 @@ class DiffusionPipeline(ConfigMixin):
|
|||
config_name = "model_index.json"
|
||||
|
||||
def register_modules(self, **kwargs):
|
||||
# import it here to avoid circular import
|
||||
from diffusers import pipelines
|
||||
|
||||
for name, module in kwargs.items():
|
||||
# check if the module is a pipeline module
|
||||
is_pipeline_module = hasattr(pipelines, module.__module__.split(".")[-1])
|
||||
|
||||
# retrive library
|
||||
library = module.__module__.split(".")[0]
|
||||
# if library is not in LOADABLE_CLASSES, then it is a custom module
|
||||
if library not in LOADABLE_CLASSES:
|
||||
|
||||
# if library is not in LOADABLE_CLASSES, then it is a custom module.
|
||||
# Or if it's a pipeline module, then the module is inside the pipeline
|
||||
# so we set the library to module name.
|
||||
if library not in LOADABLE_CLASSES or is_pipeline_module:
|
||||
library = module.__module__.split(".")[-1]
|
||||
|
||||
# retrive class_name
|
||||
|
@ -152,11 +161,21 @@ class DiffusionPipeline(ConfigMixin):
|
|||
|
||||
init_kwargs = {}
|
||||
|
||||
# import it here to avoid circular import
|
||||
from diffusers import pipelines
|
||||
|
||||
# 4. Load each module in the pipeline
|
||||
for name, (library_name, class_name) in init_dict.items():
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
# if the model is in a pipeline module, then we load it from the pipeline
|
||||
if is_pipeline_module:
|
||||
pipeline_module = getattr(pipelines, library_name)
|
||||
class_obj = getattr(pipeline_module, class_name)
|
||||
importable_classes = ALL_IMPORTABLE_CLASSES
|
||||
class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()}
|
||||
elif library_name == module_candidate_name:
|
||||
# if the model is not in diffusers or transformers, we need to load it from the hub
|
||||
# assumes that it's a subclass of ModelMixin
|
||||
if library_name == module_candidate_name:
|
||||
class_obj = get_class_from_dynamic_module(cached_folder, module_candidate, class_name, cached_folder)
|
||||
# since it's not from a library, we need to check class candidates for all importable classes
|
||||
importable_classes = ALL_IMPORTABLE_CLASSES
|
||||
|
|
Loading…
Reference in New Issue