diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index ee593b46..94a6c67b 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -15,6 +15,7 @@ # limitations under the License. import importlib +import inspect import os from typing import Optional, Union @@ -148,6 +149,12 @@ class DiffusionPipeline(ConfigMixin): diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) + # some modules can be passed directly to the init + # in this case they are already instantiated in `kwargs` + # extract them here + expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) init_kwargs = {} @@ -158,8 +165,36 @@ class DiffusionPipeline(ConfigMixin): # 3. Load each module in the pipeline for name, (library_name, class_name) in init_dict.items(): is_pipeline_module = hasattr(pipelines, library_name) + loaded_sub_model = None + # if the model is in a pipeline module, then we load it from the pipeline - if is_pipeline_module: + if name in passed_class_obj: + # 1. check that passed_class_obj has correct parent class + if not is_pipeline_module: + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + importable_classes = LOADABLE_CLASSES[library_name] + class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} + + expected_class_obj = None + for class_name, class_candidate in class_candidates.items(): + if issubclass(class_obj, class_candidate): + expected_class_obj = class_candidate + + if not issubclass(passed_class_obj[name].__class__, expected_class_obj): + raise ValueError( + f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" + f" {expected_class_obj}" + ) + else: + logger.warn( + f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" + " has the correct type" + ) + + # set passed class object + loaded_sub_model = passed_class_obj[name] + elif is_pipeline_module: pipeline_module = getattr(pipelines, library_name) class_obj = getattr(pipeline_module, class_name) importable_classes = ALL_IMPORTABLE_CLASSES @@ -171,23 +206,24 @@ class DiffusionPipeline(ConfigMixin): importable_classes = LOADABLE_CLASSES[library_name] class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} - load_method_name = None - for class_name, class_candidate in class_candidates.items(): - if issubclass(class_obj, class_candidate): - load_method_name = importable_classes[class_name][1] + if loaded_sub_model is None: + load_method_name = None + for class_name, class_candidate in class_candidates.items(): + if issubclass(class_obj, class_candidate): + load_method_name = importable_classes[class_name][1] - load_method = getattr(class_obj, load_method_name) + load_method = getattr(class_obj, load_method_name) - # check if the module is in a subdirectory - if os.path.isdir(os.path.join(cached_folder, name)): - loaded_sub_model = load_method(os.path.join(cached_folder, name)) - else: - # else load from the root directory - loaded_sub_model = load_method(cached_folder) + # check if the module is in a subdirectory + if os.path.isdir(os.path.join(cached_folder, name)): + loaded_sub_model = load_method(os.path.join(cached_folder, name)) + else: + # else load from the root directory + loaded_sub_model = load_method(cached_folder) init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) - # 5. Instantiate the pipeline + # 4. Instantiate the pipeline model = pipeline_class(**init_kwargs) return model diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index b24da077..628cb074 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -718,6 +718,28 @@ class PipelineTesterMixin(unittest.TestCase): assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" + @slow + def test_from_pretrained_hub_pass_model(self): + model_path = "google/ddpm-cifar10-32" + + # pass unet into DiffusionPipeline + unet = UNet2DModel.from_pretrained(model_path) + ddpm_from_hub_custom_model = DDPMPipeline.from_pretrained(model_path, unet=unet) + ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet) + + ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path) + + ddpm_from_hub_custom_model.scheduler.num_timesteps = 10 + ddpm_from_hub.scheduler.num_timesteps = 10 + + generator = torch.manual_seed(0) + + image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy")["sample"] + generator = generator.manual_seed(0) + new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"] + + assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" + @slow def test_output_format(self): model_path = "google/ddpm-cifar10-32"