diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 7e46d95a..331d4fff 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -96,5 +96,7 @@ title: "Stochastic Karras VE" - local: api/pipelines/dance_diffusion title: "Dance Diffusion" + - local: api/pipelines/repaint + title: "RePaint" title: "Pipelines" title: "API" diff --git a/docs/source/api/pipelines/overview.mdx b/docs/source/api/pipelines/overview.mdx index 9bb351d9..a53a2f8b 100644 --- a/docs/source/api/pipelines/overview.mdx +++ b/docs/source/api/pipelines/overview.mdx @@ -41,19 +41,21 @@ If you are looking for *official* training examples, please have a look at [exam The following table summarizes all officially supported pipelines, their corresponding paper, and if available a colab notebook to directly try them out. -| Pipeline | Paper | Tasks | Colab -|---|---|:---:|:---:| -| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | -| [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) -| [latent_diffusion](./latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation | -| [latent_diffusion_uncond](./latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation | -| [pndm](./pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation | -| [score_sde_ve](./score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | -| [score_sde_vp](./score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | -| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) -| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) -| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) -| [stochastic_karras_ve](./stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | +| Pipeline | Paper | Tasks | Colab +|------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------|:-------------------------------------:|:---:| +| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | +| [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) +| [latent_diffusion](./latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Text-to-Image Generation | +| [latent_diffusion_uncond](./latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation | +| [pndm](./pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation | +| [score_sde_ve](./score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | +| [score_sde_vp](./score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | +| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) +| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) +| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) +| [stochastic_karras_ve](./stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | +| [repaint](./repaint) | [**RePaint: Inpainting using Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2201.09865) | Image Inpainting | + **Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers. diff --git a/docs/source/api/pipelines/repaint.mdx b/docs/source/api/pipelines/repaint.mdx new file mode 100644 index 00000000..0b7de8a4 --- /dev/null +++ b/docs/source/api/pipelines/repaint.mdx @@ -0,0 +1,77 @@ + + +# RePaint + +## Overview + +[RePaint: Inpainting using Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2201.09865) (PNDM) by Andreas Lugmayr, Martin Danelljan, Andres Romero, Fisher Yu, Radu Timofte, Luc Van Gool. + +The abstract of the paper is the following: + +Free-form inpainting is the task of adding new content to an image in the regions specified by an arbitrary binary mask. Most existing approaches train for a certain distribution of masks, which limits their generalization capabilities to unseen mask types. Furthermore, training with pixel-wise and perceptual losses often leads to simple textural extensions towards the missing areas instead of semantically meaningful generation. In this work, we propose RePaint: A Denoising Diffusion Probabilistic Model (DDPM) based inpainting approach that is applicable to even extreme masks. We employ a pretrained unconditional DDPM as the generative prior. To condition the generation process, we only alter the reverse diffusion iterations by sampling the unmasked regions using the given image information. Since this technique does not modify or condition the original DDPM network itself, the model produces high-quality and diverse output images for any inpainting form. We validate our method for both faces and general-purpose image inpainting using standard and extreme masks. +RePaint outperforms state-of-the-art Autoregressive, and GAN approaches for at least five out of six mask distributions. + +The original codebase can be found [here](https://github.com/andreas128/RePaint). + +## Available Pipelines: + +| Pipeline | Tasks | Colab +|-------------------------------------------------------------------------------------------------------------------------------|--------------------|:---:| +| [pipeline_repaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/repaint/pipeline_repaint.py) | *Image Inpainting* | - | + +## Usage example + +```python +from io import BytesIO + +import torch + +import PIL +import requests +from diffusers import RePaintPipeline, RePaintScheduler + + +def download_image(url): + response = requests.get(url) + return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + +img_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/celeba_hq_256.png" +mask_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/mask_256.png" + +# Load the original image and the mask as PIL images +original_image = download_image(img_url).resize((256, 256)) +mask_image = download_image(mask_url).resize((256, 256)) + +# Load the RePaint scheduler and pipeline based on a pretrained DDPM model +scheduler = RePaintScheduler.from_config("google/ddpm-ema-celebahq-256") +pipe = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=scheduler) +pipe = pipe.to("cuda") + +generator = torch.Generator(device="cuda").manual_seed(0) +output = pipe( + original_image=original_image, + mask_image=mask_image, + num_inference_steps=250, + eta=0.0, + jump_length=10, + jump_n_sample=10, + generator=generator, +) +inpainted_image = output.images[0] +``` + +## RePaintPipeline +[[autodoc]] pipelines.repaint.pipeline_repaint.RePaintPipeline + - __call__ + diff --git a/docs/source/api/schedulers.mdx b/docs/source/api/schedulers.mdx index 6616a3e5..6e7da10e 100644 --- a/docs/source/api/schedulers.mdx +++ b/docs/source/api/schedulers.mdx @@ -127,4 +127,14 @@ Fast scheduler which often times generates good outputs with 20-30 steps. Ancestral sampling with Euler method steps. Based on the original (k-diffusion)[https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72] implementation by Katherine Crowson. Fast scheduler which often times generates good outputs with 20-30 steps. -[[autodoc]] EulerAncestralDiscreteScheduler \ No newline at end of file +[[autodoc]] EulerAncestralDiscreteScheduler + + +#### RePaint scheduler + +DDPM-based inpainting scheduler for unsupervised inpainting with extreme masks. +Intended for use with [`RePaintPipeline`]. +Based on the paper [RePaint: Inpainting using Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2201.09865) +and the original implementation by Andreas Lugmayr et al.: https://github.com/andreas128/RePaint + +[[autodoc]] RePaintScheduler \ No newline at end of file diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 49c3e82b..1a9a7d74 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -36,6 +36,7 @@ if is_torch_available(): KarrasVePipeline, LDMPipeline, PNDMPipeline, + RePaintPipeline, ScoreSdeVePipeline, ) from .schedulers import ( @@ -46,6 +47,7 @@ if is_torch_available(): IPNDMScheduler, KarrasVeScheduler, PNDMScheduler, + RePaintScheduler, SchedulerMixin, ScoreSdeVeScheduler, ) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index b3124af3..8015d4e1 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -7,6 +7,7 @@ if is_torch_available(): from .ddpm import DDPMPipeline from .latent_diffusion_uncond import LDMPipeline from .pndm import PNDMPipeline + from .repaint import RePaintPipeline from .score_sde_ve import ScoreSdeVePipeline from .stochastic_karras_ve import KarrasVePipeline else: diff --git a/src/diffusers/pipelines/repaint/__init__.py b/src/diffusers/pipelines/repaint/__init__.py new file mode 100644 index 00000000..16bc86d1 --- /dev/null +++ b/src/diffusers/pipelines/repaint/__init__.py @@ -0,0 +1 @@ +from .pipeline_repaint import RePaintPipeline diff --git a/src/diffusers/pipelines/repaint/pipeline_repaint.py b/src/diffusers/pipelines/repaint/pipeline_repaint.py new file mode 100644 index 00000000..7af88f62 --- /dev/null +++ b/src/diffusers/pipelines/repaint/pipeline_repaint.py @@ -0,0 +1,140 @@ +# Copyright 2022 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +import PIL +from tqdm.auto import tqdm + +from ...models import UNet2DModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import RePaintScheduler + + +def _preprocess_image(image: PIL.Image.Image): + image = np.array(image.convert("RGB")) + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + return image + + +def _preprocess_mask(mask: PIL.Image.Image): + mask = np.array(mask.convert("L")) + mask = mask.astype(np.float32) / 255.0 + mask = mask[None, None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + return mask + + +class RePaintPipeline(DiffusionPipeline): + unet: UNet2DModel + scheduler: RePaintScheduler + + def __init__(self, unet, scheduler): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + original_image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image: Union[torch.FloatTensor, PIL.Image.Image], + num_inference_steps: int = 250, + eta: float = 0.0, + jump_length: int = 10, + jump_n_sample: int = 10, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Args: + original_image (`torch.FloatTensor` or `PIL.Image.Image`): + The original image to inpaint on. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + The mask_image where 0.0 values define which part of the original image to inpaint (change). + num_inference_steps (`int`, *optional*, defaults to 1000): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + eta (`float`): + The weight of noise for added noise in a diffusion step. Its value is between 0.0 and 1.0 - 0.0 is DDIM + and 1.0 is DDPM scheduler respectively. + jump_length (`int`, *optional*, defaults to 10): + The number of steps taken forward in time before going backward in time for a single jump ("j" in + RePaint paper). Take a look at Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf. + jump_n_sample (`int`, *optional*, defaults to 10): + The number of times we will make forward time jump for a given chosen time sample. Take a look at + Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + + if not isinstance(original_image, torch.FloatTensor): + original_image = _preprocess_image(original_image) + original_image = original_image.to(self.device) + if not isinstance(mask_image, torch.FloatTensor): + mask_image = _preprocess_mask(mask_image) + mask_image = mask_image.to(self.device) + + # sample gaussian noise to begin the loop + image = torch.randn( + original_image.shape, + generator=generator, + device=self.device, + ) + image = image.to(self.device) + + # set step values + self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device) + self.scheduler.eta = eta + + t_last = self.scheduler.timesteps[0] + 1 + for i, t in enumerate(tqdm(self.scheduler.timesteps)): + if t < t_last: + # predict the noise residual + model_output = self.unet(image, t).sample + # compute previous image: x_t -> x_t-1 + image = self.scheduler.step(model_output, t, image, original_image, mask_image, generator).prev_sample + + else: + # compute the reverse: x_t-1 -> x_t + image = self.scheduler.undo_step(image, t_last, generator) + t_last = t + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index c3999d2c..a1915ed8 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -24,6 +24,7 @@ if is_torch_available(): from .scheduling_ipndm import IPNDMScheduler from .scheduling_karras_ve import KarrasVeScheduler from .scheduling_pndm import PNDMScheduler + from .scheduling_repaint import RePaintScheduler from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler from .scheduling_utils import SchedulerMixin diff --git a/src/diffusers/schedulers/scheduling_repaint.py b/src/diffusers/schedulers/scheduling_repaint.py new file mode 100644 index 00000000..1751f41c --- /dev/null +++ b/src/diffusers/schedulers/scheduling_repaint.py @@ -0,0 +1,322 @@ +# Copyright 2022 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class RePaintSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from + the current timestep. `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: torch.FloatTensor + + +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class RePaintScheduler(SchedulerMixin, ConfigMixin): + """ + RePaint is a schedule for DDPM inpainting inside a given mask. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + For more details, see the original paper: https://arxiv.org/pdf/2201.09865.pdf + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + eta (`float`): + The weight of noise for added noise in a diffusion step. Its value is between 0.0 and 1.0 -0.0 is DDIM and + 1.0 is DDPM scheduler respectively. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + variance_type (`str`): + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + clip_sample (`bool`, default `True`): + option to clip predicted sample between -1 and 1 for numerical stability. + + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + eta: float = 0.0, + trained_betas: Optional[np.ndarray] = None, + clip_sample: bool = True, + ): + if trained_betas is not None: + self.betas = torch.from_numpy(trained_betas) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + elif beta_schedule == "sigmoid": + # GeoDiff sigmoid schedule + betas = torch.linspace(-6, 6, num_train_timesteps) + self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + self.one = torch.tensor(1.0) + + self.final_alpha_cumprod = torch.tensor(1.0) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) + + self.eta = eta + + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + def set_timesteps( + self, + num_inference_steps: int, + jump_length: int = 10, + jump_n_sample: int = 10, + device: Union[str, torch.device] = None, + ): + num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) + self.num_inference_steps = num_inference_steps + + timesteps = [] + + jumps = {} + for j in range(0, num_inference_steps - jump_length, jump_length): + jumps[j] = jump_n_sample - 1 + + t = num_inference_steps + while t >= 1: + t = t - 1 + timesteps.append(t) + + if jumps.get(t, 0) > 0: + jumps[t] = jumps[t] - 1 + for _ in range(jump_length): + t = t + 1 + timesteps.append(t) + + timesteps = np.array(timesteps) * (self.config.num_train_timesteps // self.num_inference_steps) + self.timesteps = torch.from_numpy(timesteps).to(device) + + def _get_variance(self, t): + prev_timestep = t - self.config.num_train_timesteps // self.num_inference_steps + + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # For t > 0, compute predicted variance βt (see formula (6) and (7) from + # https://arxiv.org/pdf/2006.11239.pdf) and sample from it to get + # previous sample x_{t-1} ~ N(pred_prev_sample, variance) == add + # variance to pred_sample + # Is equivalent to formula (16) in https://arxiv.org/pdf/2010.02502.pdf + # without eta. + # variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t] + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + original_image: torch.FloatTensor, + mask: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[RePaintSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned + diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + original_image (`torch.FloatTensor`): + the original image to inpaint on. + mask (`torch.FloatTensor`): + the mask where 0.0 values define which part of the original image to inpaint (change). + generator (`torch.Generator`, *optional*): random number generator. + return_dict (`bool`): option for returning tuple rather than + DDPMSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.RePaintSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.RePaintSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + t = timestep + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 1. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 + + # 3. Clip "predicted x_0" + if self.config.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # We choose to follow RePaint Algorithm 1 to get x_{t-1}, however we + # substitute formula (7) in the algorithm coming from DDPM paper + # (formula (4) Algorithm 2 - Sampling) with formula (12) from DDIM paper. + # DDIM schedule gives the same results as DDPM with eta = 1.0 + # Noise is being reused in 7. and 8., but no impact on quality has + # been observed. + + # 5. Add noise + noise = torch.randn( + model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device + ) + std_dev_t = self.eta * self._get_variance(timestep) ** 0.5 + + variance = 0 + if t > 0 and self.eta > 0: + variance = std_dev_t * noise + + # 6. compute "direction pointing to x_t" of formula (12) + # from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * model_output + + # 7. compute x_{t-1} of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_unknown_part = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + variance + + # 8. Algorithm 1 Line 5 https://arxiv.org/pdf/2201.09865.pdf + prev_known_part = (alpha_prod_t**0.5) * original_image + ((1 - alpha_prod_t) ** 0.5) * noise + + # 9. Algorithm 1 Line 8 https://arxiv.org/pdf/2201.09865.pdf + pred_prev_sample = mask * prev_known_part + (1.0 - mask) * prev_unknown_part + + if not return_dict: + return ( + pred_prev_sample, + pred_original_sample, + ) + + return RePaintSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) + + def undo_step(self, sample, timestep, generator=None): + n = self.config.num_train_timesteps // self.num_inference_steps + + for i in range(n): + beta = self.betas[timestep + i] + noise = torch.randn(sample.shape, generator=generator, device=sample.device) + + # 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf + sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise + + return sample + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + raise NotImplementedError("Use `DDPMScheduler.add_noise()` to train for sampling with RePaint.") + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 5dd58327..63aa2096 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -227,6 +227,21 @@ class PNDMPipeline(metaclass=DummyObject): requires_backends(cls, ["torch"]) +class RePaintPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ScoreSdeVePipeline(metaclass=DummyObject): _backends = ["torch"] @@ -347,6 +362,21 @@ class PNDMScheduler(metaclass=DummyObject): requires_backends(cls, ["torch"]) +class RePaintScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class SchedulerMixin(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/pipelines/repaint/__init__.py b/tests/pipelines/repaint/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/pipelines/repaint/test_repaint.py b/tests/pipelines/repaint/test_repaint.py new file mode 100644 index 00000000..23544dfd --- /dev/null +++ b/tests/pipelines/repaint/test_repaint.py @@ -0,0 +1,65 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from diffusers import RePaintPipeline, RePaintScheduler, UNet2DModel +from diffusers.utils.testing_utils import load_image, require_torch_gpu, slow, torch_device + + +torch.backends.cuda.matmul.allow_tf32 = False + + +@slow +@require_torch_gpu +class RepaintPipelineIntegrationTests(unittest.TestCase): + def test_celebahq(self): + original_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/" + "repaint/celeba_hq_256.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/mask_256.png" + ) + expected_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/" + "repaint/celeba_hq_256_result.png" + ) + expected_image = np.array(expected_image, dtype=np.float32) / 255.0 + + model_id = "google/ddpm-ema-celebahq-256" + unet = UNet2DModel.from_pretrained(model_id) + scheduler = RePaintScheduler.from_config(model_id) + + repaint = RePaintPipeline(unet=unet, scheduler=scheduler).to(torch_device) + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = repaint( + original_image, + mask_image, + num_inference_steps=250, + eta=0.0, + jump_length=10, + jump_n_sample=10, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (256, 256, 3) + assert np.abs(expected_image - image).mean() < 1e-2