Allow passing non-default modules to pipeline (#188)

* Allow passing non-default modules to pipeline.

Override modules are recognized and replaced in the pipeline. However,
no check is performed about mismatched classes yet. This is because the
override module is already instantiated and we have no library or class
name to compare against.

* up

* add test

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Pedro Cuenca 2022-08-16 17:25:25 +02:00 committed by GitHub
parent d7b692083c
commit 513f1fbfb0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 71 additions and 13 deletions

View File

@ -15,6 +15,7 @@
# limitations under the License.
import importlib
import inspect
import os
from typing import Optional, Union
@ -148,6 +149,12 @@ class DiffusionPipeline(ConfigMixin):
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
# some modules can be passed directly to the init
# in this case they are already instantiated in `kwargs`
# extract them here
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys())
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {}
@ -158,8 +165,36 @@ class DiffusionPipeline(ConfigMixin):
# 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None
# if the model is in a pipeline module, then we load it from the pipeline
if is_pipeline_module:
if name in passed_class_obj:
# 1. check that passed_class_obj has correct parent class
if not is_pipeline_module:
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
expected_class_obj = None
for class_name, class_candidate in class_candidates.items():
if issubclass(class_obj, class_candidate):
expected_class_obj = class_candidate
if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
raise ValueError(
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
f" {expected_class_obj}"
)
else:
logger.warn(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
" has the correct type"
)
# set passed class object
loaded_sub_model = passed_class_obj[name]
elif is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name)
importable_classes = ALL_IMPORTABLE_CLASSES
@ -171,6 +206,7 @@ class DiffusionPipeline(ConfigMixin):
importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
if loaded_sub_model is None:
load_method_name = None
for class_name, class_candidate in class_candidates.items():
if issubclass(class_obj, class_candidate):
@ -187,7 +223,7 @@ class DiffusionPipeline(ConfigMixin):
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
# 5. Instantiate the pipeline
# 4. Instantiate the pipeline
model = pipeline_class(**init_kwargs)
return model

View File

@ -718,6 +718,28 @@ class PipelineTesterMixin(unittest.TestCase):
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
@slow
def test_from_pretrained_hub_pass_model(self):
model_path = "google/ddpm-cifar10-32"
# pass unet into DiffusionPipeline
unet = UNet2DModel.from_pretrained(model_path)
ddpm_from_hub_custom_model = DDPMPipeline.from_pretrained(model_path, unet=unet)
ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet)
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)
ddpm_from_hub_custom_model.scheduler.num_timesteps = 10
ddpm_from_hub.scheduler.num_timesteps = 10
generator = torch.manual_seed(0)
image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy")["sample"]
generator = generator.manual_seed(0)
new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"]
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
@slow
def test_output_format(self):
model_path = "google/ddpm-cifar10-32"