Merge branch 'main' of github.com:huggingface/diffusers
Conflicts: src/diffusers/__init__.py src/diffusers/models/__init__.py
This commit is contained in:
commit
bb98a5b709
|
@ -0,0 +1,28 @@
|
|||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Denoising Diffusion Implicit Models (DDIM)
|
||||
|
||||
## Overview
|
||||
|
||||
DDPM was proposed in [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) by *Jiaming Song, Chenlin Meng, Stefano Ermon*
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Denoising diffusion probabilistic models (DDPMs) have achieved high quality image generation without adversarial training, yet they require simulating a Markov chain for many steps to produce a sample. To accelerate sampling, we present denoising diffusion implicit models (DDIMs), a more efficient class of iterative implicit probabilistic models with the same training procedure as DDPMs. In DDPMs, the generative process is defined as the reverse of a Markovian diffusion process. We construct a class of non-Markovian diffusion processes that lead to the same training objective, but whose reverse process can be much faster to sample from. We empirically demonstrate that DDIMs can produce high quality samples 10× to 50× faster in terms of wall-clock time compared to DDPMs, allow us to trade off computation for sample quality, and can perform semantically meaningful image interpolation directly in the latent space.*
|
||||
|
||||
Tips:
|
||||
|
||||
- ...
|
||||
- ...
|
||||
|
||||
This model was contributed by [???](https://huggingface.co/???). The original code can be found [here](https://github.com/hojonathanho/diffusion).
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
import tqdm
|
||||
import torch
|
||||
|
||||
|
||||
def compute_alpha(beta, t):
|
||||
beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
|
||||
a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
|
||||
return a
|
||||
|
||||
|
||||
class DDIM(DiffusionPipeline):
|
||||
|
||||
def __init__(self, unet, noise_scheduler):
|
||||
super().__init__()
|
||||
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
|
||||
|
||||
def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, inference_time_steps=50):
|
||||
# eta is η in paper
|
||||
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
num_timesteps = self.noise_scheduler.num_timesteps
|
||||
|
||||
seq = range(0, num_timesteps, num_timesteps // inference_time_steps)
|
||||
b = self.noise_scheduler.betas.to(torch_device)
|
||||
|
||||
self.unet.to(torch_device)
|
||||
x = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
|
||||
|
||||
with torch.no_grad():
|
||||
n = batch_size
|
||||
seq_next = [-1] + list(seq[:-1])
|
||||
x0_preds = []
|
||||
xs = [x]
|
||||
for i, j in zip(reversed(seq), reversed(seq_next)):
|
||||
print(i)
|
||||
t = (torch.ones(n) * i).to(x.device)
|
||||
next_t = (torch.ones(n) * j).to(x.device)
|
||||
at = compute_alpha(b, t.long())
|
||||
at_next = compute_alpha(b, next_t.long())
|
||||
xt = xs[-1].to('cuda')
|
||||
et = self.unet(xt, t)
|
||||
x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
|
||||
x0_preds.append(x0_t.to('cpu'))
|
||||
# eta
|
||||
c1 = (
|
||||
eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
|
||||
)
|
||||
c2 = ((1 - at_next) - c1 ** 2).sqrt()
|
||||
xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
|
||||
xs.append(xt_next.to('cpu'))
|
||||
|
||||
return xt_next
|
|
@ -0,0 +1,17 @@
|
|||
#!/usr/bin/env python3
|
||||
import torch
|
||||
|
||||
from diffusers import GaussianDDPMScheduler, UNetModel
|
||||
|
||||
|
||||
model = UNetModel(dim=64, dim_mults=(1, 2, 4, 8))
|
||||
|
||||
diffusion = GaussianDDPMScheduler(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2
|
||||
|
||||
training_images = torch.randn(8, 3, 128, 128) # your images need to be normalized from a range of -1 to +1
|
||||
loss = diffusion(training_images)
|
||||
loss.backward()
|
||||
# after a lot of training
|
||||
|
||||
sampled_images = diffusion.sample(batch_size=4)
|
||||
sampled_images.shape # (4, 3, 128, 128)
|
|
@ -0,0 +1,23 @@
|
|||
#!/usr/bin/env python3
|
||||
# !pip install diffusers
|
||||
from modeling_ddim import DDIM
|
||||
import PIL.Image
|
||||
import numpy as np
|
||||
|
||||
model_id = "fusing/ddpm-cifar10"
|
||||
model_id = "fusing/ddpm-lsun-bedroom"
|
||||
|
||||
# load model and scheduler
|
||||
ddpm = DDIM.from_pretrained(model_id)
|
||||
|
||||
# run pipeline in inference (sample random noise and denoise)
|
||||
image = ddpm()
|
||||
|
||||
# process image to PIL
|
||||
image_processed = image.cpu().permute(0, 2, 3, 1)
|
||||
image_processed = (image_processed + 1.0) * 127.5
|
||||
image_processed = image_processed.numpy().astype(np.uint8)
|
||||
image_pil = PIL.Image.fromarray(image_processed[0])
|
||||
|
||||
# save image
|
||||
image_pil.save("/home/patrick/images/show.png")
|
|
@ -21,8 +21,6 @@ import torch
|
|||
|
||||
class DDPM(DiffusionPipeline):
|
||||
|
||||
modeling_file = "modeling_ddpm.py"
|
||||
|
||||
def __init__(self, unet, noise_scheduler):
|
||||
super().__init__()
|
||||
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
|
||||
|
|
10
setup.py
10
setup.py
|
@ -28,11 +28,11 @@ To create the package for pypi.
|
|||
3. Unpin specific versions from setup.py that use a git install.
|
||||
|
||||
4. Checkout the release branch (v<RELEASE>-release, for example v4.19-release), and commit these changes with the
|
||||
message: "Release: <VERSION>" and push.
|
||||
message: "Release: <RELEASE>" and push.
|
||||
|
||||
5. Wait for the tests on main to be completed and be green (otherwise revert and fix bugs)
|
||||
|
||||
6. Add a tag in git to mark the release: "git tag v<VERSION> -m 'Adds tag v<VERSION> for pypi' "
|
||||
6. Add a tag in git to mark the release: "git tag v<RELEASE> -m 'Adds tag v<RELEASE> for pypi' "
|
||||
Push the tag to git: git push --tags origin v<RELEASE>-release
|
||||
|
||||
7. Build both the sources and the wheel. Do not change anything in setup.py between
|
||||
|
@ -189,7 +189,7 @@ extras["sagemaker"] = [
|
|||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.0.1",
|
||||
version="0.0.2",
|
||||
description="Diffusers",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
@ -222,8 +222,8 @@ setup(
|
|||
|
||||
# Release checklist
|
||||
# 1. Change the version in __init__.py and setup.py.
|
||||
# 2. Commit these changes with the message: "Release: VERSION"
|
||||
# 3. Add a tag in git to mark the release: "git tag VERSION -m 'Adds tag VERSION for pypi' "
|
||||
# 2. Commit these changes with the message: "Release: Release"
|
||||
# 3. Add a tag in git to mark the release: "git tag RELEASE -m 'Adds tag RELEASE for pypi' "
|
||||
# Push the tag to git: git push --tags origin main
|
||||
# 4. Run the following commands in the top-level directory:
|
||||
# python setup.py bdist_wheel
|
||||
|
|
|
@ -7,6 +7,7 @@ __version__ = "0.0.1"
|
|||
from .modeling_utils import ModelMixin
|
||||
from .models.unet import UNetModel
|
||||
from .models.unet_glide import UNetGLIDEModel
|
||||
from .models.unet_ldm import UNetLDMModel
|
||||
from .models.clip_text_transformer import CLIPTextModel
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
|
||||
|
|
|
@ -18,4 +18,5 @@
|
|||
|
||||
from .unet import UNetModel
|
||||
from .unet_glide import UNetGLIDEModel
|
||||
from .unet_ldm import UNetLDMModel
|
||||
from .clip_text_transformer import CLIPTextModel
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -57,14 +57,13 @@ class DiffusionPipeline(ConfigMixin):
|
|||
class_name = module.__class__.__name__
|
||||
|
||||
register_dict = {name: (library, class_name)}
|
||||
|
||||
|
||||
# save model index config
|
||||
self.register(**register_dict)
|
||||
|
||||
# set models
|
||||
setattr(self, name, module)
|
||||
|
||||
|
||||
register_dict = {"_module" : self.__module__.split(".")[-1] + ".py"}
|
||||
self.register(**register_dict)
|
||||
|
||||
|
@ -103,15 +102,17 @@ class DiffusionPipeline(ConfigMixin):
|
|||
cached_folder = pretrained_model_name_or_path
|
||||
|
||||
config_dict = cls.get_config_dict(cached_folder)
|
||||
|
||||
module = config_dict["_module"]
|
||||
class_name_ = config_dict["_class_name"]
|
||||
|
||||
if class_name_ == cls.__name__:
|
||||
|
||||
module_candidate = config_dict["_module"]
|
||||
|
||||
# if we load from explicit class, let's use it
|
||||
if cls != DiffusionPipeline:
|
||||
pipeline_class = cls
|
||||
else:
|
||||
# else we need to load the correct module from the Hub
|
||||
class_name_ = config_dict["_class_name"]
|
||||
module = module_candidate
|
||||
pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
|
||||
|
||||
|
||||
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
|
@ -120,8 +121,9 @@ class DiffusionPipeline(ConfigMixin):
|
|||
for name, (library_name, class_name) in init_dict.items():
|
||||
importable_classes = LOADABLE_CLASSES[library_name]
|
||||
|
||||
if library_name == module:
|
||||
if library_name == module_candidate:
|
||||
# TODO(Suraj)
|
||||
# for vq
|
||||
pass
|
||||
|
||||
library = importlib.import_module(library_name)
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
import math
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar
|
||||
|
||||
|
||||
SAMPLING_CONFIG_NAME = "scheduler_config.json"
|
||||
|
||||
|
||||
class GaussianDDPMScheduler(nn.Module, ConfigMixin):
|
||||
|
||||
config_name = SAMPLING_CONFIG_NAME
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timesteps=1000,
|
||||
beta_start=0.0001,
|
||||
beta_end=0.02,
|
||||
beta_schedule="linear",
|
||||
variance_type="fixed_small",
|
||||
):
|
||||
super().__init__()
|
||||
self.register(
|
||||
timesteps=timesteps,
|
||||
beta_start=beta_start,
|
||||
beta_end=beta_end,
|
||||
beta_schedule=beta_schedule,
|
||||
variance_type=variance_type,
|
||||
)
|
||||
self.num_timesteps = int(timesteps)
|
||||
|
||||
if beta_schedule == "linear":
|
||||
betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# GLIDE cosine schedule
|
||||
betas = betas_for_alpha_bar(
|
||||
timesteps,
|
||||
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
|
||||
|
||||
variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||
|
||||
if variance_type == "fixed_small":
|
||||
log_variance = torch.log(variance.clamp(min=1e-20))
|
||||
elif variance_type == "fixed_large":
|
||||
log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
|
||||
|
||||
self.register_buffer("betas", betas.to(torch.float32))
|
||||
self.register_buffer("alphas", alphas.to(torch.float32))
|
||||
self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
|
||||
|
||||
self.register_buffer("log_variance", log_variance.to(torch.float32))
|
||||
|
||||
def get_alpha(self, time_step):
|
||||
return self.alphas[time_step]
|
||||
|
||||
def get_beta(self, time_step):
|
||||
return self.betas[time_step]
|
||||
|
||||
def get_alpha_prod(self, time_step):
|
||||
if time_step < 0:
|
||||
return torch.tensor(1.0)
|
||||
return self.alphas_cumprod[time_step]
|
||||
|
||||
def sample_variance(self, time_step, shape, device, generator=None):
|
||||
variance = self.log_variance[time_step]
|
||||
nonzero_mask = torch.tensor([1 - (time_step == 0)], device=device).float()[None, :]
|
||||
|
||||
noise = self.sample_noise(shape, device=device, generator=generator)
|
||||
|
||||
sampled_variance = nonzero_mask * (0.5 * variance).exp()
|
||||
sampled_variance = sampled_variance * noise
|
||||
|
||||
return sampled_variance
|
||||
|
||||
def sample_noise(self, shape, device, generator=None):
|
||||
# always sample on CPU to be deterministic
|
||||
return torch.randn(shape, generator=generator).to(device)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_timesteps
|
|
@ -16,35 +16,12 @@ import math
|
|||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar
|
||||
|
||||
|
||||
SAMPLING_CONFIG_NAME = "scheduler_config.json"
|
||||
|
||||
|
||||
def linear_beta_schedule(timesteps, beta_start, beta_end):
|
||||
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function,
|
||||
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
||||
|
||||
:param num_diffusion_timesteps: the number of betas to produce.
|
||||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that
|
||||
part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
"""
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float64)
|
||||
|
||||
|
||||
class GaussianDDPMScheduler(nn.Module, ConfigMixin):
|
||||
|
||||
config_name = SAMPLING_CONFIG_NAME
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
|
||||
def linear_beta_schedule(timesteps, beta_start, beta_end):
|
||||
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function,
|
||||
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
||||
|
||||
:param num_diffusion_timesteps: the number of betas to produce.
|
||||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that
|
||||
part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
"""
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float64)
|
|
@ -25,6 +25,7 @@ import torch
|
|||
from diffusers import GaussianDDPMScheduler, UNetModel
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from models.vision.ddpm.modeling_ddpm import DDPM
|
||||
from models.vision.ddim.modeling_ddim import DDIM
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
@ -205,6 +206,7 @@ class SamplerTesterMixin(unittest.TestCase):
|
|||
|
||||
|
||||
class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
# 1. Load models
|
||||
model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32)
|
||||
|
@ -241,3 +243,31 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
new_image = ddpm_from_hub(generator=generator)
|
||||
|
||||
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
|
||||
|
||||
@slow
|
||||
def test_ddpm_cifar10(self):
|
||||
generator = torch.manual_seed(0)
|
||||
model_id = "fusing/ddpm-cifar10"
|
||||
|
||||
ddpm = DDPM.from_pretrained(model_id)
|
||||
image = ddpm(generator=generator)
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
|
||||
assert image.shape == (1, 3, 32, 32)
|
||||
expected_slice = torch.tensor([0.2250, 0.3375, 0.2360, 0.0930, 0.3440, 0.3156, 0.1937, 0.3585, 0.1761])
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_ddim_cifar10(self):
|
||||
generator = torch.manual_seed(0)
|
||||
model_id = "fusing/ddpm-cifar10"
|
||||
|
||||
ddim = DDIM.from_pretrained(model_id)
|
||||
image = ddim(generator=generator, eta=0.0)
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
|
||||
assert image.shape == (1, 3, 32, 32)
|
||||
expected_slice = torch.tensor([-0.7688, -0.7690, -0.7597, -0.7660, -0.7713, -0.7531, -0.7009, -0.7098, -0.7350])
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
|
|
Loading…
Reference in New Issue