Add CycleDiffusion pipeline using Stable Diffusion (#888)

* Add CycleDiffusion pipeline for Stable Diffusion

* Add the option of passing noise to DDIMScheduler

Add the option of providing the noise itself to DDIMScheduler, instead of the random seed generator.

* Update README.md

* Update README.md

* Update pipeline_stable_diffusion_cycle_diffusion.py

* Update pipeline_stable_diffusion_cycle_diffusion.py

* Update pipeline_stable_diffusion_cycle_diffusion.py

* Update pipeline_stable_diffusion_cycle_diffusion.py

* Update scheduling_ddim.py

* Update import format

* Update pipeline_stable_diffusion_cycle_diffusion.py

* Update scheduling_ddim.py

* Update src/diffusers/schedulers/scheduling_ddim.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_ddim.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_ddim.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_ddim.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_ddim.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update scheduling_ddim.py

* Update scheduling_ddim.py

* Update scheduling_ddim.py

* add two tests

* Update pipeline_stable_diffusion_cycle_diffusion.py

* Update pipeline_stable_diffusion_cycle_diffusion.py

* Update README.md

* Rename pipeline name as suggested in the latest reviewer comment

* Update test_pipelines.py

* Update test_pipelines.py

* Update test_pipelines.py

* Update pipeline_stable_diffusion_cycle_diffusion.py

* Remove the generator

This generator does not control all randomness during sampling, which can be misleading.

* Update optimal hyperparameters

* Update src/diffusers/pipelines/stable_diffusion/README.md

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Update src/diffusers/pipelines/stable_diffusion/README.md

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Update src/diffusers/pipelines/stable_diffusion/README.md

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Apply suggestions from code review

* uP

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_cycle_diffusion.py

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* up

* up

* Replace assert with ValueError

* finish docs

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Chen Wu (吴尘) 2022-11-04 15:51:06 -04:00 committed by GitHub
parent 1172c9634b
commit 9d8943b7e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 1097 additions and 15 deletions

View File

@ -78,6 +78,8 @@
- sections:
- local: api/pipelines/overview
title: "Overview"
- local: api/pipelines/cycle_diffusion
title: "Cycle Diffusion"
- local: api/pipelines/ddim
title: "DDIM"
- local: api/pipelines/ddpm

View File

@ -0,0 +1,99 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Cycle Diffusion
## Overview
Cycle Diffusion is a Text-Guided Image-to-Image Generation model proposed in [Unifying Diffusion Models' Latent Space, with Applications to CycleDiffusion and Guidance](https://arxiv.org/abs/2210.05559) by Chen Henry Wu, Fernando De la Torre.
The abstract of the paper is the following:
*Diffusion models have achieved unprecedented performance in generative modeling. The commonly-adopted formulation of the latent code of diffusion models is a sequence of gradually denoised samples, as opposed to the simpler (e.g., Gaussian) latent space of GANs, VAEs, and normalizing flows. This paper provides an alternative, Gaussian formulation of the latent space of various diffusion models, as well as an invertible DPM-Encoder that maps images into the latent space. While our formulation is purely based on the definition of diffusion models, we demonstrate several intriguing consequences. (1) Empirically, we observe that a common latent space emerges from two diffusion models trained independently on related domains. In light of this finding, we propose CycleDiffusion, which uses DPM-Encoder for unpaired image-to-image translation. Furthermore, applying CycleDiffusion to text-to-image diffusion models, we show that large-scale text-to-image diffusion models can be used as zero-shot image-to-image editors. (2) One can guide pre-trained diffusion models and GANs by controlling the latent codes in a unified, plug-and-play formulation based on energy-based models. Using the CLIP model and a face recognition model as guidance, we demonstrate that diffusion models have better coverage of low-density sub-populations and individuals than GANs.*
*Tips*:
- The Cycle Diffusion pipeline is fully compatible with any [Stable Diffusion](./stable_diffusion) checkpoints
- Currently Cycle Diffusion only works with the [`DDIMScheduler`].
*Example*:
In the following we should how to best use the [`CycleDiffusionPipeline`]
```python
import requests
import torch
from PIL import Image
from io import BytesIO
from diffusers import CycleDiffusionPipeline, DDIMScheduler
# load the pipeline
# make sure you're logged in with `huggingface-cli login`
model_id_or_path = "CompVis/stable-diffusion-v1-4"
scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler")
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda")
# let's download an initial image
url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/An%20astronaut%20riding%20a%20horse.png"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((512, 512))
init_image.save("horse.png")
# let's specify a prompt
source_prompt = "An astronaut riding a horse"
prompt = "An astronaut riding an elephant"
# call the pipeline
image = pipe(
prompt=prompt,
source_prompt=source_prompt,
init_image=init_image,
num_inference_steps=100,
eta=0.1,
strength=0.8,
guidance_scale=2,
source_guidance_scale=1,
).images[0]
image.save("horse_to_elephant.png")
# let's try another example
# See more samples at the original repo: https://github.com/ChenWu98/cycle-diffusion
url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/A%20black%20colored%20car.png"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((512, 512))
init_image.save("black.png")
source_prompt = "A black colored car"
prompt = "A blue colored car"
# call the pipeline
torch.manual_seed(0)
image = pipe(
prompt=prompt,
source_prompt=source_prompt,
init_image=init_image,
num_inference_steps=100,
eta=0.1,
strength=0.85,
guidance_scale=3,
source_guidance_scale=1,
).images[0]
image.save("black_to_blue.png")
```
## CycleDiffusionPipeline
[[autodoc]] CycleDiffusionPipeline
- __call__

View File

@ -41,21 +41,24 @@ If you are looking for *official* training examples, please have a look at [exam
The following table summarizes all officially supported pipelines, their corresponding paper, and if
available a colab notebook to directly try them out.
| Pipeline | Paper | Tasks | Colab
|---|---|:---:|:---:|
| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
| [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [latent_diffusion](./latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
| [latent_diffusion_uncond](./latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
| [pndm](./pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
| [score_sde_ve](./score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [score_sde_vp](./score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stochastic_karras_ve](./stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
| [vq_diffusion](./vq_diffusion) | [**Vector Quantized Diffusion Model for Text-to-Image Synthesis**](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
| [repaint](./repaint) | [**RePaint: Inpainting using Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2201.09865) | Image Inpainting |
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
| [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
| [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [score_sde_vp](./api/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.

View File

@ -34,6 +34,7 @@ available a colab notebook to directly try them out.
| Pipeline | Paper | Tasks | Colab
|---|---|:---:|:---:|
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |

View File

@ -63,6 +63,7 @@ else:
if is_torch_available() and is_transformers_available():
from .pipelines import (
CycleDiffusionPipeline,
LDMTextToImagePipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,

View File

@ -16,6 +16,7 @@ else:
if is_torch_available() and is_transformers_available():
from .latent_diffusion import LDMTextToImagePipeline
from .stable_diffusion import (
CycleDiffusionPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy,

View File

@ -103,3 +103,74 @@ image = pipe(prompt).sample[0]
image.save("astronaut_rides_horse.png")
```
### CycleDiffusion using Stable Diffusion and DDIM scheduler
```python
import requests
import torch
from PIL import Image
from io import BytesIO
from diffusers import CycleDiffusionPipeline, DDIMScheduler
# load the scheduler. CycleDiffusion only supports stochastic schedulers.
# load the pipeline
# make sure you're logged in with `huggingface-cli login`
model_id_or_path = "CompVis/stable-diffusion-v1-4"
scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler")
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda")
# let's download an initial image
url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/An%20astronaut%20riding%20a%20horse.png"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((512, 512))
init_image.save("horse.png")
# let's specify a prompt
source_prompt = "An astronaut riding a horse"
prompt = "An astronaut riding an elephant"
# call the pipeline
image = pipe(
prompt=prompt,
source_prompt=source_prompt,
init_image=init_image,
num_inference_steps=100,
eta=0.1,
strength=0.8,
guidance_scale=2,
source_guidance_scale=1,
).images[0]
image.save("horse_to_elephant.png")
# let's try another example
# See more samples at the original repo: https://github.com/ChenWu98/cycle-diffusion
url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/A%20black%20colored%20car.png"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((512, 512))
init_image.save("black.png")
source_prompt = "A black colored car"
prompt = "A blue colored car"
# call the pipeline
torch.manual_seed(0)
image = pipe(
prompt=prompt,
source_prompt=source_prompt,
init_image=init_image,
num_inference_steps=100,
eta=0.1,
strength=0.85,
guidance_scale=3,
source_guidance_scale=1,
).images[0]
image.save("black_to_blue.png")
```

View File

@ -28,6 +28,7 @@ class StableDiffusionPipelineOutput(BaseOutput):
if is_transformers_available() and is_torch_available():
from .pipeline_cycle_diffusion import CycleDiffusionPipeline
from .pipeline_stable_diffusion import StableDiffusionPipeline
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline

View File

@ -0,0 +1,527 @@
import inspect
from typing import Callable, List, Optional, Union
import numpy as np
import torch
import PIL
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler
from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
def posterior_sample(scheduler, latents, timestep, clean_latents, eta):
# 1. get previous step value (=t-1)
prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
if prev_timestep <= 0:
return clean_latents
# 2. compute alphas, betas
alpha_prod_t = scheduler.alphas_cumprod[timestep]
alpha_prod_t_prev = (
scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
)
variance = scheduler._get_variance(timestep, prev_timestep)
std_dev_t = eta * variance ** (0.5)
# direction pointing to x_t
e_t = (latents - alpha_prod_t ** (0.5) * clean_latents) / (1 - alpha_prod_t) ** (0.5)
dir_xt = (1.0 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * e_t
noise = std_dev_t * torch.randn(clean_latents.shape, dtype=clean_latents.dtype, device=clean_latents.device)
prev_latents = alpha_prod_t_prev ** (0.5) * clean_latents + dir_xt + noise
return prev_latents
def compute_noise(scheduler, prev_latents, latents, timestep, noise_pred, eta):
# 1. get previous step value (=t-1)
prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
# 2. compute alphas, betas
alpha_prod_t = scheduler.alphas_cumprod[timestep]
alpha_prod_t_prev = (
scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
)
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
# 4. Clip "predicted x_0"
if scheduler.config.clip_sample:
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 α_t1)/(1 α_t)) * sqrt(1 α_t/α_t1)
variance = scheduler._get_variance(timestep, prev_timestep)
std_dev_t = eta * variance ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * noise_pred
noise = (prev_latents - (alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction)) / (
variance ** (0.5) * eta
)
return noise
class CycleDiffusionPipeline(DiffusionPipeline):
r"""
Pipeline for text-guided image to image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `set_attention_slice`
self.enable_attention_slicing(None)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
source_prompt: Union[str, List[str]],
init_image: Union[torch.FloatTensor, PIL.Image.Image],
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
source_guidance_scale: Optional[float] = 1,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.1,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
init_image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process.
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
`init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
noise will be maximum and the denoising process will run for the full number of iterations specified in
`num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter will be modulated by `strength`.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
source_guidance_scale (`float`, *optional*, defaults to 1):
Guidance scale for the source prompt. This is useful to control the amount of influence the source
prompt for encoding.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.1):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if batch_size != 1:
raise ValueError(
"At the moment only `batch_size=1` is supported for prompts, but you seem to have passed multiple"
f" prompts: {prompt}. Please make sure to pass only a single prompt."
)
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
if isinstance(init_image, PIL.Image.Image):
init_image = preprocess(init_image)
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
source_text_inputs = self.tokenizer(
source_prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
source_text_input_ids = source_text_inputs.input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
if source_text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(source_text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
source_text_input_ids = source_text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
source_text_embeddings = self.text_encoder(source_text_input_ids.to(self.device))[0]
# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
source_text_embeddings = source_text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
# get unconditional embeddings for classifier free guidance
uncond_tokens = [""]
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# duplicate unconditional embeddings for each generation per prompt
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
source_uncond_tokens = [""]
max_length = source_text_input_ids.shape[-1]
source_uncond_input = self.tokenizer(
source_uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
source_uncond_embeddings = self.text_encoder(source_uncond_input.input_ids.to(self.device))[0]
# duplicate unconditional embeddings for each generation per prompt
source_uncond_embeddings = source_uncond_embeddings.repeat_interleave(
batch_size * num_images_per_prompt, dim=0
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
source_text_embeddings = torch.cat([source_uncond_embeddings, source_text_embeddings])
# encode the init image into latents and scale the latents
latents_dtype = text_embeddings.dtype
init_image = init_image.to(device=self.device, dtype=latents_dtype)
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
if isinstance(prompt, str):
prompt = [prompt]
if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0:
# expand init_latents for batch_size
deprecation_message = (
f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
" images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note"
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
" your script to pass as many init images as text prompts to suppress this warning."
)
deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False)
additional_image_per_prompt = len(prompt) // init_latents.shape[0]
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts."
)
else:
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
clean_latents = init_latents
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if not (accepts_eta and (0 < eta <= 1)):
raise ValueError(
"Currently, only the DDIM scheduler is supported. Please make sure that `pipeline.scheduler` is of"
f" type {DDIMScheduler.__class__} and not {self.scheduler.__class__}."
)
extra_step_kwargs["eta"] = eta
latents = init_latents
source_latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0)
# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2)
source_latent_model_input = torch.cat([source_latents] * 2)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t)
# predict the noise residual
concat_latent_model_input = torch.stack(
[
source_latent_model_input[0],
latent_model_input[0],
source_latent_model_input[1],
latent_model_input[1],
],
dim=0,
)
concat_text_embeddings = torch.stack(
[
source_text_embeddings[0],
text_embeddings[0],
source_text_embeddings[1],
text_embeddings[1],
],
dim=0,
)
concat_noise_pred = self.unet(
concat_latent_model_input, t, encoder_hidden_states=concat_text_embeddings
).sample
# perform guidance
(
source_noise_pred_uncond,
noise_pred_uncond,
source_noise_pred_text,
noise_pred_text,
) = concat_noise_pred.chunk(4, dim=0)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
source_noise_pred = source_noise_pred_uncond + source_guidance_scale * (
source_noise_pred_text - source_noise_pred_uncond
)
# Sample source_latents from the posterior distribution.
prev_source_latents = posterior_sample(
self.scheduler, source_latents, t, clean_latents, **extra_step_kwargs
)
# Compute noise.
noise = compute_noise(
self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs
)
source_latents = prev_source_latents
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs
).prev_sample
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
self.device
)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
)
else:
has_nsfw_concept = None
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

View File

@ -208,6 +208,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
variance_noise: Optional[torch.FloatTensor] = None,
return_dict: bool = True,
) -> Union[DDIMSchedulerOutput, Tuple]:
"""
@ -225,6 +226,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
`self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
coincide with the one provided as input and `use_clipped_model_output` will have not effect.
generator: random number generator.
variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
can directly provide the noise for the variance itself. This is useful for methods such as
CycleDiffusion. (https://arxiv.org/abs/2210.05559)
return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
Returns:
@ -284,8 +288,17 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
if eta > 0:
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
device = model_output.device if torch.is_tensor(model_output) else "cpu"
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
if variance_noise is not None and generator is not None:
raise ValueError(
"Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
" `variance_noise` stays `None`."
)
if variance_noise is None:
variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(
device
)
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise
prev_sample = prev_sample + variance

View File

@ -4,6 +4,21 @@
from ..utils import DummyObject, requires_backends
class CycleDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LDMTextToImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@ -0,0 +1,348 @@
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import random
import unittest
import numpy as np
import torch
from diffusers import AutoencoderKL, CycleDiffusionPipeline, DDIMScheduler, UNet2DConditionModel, UNet2DModel, VQModel
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from ...test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
class CycleDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
@property
def dummy_image(self):
batch_size = 1
num_channels = 3
sizes = (32, 32)
image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
return image
@property
def dummy_uncond_unet(self):
torch.manual_seed(0)
model = UNet2DModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=3,
out_channels=3,
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
return model
@property
def dummy_cond_unet(self):
torch.manual_seed(0)
model = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
)
return model
@property
def dummy_cond_unet_inpaint(self):
torch.manual_seed(0)
model = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=9,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
)
return model
@property
def dummy_vq_model(self):
torch.manual_seed(0)
model = VQModel(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=3,
)
return model
@property
def dummy_vae(self):
torch.manual_seed(0)
model = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
)
return model
@property
def dummy_text_encoder(self):
torch.manual_seed(0)
config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
return CLIPTextModel(config)
@property
def dummy_extractor(self):
def extract(*args, **kwargs):
class Out:
def __init__(self):
self.pixel_values = torch.ones([0])
def to(self, device):
self.pixel_values.to(device)
return self
return Out()
return extract
def test_stable_diffusion_cycle(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
clip_sample=False,
set_alpha_to_one=False,
)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
# make sure here that pndm scheduler skips prk
sd_pipe = CycleDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
source_prompt = "An astronaut riding a horse"
prompt = "An astronaut riding an elephant"
init_image = self.dummy_image.to(device)
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe(
prompt=prompt,
source_prompt=source_prompt,
generator=generator,
num_inference_steps=2,
init_image=init_image,
eta=0.1,
strength=0.8,
guidance_scale=3,
source_guidance_scale=1,
output_type="np",
)
images = output.images
image_slice = images[0, -3:, -3:, -1]
assert images.shape == (1, 32, 32, 3)
expected_slice = np.array([0.4459, 0.4943, 0.4544, 0.6643, 0.5474, 0.4327, 0.5701, 0.5959, 0.5179])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
def test_stable_diffusion_cycle_fp16(self):
unet = self.dummy_cond_unet
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
clip_sample=False,
set_alpha_to_one=False,
)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
unet = unet.half()
vae = vae.half()
bert = bert.half()
# make sure here that pndm scheduler skips prk
sd_pipe = CycleDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
source_prompt = "An astronaut riding a horse"
prompt = "An astronaut riding an elephant"
init_image = self.dummy_image.to(torch_device)
generator = torch.Generator(device=torch_device).manual_seed(0)
output = sd_pipe(
prompt=prompt,
source_prompt=source_prompt,
generator=generator,
num_inference_steps=2,
init_image=init_image,
eta=0.1,
strength=0.8,
guidance_scale=3,
source_guidance_scale=1,
output_type="np",
)
images = output.images
image_slice = images[0, -3:, -3:, -1]
assert images.shape == (1, 32, 32, 3)
expected_slice = np.array([0.3506, 0.4543, 0.446, 0.4575, 0.5195, 0.4155, 0.5273, 0.518, 0.4116])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
@require_torch_gpu
class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_cycle_diffusion_pipeline_fp16(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/cycle-diffusion/black_colored_car.png"
)
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/cycle-diffusion/blue_colored_car_fp16.npy"
)
init_image = init_image.resize((512, 512))
model_id = "CompVis/stable-diffusion-v1-4"
scheduler = DDIMScheduler.from_config(model_id, subfolder="scheduler")
pipe = CycleDiffusionPipeline.from_pretrained(
model_id, scheduler=scheduler, safety_checker=None, torch_dtype=torch.float16, revision="fp16"
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
source_prompt = "A black colored car"
prompt = "A blue colored car"
torch.manual_seed(0)
output = pipe(
prompt=prompt,
source_prompt=source_prompt,
init_image=init_image,
num_inference_steps=100,
eta=0.1,
strength=0.85,
guidance_scale=3,
source_guidance_scale=1,
output_type="np",
)
image = output.images
# the values aren't exactly equal, but the images look the same visually
assert np.abs(image - expected_image).max() < 1e-2
def test_cycle_diffusion_pipeline(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/cycle-diffusion/black_colored_car.png"
)
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/cycle-diffusion/blue_colored_car.npy"
)
init_image = init_image.resize((512, 512))
model_id = "CompVis/stable-diffusion-v1-4"
scheduler = DDIMScheduler.from_config(model_id, subfolder="scheduler")
pipe = CycleDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, safety_checker=None)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
source_prompt = "A black colored car"
prompt = "A blue colored car"
torch.manual_seed(0)
output = pipe(
prompt=prompt,
source_prompt=source_prompt,
init_image=init_image,
num_inference_steps=100,
eta=0.1,
strength=0.85,
guidance_scale=3,
source_guidance_scale=1,
output_type="np",
)
image = output.images
assert np.abs(image - expected_image).max() < 1e-2