Custome Pipelines (#744)
* [Custom Pipelines] * uP * make style * finish * finish * remove ipdb * upload * fix * finish docs * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: apolinario <joaopaulo.passos@gmail.com> * finish * final uploads * remove unnecessary test Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: apolinario <joaopaulo.passos@gmail.com>
This commit is contained in:
parent
f3128c8788
commit
d9c449ea30
|
@ -12,6 +12,8 @@
|
||||||
title: "Loading Pipelines, Models, and Schedulers"
|
title: "Loading Pipelines, Models, and Schedulers"
|
||||||
- local: using-diffusers/configuration
|
- local: using-diffusers/configuration
|
||||||
title: "Configuring Pipelines, Models, and Schedulers"
|
title: "Configuring Pipelines, Models, and Schedulers"
|
||||||
|
- local: using-diffusers/custom_pipelines
|
||||||
|
title: "Loading and Creating Custom Pipelines"
|
||||||
title: "Loading"
|
title: "Loading"
|
||||||
- sections:
|
- sections:
|
||||||
- local: using-diffusers/unconditional_image_generation
|
- local: using-diffusers/unconditional_image_generation
|
||||||
|
|
|
@ -0,0 +1,121 @@
|
||||||
|
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
-->
|
||||||
|
|
||||||
|
# Custom Pipelines
|
||||||
|
|
||||||
|
Diffusers allows you to conveniently load any custom pipeline from the Hugging Face Hub as well as any [official community pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community)
|
||||||
|
via the [`DiffusionPipeline`] class.
|
||||||
|
|
||||||
|
## Loading custom pipelines from the Hub
|
||||||
|
|
||||||
|
Custom pipelines can be easily loaded from any model repository on the Hub that defines a diffusion pipeline in a `pipeline.py` file.
|
||||||
|
Let's load a dummy pipeline from [hf-internal-testing/diffusers-dummy-pipeline](https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline).
|
||||||
|
|
||||||
|
All you need to do is pass the custom pipeline repo id with the `custom_pipeline` argument alongside the repo from where you wish to load the pipeline modules.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
|
pipeline = DiffusionPipeline.from_pretrained(
|
||||||
|
"google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
This will load the custom pipeline as defined in the [model repository](https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py).
|
||||||
|
|
||||||
|
<Tip warning={true} >
|
||||||
|
|
||||||
|
By loading a custom pipeline from the Hugging Face Hub, you are trusting that the code you are loading
|
||||||
|
is safe 🔒. Make sure to check out the code online before loading & running it automatically.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
## Loading official community pipelines
|
||||||
|
|
||||||
|
Community pipelines are summarized in the [community examples folder](https://github.com/huggingface/diffusers/tree/main/examples/community)
|
||||||
|
|
||||||
|
Similarly, you need to pass both the *repo id* from where you wish to load the weights as well as the `custom_pipeline` argument. Here the `custom_pipeline` argument should consist simply of the filename of the community pipeline excluding the `.py` suffix, *e.g.* `clip_guided_stable_diffusion`.
|
||||||
|
|
||||||
|
Since community pipelines are often more complex, one can mix loading weights from an official *repo id*
|
||||||
|
and passing pipeline modules directly.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
from transformers import CLIPFeatureExtractor, CLIPModel
|
||||||
|
|
||||||
|
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
||||||
|
|
||||||
|
feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)
|
||||||
|
clip_model = CLIPModel.from_pretrained(clip_model_id)
|
||||||
|
|
||||||
|
pipeline = DiffusionPipeline.from_pretrained(
|
||||||
|
"CompVis/stable-diffusion-v1-4",
|
||||||
|
custom_pipeline="clip_guided_stable_diffusion",
|
||||||
|
clip_model=clip_model,
|
||||||
|
feature_extractor=feature_extractor,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Adding custom pipelines to the Hub
|
||||||
|
|
||||||
|
To add a custom pipeline to the Hub, all you need to do is to define a pipeline class that inherits
|
||||||
|
from [`DiffusionPipeline`] in a `pipeline.py` file.
|
||||||
|
Make sure that the whole pipeline is encapsulated within a single class and that the `pipeline.py` file
|
||||||
|
has only one such class.
|
||||||
|
|
||||||
|
Let's quickly define an example pipeline.
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
|
|
||||||
|
class MyPipeline(DiffusionPipeline):
|
||||||
|
def __init__(self, unet, scheduler):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.register_modules(unet=unet, scheduler=scheduler)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, batch_size: int = 1, num_inference_steps: int = 50):
|
||||||
|
# Sample gaussian noise to begin loop
|
||||||
|
image = torch.randn((batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size))
|
||||||
|
|
||||||
|
image = image.to(self.device)
|
||||||
|
|
||||||
|
# set step values
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
|
|
||||||
|
for t in self.progress_bar(self.scheduler.timesteps):
|
||||||
|
# 1. predict noise model_output
|
||||||
|
model_output = self.unet(image, t).sample
|
||||||
|
|
||||||
|
# 2. predict previous mean of image x_t-1 and add variance depending on eta
|
||||||
|
# eta corresponds to η in paper and should be between [0, 1]
|
||||||
|
# do x_t -> x_t-1
|
||||||
|
image = self.scheduler.step(model_output, t, image, eta).prev_sample
|
||||||
|
|
||||||
|
image = (image / 2 + 0.5).clamp(0, 1)
|
||||||
|
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||||
|
|
||||||
|
return image
|
||||||
|
```
|
||||||
|
|
||||||
|
Now you can upload this short file under the name `pipeline.py` in your preferred [model repository](https://huggingface.co/docs/hub/models-uploading). For Stable Diffusion pipelines, you may also [join the community organisation for shared pipelines](https://huggingface.co/organizations/sd-diffusers-pipelines-library/share/BUPyDUuHcciGTOKaExlqtfFcyCZsVFdrjr) to upload yours.
|
||||||
|
Finally, we can load the custom pipeline by passing the model repository name, *e.g.* `sd-diffusers-pipelines-library/my_custom_pipeline` alongside the model repository from where we want to load the `unet` and `scheduler` components.
|
||||||
|
|
||||||
|
```python
|
||||||
|
my_pipeline = DiffusionPipeline.from_pretrained(
|
||||||
|
"google/ddpm-cifar10-32", custom_pipeline="patrickvonplaten/my_custom_pipeline"
|
||||||
|
)
|
||||||
|
```
|
|
@ -0,0 +1,102 @@
|
||||||
|
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
|
class CustomPipeline(DiffusionPipeline):
|
||||||
|
r"""
|
||||||
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||||
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
|
||||||
|
scheduler ([`SchedulerMixin`]):
|
||||||
|
A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
|
||||||
|
[`DDPMScheduler`], or [`DDIMScheduler`].
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, unet, scheduler):
|
||||||
|
super().__init__()
|
||||||
|
self.register_modules(unet=unet, scheduler=scheduler)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
batch_size: int = 1,
|
||||||
|
generator: Optional[torch.Generator] = None,
|
||||||
|
eta: float = 0.0,
|
||||||
|
num_inference_steps: int = 50,
|
||||||
|
output_type: Optional[str] = "pil",
|
||||||
|
return_dict: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[ImagePipelineOutput, Tuple]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
batch_size (`int`, *optional*, defaults to 1):
|
||||||
|
The number of images to generate.
|
||||||
|
generator (`torch.Generator`, *optional*):
|
||||||
|
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||||
|
deterministic.
|
||||||
|
eta (`float`, *optional*, defaults to 0.0):
|
||||||
|
The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM).
|
||||||
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||||
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||||
|
expense of slower inference.
|
||||||
|
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||||
|
The output format of the generate image. Choose between
|
||||||
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
|
||||||
|
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
||||||
|
generated images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Sample gaussian noise to begin loop
|
||||||
|
image = torch.randn(
|
||||||
|
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||||
|
generator=generator,
|
||||||
|
)
|
||||||
|
image = image.to(self.device)
|
||||||
|
|
||||||
|
# set step values
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
|
|
||||||
|
for t in self.progress_bar(self.scheduler.timesteps):
|
||||||
|
# 1. predict noise model_output
|
||||||
|
model_output = self.unet(image, t).sample
|
||||||
|
|
||||||
|
# 2. predict previous mean of image x_t-1 and add variance depending on eta
|
||||||
|
# eta corresponds to η in paper and should be between [0, 1]
|
||||||
|
# do x_t -> x_t-1
|
||||||
|
image = self.scheduler.step(model_output, t, image, eta).prev_sample
|
||||||
|
|
||||||
|
image = (image / 2 + 0.5).clamp(0, 1)
|
||||||
|
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||||
|
if output_type == "pil":
|
||||||
|
image = self.numpy_to_pil(image)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (image,)
|
||||||
|
|
||||||
|
return ImagePipelineOutput(images=image), "This is a test"
|
|
@ -0,0 +1 @@
|
||||||
|
b8fa12635e53eebebc22f95ee863e7af4fc2fb07
|
|
@ -0,0 +1 @@
|
||||||
|
../../blobs/bbbcb9f65616524d6199fa3bc16dc0500fb2cbbb
|
|
@ -1,5 +1,5 @@
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2021 The HuggingFace Inc. team.
|
# Copyright 2022 The HuggingFace Inc. team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -15,6 +15,7 @@
|
||||||
"""Utilities to dynamically load objects from the Hub."""
|
"""Utilities to dynamically load objects from the Hub."""
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
|
import inspect
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
|
@ -22,11 +23,16 @@ import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
from huggingface_hub import cached_download
|
from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info
|
||||||
|
|
||||||
from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
|
from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
|
||||||
|
|
||||||
|
|
||||||
|
COMMUNITY_PIPELINES_URL = (
|
||||||
|
"https://raw.githubusercontent.com/huggingface/diffusers/main/examples/community/{pipeline}.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
@ -145,9 +151,35 @@ def get_class_in_module(class_name, module_path):
|
||||||
"""
|
"""
|
||||||
module_path = module_path.replace(os.path.sep, ".")
|
module_path = module_path.replace(os.path.sep, ".")
|
||||||
module = importlib.import_module(module_path)
|
module = importlib.import_module(module_path)
|
||||||
|
|
||||||
|
if class_name is None:
|
||||||
|
return find_pipeline_class(module)
|
||||||
return getattr(module, class_name)
|
return getattr(module, class_name)
|
||||||
|
|
||||||
|
|
||||||
|
def find_pipeline_class(loaded_module):
|
||||||
|
"""
|
||||||
|
Retrieve pipeline class that inherits from `DiffusionPipeline`. Note that there has to be exactly one class
|
||||||
|
inheriting from `DiffusionPipeline`.
|
||||||
|
"""
|
||||||
|
from .pipeline_utils import DiffusionPipeline
|
||||||
|
|
||||||
|
cls_members = dict(inspect.getmembers(loaded_module, inspect.isclass))
|
||||||
|
|
||||||
|
pipeline_class = None
|
||||||
|
for cls_name, cls in cls_members.items():
|
||||||
|
if cls_name != DiffusionPipeline.__name__ and issubclass(cls, DiffusionPipeline):
|
||||||
|
if pipeline_class is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Multiple classes that inherit from {DiffusionPipeline.__name__} have been found:"
|
||||||
|
f" {pipeline_class.__name__}, and {cls_name}. Please make sure to define only one in"
|
||||||
|
f" {loaded_module}."
|
||||||
|
)
|
||||||
|
pipeline_class = cls
|
||||||
|
|
||||||
|
return pipeline_class
|
||||||
|
|
||||||
|
|
||||||
def get_cached_module_file(
|
def get_cached_module_file(
|
||||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||||
module_file: str,
|
module_file: str,
|
||||||
|
@ -208,16 +240,35 @@ def get_cached_module_file(
|
||||||
"""
|
"""
|
||||||
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
|
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
|
||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
|
|
||||||
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
|
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
|
||||||
submodule = "local"
|
|
||||||
|
|
||||||
if os.path.isfile(module_file_or_url):
|
if os.path.isfile(module_file_or_url):
|
||||||
resolved_module_file = module_file_or_url
|
resolved_module_file = module_file_or_url
|
||||||
|
submodule = "local"
|
||||||
|
elif pretrained_model_name_or_path.count("/") == 0:
|
||||||
|
# community pipeline on GitHub
|
||||||
|
github_url = COMMUNITY_PIPELINES_URL.format(pipeline=pretrained_model_name_or_path)
|
||||||
|
try:
|
||||||
|
resolved_module_file = cached_download(
|
||||||
|
github_url,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=False,
|
||||||
|
)
|
||||||
|
submodule = "local"
|
||||||
|
except EnvironmentError:
|
||||||
|
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
|
||||||
|
raise
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
# Load from URL or cache if already cached
|
# Load from URL or cache if already cached
|
||||||
resolved_module_file = cached_download(
|
resolved_module_file = hf_hub_download(
|
||||||
module_file_or_url,
|
pretrained_model_name_or_path,
|
||||||
|
module_file,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
force_download=force_download,
|
force_download=force_download,
|
||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
|
@ -225,7 +276,7 @@ def get_cached_module_file(
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
)
|
)
|
||||||
|
submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/")))
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
|
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
|
||||||
raise
|
raise
|
||||||
|
@ -237,20 +288,55 @@ def get_cached_module_file(
|
||||||
full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
|
full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
|
||||||
create_dynamic_module(full_submodule)
|
create_dynamic_module(full_submodule)
|
||||||
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
|
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
|
||||||
# We always copy local files (we could hash the file to see if there was a change, and give them the name of
|
if submodule == "local":
|
||||||
# that hash, to only copy when there is a modification but it seems overkill for now).
|
# We always copy local files (we could hash the file to see if there was a change, and give them the name of
|
||||||
# The only reason we do the copy is to avoid putting too many folders in sys.path.
|
# that hash, to only copy when there is a modification but it seems overkill for now).
|
||||||
shutil.copy(resolved_module_file, submodule_path / module_file)
|
# The only reason we do the copy is to avoid putting too many folders in sys.path.
|
||||||
for module_needed in modules_needed:
|
shutil.copy(resolved_module_file, submodule_path / module_file)
|
||||||
module_needed = f"{module_needed}.py"
|
for module_needed in modules_needed:
|
||||||
shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
|
module_needed = f"{module_needed}.py"
|
||||||
|
shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
|
||||||
|
else:
|
||||||
|
# Get the commit hash
|
||||||
|
# TODO: we will get this info in the etag soon, so retrieve it from there and not here.
|
||||||
|
if isinstance(use_auth_token, str):
|
||||||
|
token = use_auth_token
|
||||||
|
elif use_auth_token is True:
|
||||||
|
token = HfFolder.get_token()
|
||||||
|
else:
|
||||||
|
token = None
|
||||||
|
|
||||||
|
commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha
|
||||||
|
|
||||||
|
# The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
|
||||||
|
# benefit of versioning.
|
||||||
|
submodule_path = submodule_path / commit_hash
|
||||||
|
full_submodule = full_submodule + os.path.sep + commit_hash
|
||||||
|
create_dynamic_module(full_submodule)
|
||||||
|
|
||||||
|
if not (submodule_path / module_file).exists():
|
||||||
|
shutil.copy(resolved_module_file, submodule_path / module_file)
|
||||||
|
# Make sure we also have every file with relative
|
||||||
|
for module_needed in modules_needed:
|
||||||
|
if not (submodule_path / module_needed).exists():
|
||||||
|
get_cached_module_file(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
f"{module_needed}.py",
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
revision=revision,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
return os.path.join(full_submodule, module_file)
|
return os.path.join(full_submodule, module_file)
|
||||||
|
|
||||||
|
|
||||||
def get_class_from_dynamic_module(
|
def get_class_from_dynamic_module(
|
||||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||||
module_file: str,
|
module_file: str,
|
||||||
class_name: str,
|
class_name: Optional[str] = None,
|
||||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||||
force_download: bool = False,
|
force_download: bool = False,
|
||||||
resume_download: bool = False,
|
resume_download: bool = False,
|
||||||
|
|
|
@ -30,11 +30,13 @@ from PIL import Image
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
from .configuration_utils import ConfigMixin
|
from .configuration_utils import ConfigMixin
|
||||||
|
from .dynamic_modules_utils import get_class_from_dynamic_module
|
||||||
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||||
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
|
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
|
||||||
|
|
||||||
|
|
||||||
INDEX_FILE = "diffusion_pytorch_model.bin"
|
INDEX_FILE = "diffusion_pytorch_model.bin"
|
||||||
|
CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
@ -214,6 +216,52 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||||
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
||||||
will be automatically derived from the model's weights.
|
will be automatically derived from the model's weights.
|
||||||
|
custom_pipeline (`str`, *optional*):
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
This is an experimental feature and is likely to change in the future.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Can be either:
|
||||||
|
|
||||||
|
- A string, the *repo id* of a custom pipeline hosted inside a model repo on
|
||||||
|
https://huggingface.co/. Valid repo ids have to be located under a user or organization name,
|
||||||
|
like `hf-internal-testing/diffusers-dummy-pipeline`.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
It is required that the model repo has a file, called `pipeline.py` that defines the custom
|
||||||
|
pipeline.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
- A string, the *file name* of a community pipeline hosted on GitHub under
|
||||||
|
https://github.com/huggingface/diffusers/tree/main/examples/community. Valid file names have to
|
||||||
|
match exactly the file name without `.py` located under the above link, *e.g.*
|
||||||
|
`clip_guided_stable_diffusion`.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
Community pipelines are always loaded from the current `main` branch of GitHub.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
- A path to a *directory* containing a custom pipeline, e.g., `./my_pipeline_directory/`.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
It is required that the directory has a file, called `pipeline.py` that defines the custom
|
||||||
|
pipeline.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
For more information on how to load and create custom pipelines, please have a look at [Loading and
|
||||||
|
Creating Custom
|
||||||
|
Pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/custom_pipelines)
|
||||||
|
|
||||||
|
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||||
force_download (`bool`, *optional*, defaults to `False`):
|
force_download (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||||
cached versions if they exist.
|
cached versions if they exist.
|
||||||
|
@ -285,6 +333,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||||
revision = kwargs.pop("revision", None)
|
revision = kwargs.pop("revision", None)
|
||||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||||
|
custom_pipeline = kwargs.pop("custom_pipeline", None)
|
||||||
provider = kwargs.pop("provider", None)
|
provider = kwargs.pop("provider", None)
|
||||||
sess_options = kwargs.pop("sess_options", None)
|
sess_options = kwargs.pop("sess_options", None)
|
||||||
|
|
||||||
|
@ -305,6 +354,9 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
allow_patterns = [os.path.join(k, "*") for k in folder_names]
|
allow_patterns = [os.path.join(k, "*") for k in folder_names]
|
||||||
allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
|
allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
|
||||||
|
|
||||||
|
if custom_pipeline is not None:
|
||||||
|
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
|
||||||
|
|
||||||
# download all allow_patterns
|
# download all allow_patterns
|
||||||
cached_folder = snapshot_download(
|
cached_folder = snapshot_download(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
|
@ -323,7 +375,11 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
|
|
||||||
# 2. Load the pipeline class, if using custom module then load it from the hub
|
# 2. Load the pipeline class, if using custom module then load it from the hub
|
||||||
# if we load from explicit class, let's use it
|
# if we load from explicit class, let's use it
|
||||||
if cls != DiffusionPipeline:
|
if custom_pipeline is not None:
|
||||||
|
pipeline_class = get_class_from_dynamic_module(
|
||||||
|
custom_pipeline, module_file=CUSTOM_PIPELINE_FILE_NAME, cache_dir=custom_pipeline
|
||||||
|
)
|
||||||
|
elif cls != DiffusionPipeline:
|
||||||
pipeline_class = cls
|
pipeline_class = cls
|
||||||
else:
|
else:
|
||||||
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
|
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
|
||||||
|
@ -332,7 +388,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
# some modules can be passed directly to the init
|
# some modules can be passed directly to the init
|
||||||
# in this case they are already instantiated in `kwargs`
|
# in this case they are already instantiated in `kwargs`
|
||||||
# extract them here
|
# extract them here
|
||||||
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys())
|
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"])
|
||||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||||
|
|
||||||
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||||
|
@ -414,7 +470,18 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
|
|
||||||
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
||||||
|
|
||||||
# 4. Instantiate the pipeline
|
# 4. Potentially add passed objects if expected
|
||||||
|
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
||||||
|
if len(missing_modules) > 0 and missing_modules <= set(passed_class_obj.keys()):
|
||||||
|
for module in missing_modules:
|
||||||
|
init_kwargs[module] = passed_class_obj[module]
|
||||||
|
elif len(missing_modules) > 0:
|
||||||
|
passed_modules = set(init_kwargs.keys() + passed_class_obj.keys())
|
||||||
|
raise ValueError(
|
||||||
|
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. Instantiate the pipeline
|
||||||
model = pipeline_class(**init_kwargs)
|
model = pipeline_class(**init_kwargs)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import inspect
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
@ -22,6 +23,27 @@ if is_torch_higher_equal_than_1_12:
|
||||||
torch_device = "mps" if torch.backends.mps.is_available() else torch_device
|
torch_device = "mps" if torch.backends.mps.is_available() else torch_device
|
||||||
|
|
||||||
|
|
||||||
|
def get_tests_dir(append_path=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
append_path: optional path to append to the tests dir path
|
||||||
|
Return:
|
||||||
|
The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is
|
||||||
|
joined after the `tests` dir the former is provided.
|
||||||
|
"""
|
||||||
|
# this function caller's __file__
|
||||||
|
caller__file__ = inspect.stack()[1][1]
|
||||||
|
tests_dir = os.path.abspath(os.path.dirname(caller__file__))
|
||||||
|
|
||||||
|
while not tests_dir.endswith("tests"):
|
||||||
|
tests_dir = os.path.dirname(tests_dir)
|
||||||
|
|
||||||
|
if append_path:
|
||||||
|
return os.path.join(tests_dir, append_path)
|
||||||
|
else:
|
||||||
|
return tests_dir
|
||||||
|
|
||||||
|
|
||||||
def parse_flag_from_env(key, default=False):
|
def parse_flag_from_env(key, default=False):
|
||||||
try:
|
try:
|
||||||
value = os.environ[key]
|
value = os.environ[key]
|
||||||
|
|
|
@ -0,0 +1,102 @@
|
||||||
|
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
|
class CustomLocalPipeline(DiffusionPipeline):
|
||||||
|
r"""
|
||||||
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||||
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
|
||||||
|
scheduler ([`SchedulerMixin`]):
|
||||||
|
A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
|
||||||
|
[`DDPMScheduler`], or [`DDIMScheduler`].
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, unet, scheduler):
|
||||||
|
super().__init__()
|
||||||
|
self.register_modules(unet=unet, scheduler=scheduler)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
batch_size: int = 1,
|
||||||
|
generator: Optional[torch.Generator] = None,
|
||||||
|
eta: float = 0.0,
|
||||||
|
num_inference_steps: int = 50,
|
||||||
|
output_type: Optional[str] = "pil",
|
||||||
|
return_dict: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[ImagePipelineOutput, Tuple]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
batch_size (`int`, *optional*, defaults to 1):
|
||||||
|
The number of images to generate.
|
||||||
|
generator (`torch.Generator`, *optional*):
|
||||||
|
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||||
|
deterministic.
|
||||||
|
eta (`float`, *optional*, defaults to 0.0):
|
||||||
|
The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM).
|
||||||
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||||
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||||
|
expense of slower inference.
|
||||||
|
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||||
|
The output format of the generate image. Choose between
|
||||||
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
|
||||||
|
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
||||||
|
generated images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Sample gaussian noise to begin loop
|
||||||
|
image = torch.randn(
|
||||||
|
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||||
|
generator=generator,
|
||||||
|
)
|
||||||
|
image = image.to(self.device)
|
||||||
|
|
||||||
|
# set step values
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
|
|
||||||
|
for t in self.progress_bar(self.scheduler.timesteps):
|
||||||
|
# 1. predict noise model_output
|
||||||
|
model_output = self.unet(image, t).sample
|
||||||
|
|
||||||
|
# 2. predict previous mean of image x_t-1 and add variance depending on eta
|
||||||
|
# eta corresponds to η in paper and should be between [0, 1]
|
||||||
|
# do x_t -> x_t-1
|
||||||
|
image = self.scheduler.step(model_output, t, image, eta).prev_sample
|
||||||
|
|
||||||
|
image = (image / 2 + 0.5).clamp(0, 1)
|
||||||
|
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||||
|
if output_type == "pil":
|
||||||
|
image = self.numpy_to_pil(image)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (image,), "This is a local test"
|
||||||
|
|
||||||
|
return ImagePipelineOutput(images=image), "This is a local test"
|
|
@ -49,8 +49,9 @@ from diffusers import (
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
from diffusers.pipeline_utils import DiffusionPipeline
|
||||||
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||||
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device
|
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device
|
||||||
|
from diffusers.utils.testing_utils import get_tests_dir
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = False
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
@ -79,6 +80,60 @@ def test_progress_bar(capsys):
|
||||||
assert captured.err == "", "Progress bar should be disabled"
|
assert captured.err == "", "Progress bar should be disabled"
|
||||||
|
|
||||||
|
|
||||||
|
class CustomPipelineTests(unittest.TestCase):
|
||||||
|
def test_load_custom_pipeline(self):
|
||||||
|
pipeline = DiffusionPipeline.from_pretrained(
|
||||||
|
"google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
|
||||||
|
)
|
||||||
|
# NOTE that `"CustomPipeline"` is not a class that is defined in this library, but solely on the Hub
|
||||||
|
# under https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L24
|
||||||
|
assert pipeline.__class__.__name__ == "CustomPipeline"
|
||||||
|
|
||||||
|
def test_run_custom_pipeline(self):
|
||||||
|
pipeline = DiffusionPipeline.from_pretrained(
|
||||||
|
"google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
|
||||||
|
)
|
||||||
|
images, output_str = pipeline(num_inference_steps=2, output_type="np")
|
||||||
|
|
||||||
|
assert images[0].shape == (1, 32, 32, 3)
|
||||||
|
# compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102
|
||||||
|
assert output_str == "This is a test"
|
||||||
|
|
||||||
|
def test_local_custom_pipeline(self):
|
||||||
|
local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline")
|
||||||
|
pipeline = DiffusionPipeline.from_pretrained(
|
||||||
|
"google/ddpm-cifar10-32", custom_pipeline=local_custom_pipeline_path
|
||||||
|
)
|
||||||
|
images, output_str = pipeline(num_inference_steps=2, output_type="np")
|
||||||
|
|
||||||
|
assert pipeline.__class__.__name__ == "CustomLocalPipeline"
|
||||||
|
assert images[0].shape == (1, 32, 32, 3)
|
||||||
|
# compare to https://github.com/huggingface/diffusers/blob/main/tests/fixtures/custom_pipeline/pipeline.py#L102
|
||||||
|
assert output_str == "This is a local test"
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_load_pipeline_from_git(self):
|
||||||
|
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
||||||
|
|
||||||
|
feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)
|
||||||
|
clip_model = CLIPModel.from_pretrained(clip_model_id)
|
||||||
|
|
||||||
|
pipeline = DiffusionPipeline.from_pretrained(
|
||||||
|
"CompVis/stable-diffusion-v1-4",
|
||||||
|
custom_pipeline="clip_guided_stable_diffusion",
|
||||||
|
clip_model=clip_model,
|
||||||
|
feature_extractor=feature_extractor,
|
||||||
|
)
|
||||||
|
pipeline = pipeline.to(torch_device)
|
||||||
|
|
||||||
|
# NOTE that `"CLIPGuidedStableDiffusion"` is not a class that is defined in the pypi package of th e library, but solely on the community examples folder of GitHub under:
|
||||||
|
# https://github.com/huggingface/diffusers/blob/main/examples/community/clip_guided_stable_diffusion.py
|
||||||
|
assert pipeline.__class__.__name__ == "CLIPGuidedStableDiffusion"
|
||||||
|
|
||||||
|
image = pipeline("a prompt", num_inference_steps=2, output_type="np").images[0]
|
||||||
|
assert image.shape == (512, 512, 3)
|
||||||
|
|
||||||
|
|
||||||
class PipelineFastTests(unittest.TestCase):
|
class PipelineFastTests(unittest.TestCase):
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
# clean up the VRAM after each test
|
# clean up the VRAM after each test
|
||||||
|
|
Loading…
Reference in New Issue