feat: add repaint (#974)

* feat: add repaint

* fix: fix quality check with `make fix-copies`

* fix: remove old unnecessary arg

* chore: change default to DDPM (looks better in experiments)

* ".to(device)" changed to "device="

Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>

* make generator device-specific

Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>

* make generator device-specific and change shape

Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>

* fix: add preprocessing for image and mask

Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>

* fix: update test

Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>

* Update src/diffusers/pipelines/repaint/pipeline_repaint.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Add docs and examples

* Fix toctree

Co-authored-by: fja <fja@zurich.ibm.com>
Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Anton Lozhkov <anton@huggingface.co>
This commit is contained in:
Revist 2022-11-03 15:42:46 +01:00 committed by GitHub
parent 4a38166afe
commit d38c804320
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 667 additions and 14 deletions

View File

@ -96,5 +96,7 @@
title: "Stochastic Karras VE"
- local: api/pipelines/dance_diffusion
title: "Dance Diffusion"
- local: api/pipelines/repaint
title: "RePaint"
title: "Pipelines"
title: "API"

View File

@ -41,19 +41,21 @@ If you are looking for *official* training examples, please have a look at [exam
The following table summarizes all officially supported pipelines, their corresponding paper, and if
available a colab notebook to directly try them out.
| Pipeline | Paper | Tasks | Colab
|---|---|:---:|:---:|
| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
| [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [latent_diffusion](./latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
| [latent_diffusion_uncond](./latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
| [pndm](./pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
| [score_sde_ve](./score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [score_sde_vp](./score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stochastic_karras_ve](./stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
| Pipeline | Paper | Tasks | Colab
|------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------|:-------------------------------------:|:---:|
| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
| [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [latent_diffusion](./latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Text-to-Image Generation |
| [latent_diffusion_uncond](./latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
| [pndm](./pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
| [score_sde_ve](./score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [score_sde_vp](./score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stochastic_karras_ve](./stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
| [repaint](./repaint) | [**RePaint: Inpainting using Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2201.09865) | Image Inpainting |
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.

View File

@ -0,0 +1,77 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# RePaint
## Overview
[RePaint: Inpainting using Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2201.09865) (PNDM) by Andreas Lugmayr, Martin Danelljan, Andres Romero, Fisher Yu, Radu Timofte, Luc Van Gool.
The abstract of the paper is the following:
Free-form inpainting is the task of adding new content to an image in the regions specified by an arbitrary binary mask. Most existing approaches train for a certain distribution of masks, which limits their generalization capabilities to unseen mask types. Furthermore, training with pixel-wise and perceptual losses often leads to simple textural extensions towards the missing areas instead of semantically meaningful generation. In this work, we propose RePaint: A Denoising Diffusion Probabilistic Model (DDPM) based inpainting approach that is applicable to even extreme masks. We employ a pretrained unconditional DDPM as the generative prior. To condition the generation process, we only alter the reverse diffusion iterations by sampling the unmasked regions using the given image information. Since this technique does not modify or condition the original DDPM network itself, the model produces high-quality and diverse output images for any inpainting form. We validate our method for both faces and general-purpose image inpainting using standard and extreme masks.
RePaint outperforms state-of-the-art Autoregressive, and GAN approaches for at least five out of six mask distributions.
The original codebase can be found [here](https://github.com/andreas128/RePaint).
## Available Pipelines:
| Pipeline | Tasks | Colab
|-------------------------------------------------------------------------------------------------------------------------------|--------------------|:---:|
| [pipeline_repaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/repaint/pipeline_repaint.py) | *Image Inpainting* | - |
## Usage example
```python
from io import BytesIO
import torch
import PIL
import requests
from diffusers import RePaintPipeline, RePaintScheduler
def download_image(url):
response = requests.get(url)
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
img_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/celeba_hq_256.png"
mask_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/mask_256.png"
# Load the original image and the mask as PIL images
original_image = download_image(img_url).resize((256, 256))
mask_image = download_image(mask_url).resize((256, 256))
# Load the RePaint scheduler and pipeline based on a pretrained DDPM model
scheduler = RePaintScheduler.from_config("google/ddpm-ema-celebahq-256")
pipe = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=scheduler)
pipe = pipe.to("cuda")
generator = torch.Generator(device="cuda").manual_seed(0)
output = pipe(
original_image=original_image,
mask_image=mask_image,
num_inference_steps=250,
eta=0.0,
jump_length=10,
jump_n_sample=10,
generator=generator,
)
inpainted_image = output.images[0]
```
## RePaintPipeline
[[autodoc]] pipelines.repaint.pipeline_repaint.RePaintPipeline
- __call__

View File

@ -127,4 +127,14 @@ Fast scheduler which often times generates good outputs with 20-30 steps.
Ancestral sampling with Euler method steps. Based on the original (k-diffusion)[https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72] implementation by Katherine Crowson.
Fast scheduler which often times generates good outputs with 20-30 steps.
[[autodoc]] EulerAncestralDiscreteScheduler
[[autodoc]] EulerAncestralDiscreteScheduler
#### RePaint scheduler
DDPM-based inpainting scheduler for unsupervised inpainting with extreme masks.
Intended for use with [`RePaintPipeline`].
Based on the paper [RePaint: Inpainting using Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2201.09865)
and the original implementation by Andreas Lugmayr et al.: https://github.com/andreas128/RePaint
[[autodoc]] RePaintScheduler

View File

@ -36,6 +36,7 @@ if is_torch_available():
KarrasVePipeline,
LDMPipeline,
PNDMPipeline,
RePaintPipeline,
ScoreSdeVePipeline,
)
from .schedulers import (
@ -46,6 +47,7 @@ if is_torch_available():
IPNDMScheduler,
KarrasVeScheduler,
PNDMScheduler,
RePaintScheduler,
SchedulerMixin,
ScoreSdeVeScheduler,
)

View File

@ -7,6 +7,7 @@ if is_torch_available():
from .ddpm import DDPMPipeline
from .latent_diffusion_uncond import LDMPipeline
from .pndm import PNDMPipeline
from .repaint import RePaintPipeline
from .score_sde_ve import ScoreSdeVePipeline
from .stochastic_karras_ve import KarrasVePipeline
else:

View File

@ -0,0 +1 @@
from .pipeline_repaint import RePaintPipeline

View File

@ -0,0 +1,140 @@
# Copyright 2022 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import numpy as np
import torch
import PIL
from tqdm.auto import tqdm
from ...models import UNet2DModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import RePaintScheduler
def _preprocess_image(image: PIL.Image.Image):
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
return image
def _preprocess_mask(mask: PIL.Image.Image):
mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
return mask
class RePaintPipeline(DiffusionPipeline):
unet: UNet2DModel
scheduler: RePaintScheduler
def __init__(self, unet, scheduler):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(
self,
original_image: Union[torch.FloatTensor, PIL.Image.Image],
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
num_inference_steps: int = 250,
eta: float = 0.0,
jump_length: int = 10,
jump_n_sample: int = 10,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Args:
original_image (`torch.FloatTensor` or `PIL.Image.Image`):
The original image to inpaint on.
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
The mask_image where 0.0 values define which part of the original image to inpaint (change).
num_inference_steps (`int`, *optional*, defaults to 1000):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
eta (`float`):
The weight of noise for added noise in a diffusion step. Its value is between 0.0 and 1.0 - 0.0 is DDIM
and 1.0 is DDPM scheduler respectively.
jump_length (`int`, *optional*, defaults to 10):
The number of steps taken forward in time before going backward in time for a single jump ("j" in
RePaint paper). Take a look at Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf.
jump_n_sample (`int`, *optional*, defaults to 10):
The number of times we will make forward time jump for a given chosen time sample. Take a look at
Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
Returns:
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
if not isinstance(original_image, torch.FloatTensor):
original_image = _preprocess_image(original_image)
original_image = original_image.to(self.device)
if not isinstance(mask_image, torch.FloatTensor):
mask_image = _preprocess_mask(mask_image)
mask_image = mask_image.to(self.device)
# sample gaussian noise to begin the loop
image = torch.randn(
original_image.shape,
generator=generator,
device=self.device,
)
image = image.to(self.device)
# set step values
self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device)
self.scheduler.eta = eta
t_last = self.scheduler.timesteps[0] + 1
for i, t in enumerate(tqdm(self.scheduler.timesteps)):
if t < t_last:
# predict the noise residual
model_output = self.unet(image, t).sample
# compute previous image: x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, original_image, mask_image, generator).prev_sample
else:
# compute the reverse: x_t-1 -> x_t
image = self.scheduler.undo_step(image, t_last, generator)
t_last = t
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)

View File

@ -24,6 +24,7 @@ if is_torch_available():
from .scheduling_ipndm import IPNDMScheduler
from .scheduling_karras_ve import KarrasVeScheduler
from .scheduling_pndm import PNDMScheduler
from .scheduling_repaint import RePaintScheduler
from .scheduling_sde_ve import ScoreSdeVeScheduler
from .scheduling_sde_vp import ScoreSdeVpScheduler
from .scheduling_utils import SchedulerMixin

View File

@ -0,0 +1,322 @@
# Copyright 2022 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
@dataclass
class RePaintSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from
the current timestep. `pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
pred_original_sample: torch.FloatTensor
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
class RePaintScheduler(SchedulerMixin, ConfigMixin):
"""
RePaint is a schedule for DDPM inpainting inside a given mask.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
For more details, see the original paper: https://arxiv.org/pdf/2201.09865.pdf
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
eta (`float`):
The weight of noise for added noise in a diffusion step. Its value is between 0.0 and 1.0 -0.0 is DDIM and
1.0 is DDPM scheduler respectively.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
variance_type (`str`):
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
"""
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
eta: float = 0.0,
trained_betas: Optional[np.ndarray] = None,
clip_sample: bool = True,
):
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
elif beta_schedule == "sigmoid":
# GeoDiff sigmoid schedule
betas = torch.linspace(-6, 6, num_train_timesteps)
self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.one = torch.tensor(1.0)
self.final_alpha_cumprod = torch.tensor(1.0)
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
# setable values
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
self.eta = eta
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`): input sample
timestep (`int`, optional): current timestep
Returns:
`torch.FloatTensor`: scaled input sample
"""
return sample
def set_timesteps(
self,
num_inference_steps: int,
jump_length: int = 10,
jump_n_sample: int = 10,
device: Union[str, torch.device] = None,
):
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
self.num_inference_steps = num_inference_steps
timesteps = []
jumps = {}
for j in range(0, num_inference_steps - jump_length, jump_length):
jumps[j] = jump_n_sample - 1
t = num_inference_steps
while t >= 1:
t = t - 1
timesteps.append(t)
if jumps.get(t, 0) > 0:
jumps[t] = jumps[t] - 1
for _ in range(jump_length):
t = t + 1
timesteps.append(t)
timesteps = np.array(timesteps) * (self.config.num_train_timesteps // self.num_inference_steps)
self.timesteps = torch.from_numpy(timesteps).to(device)
def _get_variance(self, t):
prev_timestep = t - self.config.num_train_timesteps // self.num_inference_steps
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
# For t > 0, compute predicted variance βt (see formula (6) and (7) from
# https://arxiv.org/pdf/2006.11239.pdf) and sample from it to get
# previous sample x_{t-1} ~ N(pred_prev_sample, variance) == add
# variance to pred_sample
# Is equivalent to formula (16) in https://arxiv.org/pdf/2010.02502.pdf
# without eta.
# variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
return variance
def step(
self,
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
original_image: torch.FloatTensor,
mask: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[RePaintSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`): direct output from learned
diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
original_image (`torch.FloatTensor`):
the original image to inpaint on.
mask (`torch.FloatTensor`):
the mask where 0.0 values define which part of the original image to inpaint (change).
generator (`torch.Generator`, *optional*): random number generator.
return_dict (`bool`): option for returning tuple rather than
DDPMSchedulerOutput class
Returns:
[`~schedulers.scheduling_utils.RePaintSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.RePaintSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
t = timestep
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
# 1. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
# 3. Clip "predicted x_0"
if self.config.clip_sample:
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
# We choose to follow RePaint Algorithm 1 to get x_{t-1}, however we
# substitute formula (7) in the algorithm coming from DDPM paper
# (formula (4) Algorithm 2 - Sampling) with formula (12) from DDIM paper.
# DDIM schedule gives the same results as DDPM with eta = 1.0
# Noise is being reused in 7. and 8., but no impact on quality has
# been observed.
# 5. Add noise
noise = torch.randn(
model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device
)
std_dev_t = self.eta * self._get_variance(timestep) ** 0.5
variance = 0
if t > 0 and self.eta > 0:
variance = std_dev_t * noise
# 6. compute "direction pointing to x_t" of formula (12)
# from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * model_output
# 7. compute x_{t-1} of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_unknown_part = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + variance
# 8. Algorithm 1 Line 5 https://arxiv.org/pdf/2201.09865.pdf
prev_known_part = (alpha_prod_t**0.5) * original_image + ((1 - alpha_prod_t) ** 0.5) * noise
# 9. Algorithm 1 Line 8 https://arxiv.org/pdf/2201.09865.pdf
pred_prev_sample = mask * prev_known_part + (1.0 - mask) * prev_unknown_part
if not return_dict:
return (
pred_prev_sample,
pred_original_sample,
)
return RePaintSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
def undo_step(self, sample, timestep, generator=None):
n = self.config.num_train_timesteps // self.num_inference_steps
for i in range(n):
beta = self.betas[timestep + i]
noise = torch.randn(sample.shape, generator=generator, device=sample.device)
# 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf
sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise
return sample
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
raise NotImplementedError("Use `DDPMScheduler.add_noise()` to train for sampling with RePaint.")
def __len__(self):
return self.config.num_train_timesteps

View File

@ -227,6 +227,21 @@ class PNDMPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class RePaintPipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ScoreSdeVePipeline(metaclass=DummyObject):
_backends = ["torch"]
@ -347,6 +362,21 @@ class PNDMScheduler(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class RePaintScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class SchedulerMixin(metaclass=DummyObject):
_backends = ["torch"]

View File

View File

@ -0,0 +1,65 @@
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from diffusers import RePaintPipeline, RePaintScheduler, UNet2DModel
from diffusers.utils.testing_utils import load_image, require_torch_gpu, slow, torch_device
torch.backends.cuda.matmul.allow_tf32 = False
@slow
@require_torch_gpu
class RepaintPipelineIntegrationTests(unittest.TestCase):
def test_celebahq(self):
original_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/"
"repaint/celeba_hq_256.png"
)
mask_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/mask_256.png"
)
expected_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/"
"repaint/celeba_hq_256_result.png"
)
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
model_id = "google/ddpm-ema-celebahq-256"
unet = UNet2DModel.from_pretrained(model_id)
scheduler = RePaintScheduler.from_config(model_id)
repaint = RePaintPipeline(unet=unet, scheduler=scheduler).to(torch_device)
generator = torch.Generator(device=torch_device).manual_seed(0)
output = repaint(
original_image,
mask_image,
num_inference_steps=250,
eta=0.0,
jump_length=10,
jump_n_sample=10,
generator=generator,
output_type="np",
)
image = output.images[0]
assert image.shape == (256, 256, 3)
assert np.abs(expected_image - image).mean() < 1e-2