diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index 3d1bd492..390b8ff0 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -12,6 +12,8 @@
title: "Loading Pipelines, Models, and Schedulers"
- local: using-diffusers/configuration
title: "Configuring Pipelines, Models, and Schedulers"
+ - local: using-diffusers/custom_pipelines
+ title: "Loading and Creating Custom Pipelines"
title: "Loading"
- sections:
- local: using-diffusers/unconditional_image_generation
diff --git a/docs/source/using-diffusers/custom_pipelines.mdx b/docs/source/using-diffusers/custom_pipelines.mdx
new file mode 100644
index 00000000..9466f1dc
--- /dev/null
+++ b/docs/source/using-diffusers/custom_pipelines.mdx
@@ -0,0 +1,121 @@
+
+
+# Custom Pipelines
+
+Diffusers allows you to conveniently load any custom pipeline from the Hugging Face Hub as well as any [official community pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community)
+via the [`DiffusionPipeline`] class.
+
+## Loading custom pipelines from the Hub
+
+Custom pipelines can be easily loaded from any model repository on the Hub that defines a diffusion pipeline in a `pipeline.py` file.
+Let's load a dummy pipeline from [hf-internal-testing/diffusers-dummy-pipeline](https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline).
+
+All you need to do is pass the custom pipeline repo id with the `custom_pipeline` argument alongside the repo from where you wish to load the pipeline modules.
+
+```python
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
+)
+```
+
+This will load the custom pipeline as defined in the [model repository](https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py).
+
+
+
+By loading a custom pipeline from the Hugging Face Hub, you are trusting that the code you are loading
+is safe 🔒. Make sure to check out the code online before loading & running it automatically.
+
+
+
+## Loading official community pipelines
+
+Community pipelines are summarized in the [community examples folder](https://github.com/huggingface/diffusers/tree/main/examples/community)
+
+Similarly, you need to pass both the *repo id* from where you wish to load the weights as well as the `custom_pipeline` argument. Here the `custom_pipeline` argument should consist simply of the filename of the community pipeline excluding the `.py` suffix, *e.g.* `clip_guided_stable_diffusion`.
+
+Since community pipelines are often more complex, one can mix loading weights from an official *repo id*
+and passing pipeline modules directly.
+
+```python
+from diffusers import DiffusionPipeline
+from transformers import CLIPFeatureExtractor, CLIPModel
+
+clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
+
+feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)
+clip_model = CLIPModel.from_pretrained(clip_model_id)
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ custom_pipeline="clip_guided_stable_diffusion",
+ clip_model=clip_model,
+ feature_extractor=feature_extractor,
+)
+```
+
+## Adding custom pipelines to the Hub
+
+To add a custom pipeline to the Hub, all you need to do is to define a pipeline class that inherits
+from [`DiffusionPipeline`] in a `pipeline.py` file.
+Make sure that the whole pipeline is encapsulated within a single class and that the `pipeline.py` file
+has only one such class.
+
+Let's quickly define an example pipeline.
+
+
+```python
+import torch
+from diffusers import DiffusionPipeline
+
+
+class MyPipeline(DiffusionPipeline):
+ def __init__(self, unet, scheduler):
+ super().__init__()
+
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(self, batch_size: int = 1, num_inference_steps: int = 50):
+ # Sample gaussian noise to begin loop
+ image = torch.randn((batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size))
+
+ image = image.to(self.device)
+
+ # set step values
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # 1. predict noise model_output
+ model_output = self.unet(image, t).sample
+
+ # 2. predict previous mean of image x_t-1 and add variance depending on eta
+ # eta corresponds to η in paper and should be between [0, 1]
+ # do x_t -> x_t-1
+ image = self.scheduler.step(model_output, t, image, eta).prev_sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ return image
+```
+
+Now you can upload this short file under the name `pipeline.py` in your preferred [model repository](https://huggingface.co/docs/hub/models-uploading). For Stable Diffusion pipelines, you may also [join the community organisation for shared pipelines](https://huggingface.co/organizations/sd-diffusers-pipelines-library/share/BUPyDUuHcciGTOKaExlqtfFcyCZsVFdrjr) to upload yours.
+Finally, we can load the custom pipeline by passing the model repository name, *e.g.* `sd-diffusers-pipelines-library/my_custom_pipeline` alongside the model repository from where we want to load the `unet` and `scheduler` components.
+
+```python
+my_pipeline = DiffusionPipeline.from_pretrained(
+ "google/ddpm-cifar10-32", custom_pipeline="patrickvonplaten/my_custom_pipeline"
+)
+```
diff --git a/hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/blobs/bbbcb9f65616524d6199fa3bc16dc0500fb2cbbb b/hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/blobs/bbbcb9f65616524d6199fa3bc16dc0500fb2cbbb
new file mode 100644
index 00000000..bbbcb9f6
--- /dev/null
+++ b/hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/blobs/bbbcb9f65616524d6199fa3bc16dc0500fb2cbbb
@@ -0,0 +1,102 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+# limitations under the License.
+
+
+from typing import Optional, Tuple, Union
+
+import torch
+
+from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+
+class CustomPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
+ [`DDPMScheduler`], or [`DDIMScheduler`].
+ """
+
+ def __init__(self, unet, scheduler):
+ super().__init__()
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ generator: Optional[torch.Generator] = None,
+ eta: float = 0.0,
+ num_inference_steps: int = 50,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ eta (`float`, *optional*, defaults to 0.0):
+ The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM).
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ # Sample gaussian noise to begin loop
+ image = torch.randn(
+ (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
+ generator=generator,
+ )
+ image = image.to(self.device)
+
+ # set step values
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # 1. predict noise model_output
+ model_output = self.unet(image, t).sample
+
+ # 2. predict previous mean of image x_t-1 and add variance depending on eta
+ # eta corresponds to η in paper and should be between [0, 1]
+ # do x_t -> x_t-1
+ image = self.scheduler.step(model_output, t, image, eta).prev_sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image), "This is a test"
\ No newline at end of file
diff --git a/hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/refs/main b/hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/refs/main
new file mode 100644
index 00000000..152c8af6
--- /dev/null
+++ b/hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/refs/main
@@ -0,0 +1 @@
+b8fa12635e53eebebc22f95ee863e7af4fc2fb07
\ No newline at end of file
diff --git a/hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/snapshots/b8fa12635e53eebebc22f95ee863e7af4fc2fb07/pipeline.py b/hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/snapshots/b8fa12635e53eebebc22f95ee863e7af4fc2fb07/pipeline.py
new file mode 120000
index 00000000..47bb9680
--- /dev/null
+++ b/hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/snapshots/b8fa12635e53eebebc22f95ee863e7af4fc2fb07/pipeline.py
@@ -0,0 +1 @@
+../../blobs/bbbcb9f65616524d6199fa3bc16dc0500fb2cbbb
\ No newline at end of file
diff --git a/src/diffusers/dynamic_modules_utils.py b/src/diffusers/dynamic_modules_utils.py
index 1177c4ef..ad3d1a6f 100644
--- a/src/diffusers/dynamic_modules_utils.py
+++ b/src/diffusers/dynamic_modules_utils.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2021 The HuggingFace Inc. team.
+# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,6 +15,7 @@
"""Utilities to dynamically load objects from the Hub."""
import importlib
+import inspect
import os
import re
import shutil
@@ -22,11 +23,16 @@ import sys
from pathlib import Path
from typing import Dict, Optional, Union
-from huggingface_hub import cached_download
+from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info
from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
+COMMUNITY_PIPELINES_URL = (
+ "https://raw.githubusercontent.com/huggingface/diffusers/main/examples/community/{pipeline}.py"
+)
+
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -145,9 +151,35 @@ def get_class_in_module(class_name, module_path):
"""
module_path = module_path.replace(os.path.sep, ".")
module = importlib.import_module(module_path)
+
+ if class_name is None:
+ return find_pipeline_class(module)
return getattr(module, class_name)
+def find_pipeline_class(loaded_module):
+ """
+ Retrieve pipeline class that inherits from `DiffusionPipeline`. Note that there has to be exactly one class
+ inheriting from `DiffusionPipeline`.
+ """
+ from .pipeline_utils import DiffusionPipeline
+
+ cls_members = dict(inspect.getmembers(loaded_module, inspect.isclass))
+
+ pipeline_class = None
+ for cls_name, cls in cls_members.items():
+ if cls_name != DiffusionPipeline.__name__ and issubclass(cls, DiffusionPipeline):
+ if pipeline_class is not None:
+ raise ValueError(
+ f"Multiple classes that inherit from {DiffusionPipeline.__name__} have been found:"
+ f" {pipeline_class.__name__}, and {cls_name}. Please make sure to define only one in"
+ f" {loaded_module}."
+ )
+ pipeline_class = cls
+
+ return pipeline_class
+
+
def get_cached_module_file(
pretrained_model_name_or_path: Union[str, os.PathLike],
module_file: str,
@@ -208,16 +240,35 @@ def get_cached_module_file(
"""
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
- submodule = "local"
if os.path.isfile(module_file_or_url):
resolved_module_file = module_file_or_url
+ submodule = "local"
+ elif pretrained_model_name_or_path.count("/") == 0:
+ # community pipeline on GitHub
+ github_url = COMMUNITY_PIPELINES_URL.format(pipeline=pretrained_model_name_or_path)
+ try:
+ resolved_module_file = cached_download(
+ github_url,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=False,
+ )
+ submodule = "local"
+ except EnvironmentError:
+ logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
+ raise
else:
try:
# Load from URL or cache if already cached
- resolved_module_file = cached_download(
- module_file_or_url,
+ resolved_module_file = hf_hub_download(
+ pretrained_model_name_or_path,
+ module_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
@@ -225,7 +276,7 @@ def get_cached_module_file(
local_files_only=local_files_only,
use_auth_token=use_auth_token,
)
-
+ submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/")))
except EnvironmentError:
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
raise
@@ -237,20 +288,55 @@ def get_cached_module_file(
full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
create_dynamic_module(full_submodule)
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
- # We always copy local files (we could hash the file to see if there was a change, and give them the name of
- # that hash, to only copy when there is a modification but it seems overkill for now).
- # The only reason we do the copy is to avoid putting too many folders in sys.path.
- shutil.copy(resolved_module_file, submodule_path / module_file)
- for module_needed in modules_needed:
- module_needed = f"{module_needed}.py"
- shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
+ if submodule == "local":
+ # We always copy local files (we could hash the file to see if there was a change, and give them the name of
+ # that hash, to only copy when there is a modification but it seems overkill for now).
+ # The only reason we do the copy is to avoid putting too many folders in sys.path.
+ shutil.copy(resolved_module_file, submodule_path / module_file)
+ for module_needed in modules_needed:
+ module_needed = f"{module_needed}.py"
+ shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
+ else:
+ # Get the commit hash
+ # TODO: we will get this info in the etag soon, so retrieve it from there and not here.
+ if isinstance(use_auth_token, str):
+ token = use_auth_token
+ elif use_auth_token is True:
+ token = HfFolder.get_token()
+ else:
+ token = None
+
+ commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha
+
+ # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
+ # benefit of versioning.
+ submodule_path = submodule_path / commit_hash
+ full_submodule = full_submodule + os.path.sep + commit_hash
+ create_dynamic_module(full_submodule)
+
+ if not (submodule_path / module_file).exists():
+ shutil.copy(resolved_module_file, submodule_path / module_file)
+ # Make sure we also have every file with relative
+ for module_needed in modules_needed:
+ if not (submodule_path / module_needed).exists():
+ get_cached_module_file(
+ pretrained_model_name_or_path,
+ f"{module_needed}.py",
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ local_files_only=local_files_only,
+ )
return os.path.join(full_submodule, module_file)
def get_class_from_dynamic_module(
pretrained_model_name_or_path: Union[str, os.PathLike],
module_file: str,
- class_name: str,
+ class_name: Optional[str] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: bool = False,
diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py
index 01ba4eef..8991bfce 100644
--- a/src/diffusers/pipeline_utils.py
+++ b/src/diffusers/pipeline_utils.py
@@ -30,11 +30,13 @@ from PIL import Image
from tqdm.auto import tqdm
from .configuration_utils import ConfigMixin
+from .dynamic_modules_utils import get_class_from_dynamic_module
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
INDEX_FILE = "diffusion_pytorch_model.bin"
+CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
logger = logging.get_logger(__name__)
@@ -214,6 +216,52 @@ class DiffusionPipeline(ConfigMixin):
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
will be automatically derived from the model's weights.
+ custom_pipeline (`str`, *optional*):
+
+
+
+ This is an experimental feature and is likely to change in the future.
+
+
+
+ Can be either:
+
+ - A string, the *repo id* of a custom pipeline hosted inside a model repo on
+ https://huggingface.co/. Valid repo ids have to be located under a user or organization name,
+ like `hf-internal-testing/diffusers-dummy-pipeline`.
+
+
+
+ It is required that the model repo has a file, called `pipeline.py` that defines the custom
+ pipeline.
+
+
+
+ - A string, the *file name* of a community pipeline hosted on GitHub under
+ https://github.com/huggingface/diffusers/tree/main/examples/community. Valid file names have to
+ match exactly the file name without `.py` located under the above link, *e.g.*
+ `clip_guided_stable_diffusion`.
+
+
+
+ Community pipelines are always loaded from the current `main` branch of GitHub.
+
+
+
+ - A path to a *directory* containing a custom pipeline, e.g., `./my_pipeline_directory/`.
+
+
+
+ It is required that the directory has a file, called `pipeline.py` that defines the custom
+ pipeline.
+
+
+
+ For more information on how to load and create custom pipelines, please have a look at [Loading and
+ Creating Custom
+ Pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/custom_pipelines)
+
+ torch_dtype (`str` or `torch.dtype`, *optional*):
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
@@ -285,6 +333,7 @@ class DiffusionPipeline(ConfigMixin):
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
+ custom_pipeline = kwargs.pop("custom_pipeline", None)
provider = kwargs.pop("provider", None)
sess_options = kwargs.pop("sess_options", None)
@@ -305,6 +354,9 @@ class DiffusionPipeline(ConfigMixin):
allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
+ if custom_pipeline is not None:
+ allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
+
# download all allow_patterns
cached_folder = snapshot_download(
pretrained_model_name_or_path,
@@ -323,7 +375,11 @@ class DiffusionPipeline(ConfigMixin):
# 2. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
- if cls != DiffusionPipeline:
+ if custom_pipeline is not None:
+ pipeline_class = get_class_from_dynamic_module(
+ custom_pipeline, module_file=CUSTOM_PIPELINE_FILE_NAME, cache_dir=custom_pipeline
+ )
+ elif cls != DiffusionPipeline:
pipeline_class = cls
else:
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
@@ -332,7 +388,7 @@ class DiffusionPipeline(ConfigMixin):
# some modules can be passed directly to the init
# in this case they are already instantiated in `kwargs`
# extract them here
- expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys())
+ expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"])
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
@@ -414,7 +470,18 @@ class DiffusionPipeline(ConfigMixin):
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
- # 4. Instantiate the pipeline
+ # 4. Potentially add passed objects if expected
+ missing_modules = set(expected_modules) - set(init_kwargs.keys())
+ if len(missing_modules) > 0 and missing_modules <= set(passed_class_obj.keys()):
+ for module in missing_modules:
+ init_kwargs[module] = passed_class_obj[module]
+ elif len(missing_modules) > 0:
+ passed_modules = set(init_kwargs.keys() + passed_class_obj.keys())
+ raise ValueError(
+ f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
+ )
+
+ # 5. Instantiate the pipeline
model = pipeline_class(**init_kwargs)
return model
diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py
index d3f6fa62..0177d30a 100644
--- a/src/diffusers/utils/testing_utils.py
+++ b/src/diffusers/utils/testing_utils.py
@@ -1,3 +1,4 @@
+import inspect
import os
import random
import re
@@ -22,6 +23,27 @@ if is_torch_higher_equal_than_1_12:
torch_device = "mps" if torch.backends.mps.is_available() else torch_device
+def get_tests_dir(append_path=None):
+ """
+ Args:
+ append_path: optional path to append to the tests dir path
+ Return:
+ The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is
+ joined after the `tests` dir the former is provided.
+ """
+ # this function caller's __file__
+ caller__file__ = inspect.stack()[1][1]
+ tests_dir = os.path.abspath(os.path.dirname(caller__file__))
+
+ while not tests_dir.endswith("tests"):
+ tests_dir = os.path.dirname(tests_dir)
+
+ if append_path:
+ return os.path.join(tests_dir, append_path)
+ else:
+ return tests_dir
+
+
def parse_flag_from_env(key, default=False):
try:
value = os.environ[key]
diff --git a/tests/fixtures/custom_pipeline/pipeline.py b/tests/fixtures/custom_pipeline/pipeline.py
new file mode 100644
index 00000000..10a22eda
--- /dev/null
+++ b/tests/fixtures/custom_pipeline/pipeline.py
@@ -0,0 +1,102 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+# limitations under the License.
+
+
+from typing import Optional, Tuple, Union
+
+import torch
+
+from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+
+class CustomLocalPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
+ [`DDPMScheduler`], or [`DDIMScheduler`].
+ """
+
+ def __init__(self, unet, scheduler):
+ super().__init__()
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ generator: Optional[torch.Generator] = None,
+ eta: float = 0.0,
+ num_inference_steps: int = 50,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ eta (`float`, *optional*, defaults to 0.0):
+ The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM).
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ # Sample gaussian noise to begin loop
+ image = torch.randn(
+ (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
+ generator=generator,
+ )
+ image = image.to(self.device)
+
+ # set step values
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # 1. predict noise model_output
+ model_output = self.unet(image, t).sample
+
+ # 2. predict previous mean of image x_t-1 and add variance depending on eta
+ # eta corresponds to η in paper and should be between [0, 1]
+ # do x_t -> x_t-1
+ image = self.scheduler.step(model_output, t, image, eta).prev_sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,), "This is a local test"
+
+ return ImagePipelineOutput(images=image), "This is a local test"
diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py
index 69301e97..4a0839ad 100644
--- a/tests/test_pipelines.py
+++ b/tests/test_pipelines.py
@@ -49,8 +49,9 @@ from diffusers import (
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device
+from diffusers.utils.testing_utils import get_tests_dir
from PIL import Image
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
+from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
torch.backends.cuda.matmul.allow_tf32 = False
@@ -79,6 +80,60 @@ def test_progress_bar(capsys):
assert captured.err == "", "Progress bar should be disabled"
+class CustomPipelineTests(unittest.TestCase):
+ def test_load_custom_pipeline(self):
+ pipeline = DiffusionPipeline.from_pretrained(
+ "google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
+ )
+ # NOTE that `"CustomPipeline"` is not a class that is defined in this library, but solely on the Hub
+ # under https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L24
+ assert pipeline.__class__.__name__ == "CustomPipeline"
+
+ def test_run_custom_pipeline(self):
+ pipeline = DiffusionPipeline.from_pretrained(
+ "google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
+ )
+ images, output_str = pipeline(num_inference_steps=2, output_type="np")
+
+ assert images[0].shape == (1, 32, 32, 3)
+ # compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102
+ assert output_str == "This is a test"
+
+ def test_local_custom_pipeline(self):
+ local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline")
+ pipeline = DiffusionPipeline.from_pretrained(
+ "google/ddpm-cifar10-32", custom_pipeline=local_custom_pipeline_path
+ )
+ images, output_str = pipeline(num_inference_steps=2, output_type="np")
+
+ assert pipeline.__class__.__name__ == "CustomLocalPipeline"
+ assert images[0].shape == (1, 32, 32, 3)
+ # compare to https://github.com/huggingface/diffusers/blob/main/tests/fixtures/custom_pipeline/pipeline.py#L102
+ assert output_str == "This is a local test"
+
+ @slow
+ def test_load_pipeline_from_git(self):
+ clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
+
+ feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)
+ clip_model = CLIPModel.from_pretrained(clip_model_id)
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ custom_pipeline="clip_guided_stable_diffusion",
+ clip_model=clip_model,
+ feature_extractor=feature_extractor,
+ )
+ pipeline = pipeline.to(torch_device)
+
+ # NOTE that `"CLIPGuidedStableDiffusion"` is not a class that is defined in the pypi package of th e library, but solely on the community examples folder of GitHub under:
+ # https://github.com/huggingface/diffusers/blob/main/examples/community/clip_guided_stable_diffusion.py
+ assert pipeline.__class__.__name__ == "CLIPGuidedStableDiffusion"
+
+ image = pipeline("a prompt", num_inference_steps=2, output_type="np").images[0]
+ assert image.shape == (512, 512, 3)
+
+
class PipelineFastTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test