This commit is contained in:
Patrick von Platen 2022-06-27 10:46:18 +02:00
commit 0183bf13c7
11 changed files with 280 additions and 119 deletions

View File

@ -226,6 +226,56 @@ image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("test.png")
```
#### **Example 1024x1024 image generation with SDE VE**
See [paper](https://arxiv.org/abs/2011.13456) for more information on SDE VE.
```python
from diffusers import DiffusionPipeline
import torch
import PIL.Image
import numpy as np
torch.manual_seed(32)
score_sde_sv = DiffusionPipeline.from_pretrained("fusing/ffhq_ncsnpp")
# Note this might take up to 3 minutes on a GPU
image = score_sde_sv(num_inference_steps=2000)
image = image.permute(0, 2, 3, 1).cpu().numpy()
image = np.clip(image * 255, 0, 255).astype(np.uint8)
image_pil = PIL.Image.fromarray(image[0])
# save image
image_pil.save("test.png")
```
#### **Example 32x32 image generation with SDE VP**
See [paper](https://arxiv.org/abs/2011.13456) for more information on SDE VE.
```python
from diffusers import DiffusionPipeline
import torch
import PIL.Image
import numpy as np
torch.manual_seed(32)
score_sde_sv = DiffusionPipeline.from_pretrained("fusing/cifar10-ddpmpp-deep-vp")
# Note this might take up to 3 minutes on a GPU
image = score_sde_sv(num_inference_steps=1000)
image = image.permute(0, 2, 3, 1).cpu().numpy()
image = np.clip(image * 255, 0, 255).astype(np.uint8)
image_pil = PIL.Image.fromarray(image[0])
# save image
image_pil.save("test.png")
```
#### **Text to Image generation with Latent Diffusion**
_Note: To use latent diffusion install transformers from [this branch](https://github.com/patil-suraj/transformers/tree/ldm-bert)._

View File

@ -9,8 +9,16 @@ __version__ = "0.0.4"
from .modeling_utils import ModelMixin
from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel
from .pipeline_utils import DiffusionPipeline
from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin, VeSdeScheduler
from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline, ScoreSdeVePipeline, ScoreSdeVpPipeline
from .schedulers import (
DDIMScheduler,
DDPMScheduler,
GradTTSScheduler,
PNDMScheduler,
SchedulerMixin,
ScoreSdeVeScheduler,
ScoreSdeVpScheduler,
)
if is_transformers_available():

View File

@ -766,7 +766,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
continuous=continuous,
)
self.act = act = get_act(nonlinearity)
# self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config)))
self.nf = nf
self.num_res_blocks = num_res_blocks
@ -939,7 +938,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
self.all_modules = nn.ModuleList(modules)
def forward(self, x, time_cond):
def forward(self, x, time_cond, sigmas=None):
# timestep/noise_level embedding; only for continuous training
modules = self.all_modules
m_idx = 0
@ -952,7 +951,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
elif self.embedding_type == "positional":
# Sinusoidal positional embeddings.
timesteps = time_cond
used_sigmas = self.sigmas[time_cond.long()]
used_sigmas = sigmas
temb = get_timestep_embedding(timesteps, self.nf)
else:

View File

@ -3,9 +3,11 @@ from .pipeline_bddm import BDDMPipeline
from .pipeline_ddim import DDIMPipeline
from .pipeline_ddpm import DDPMPipeline
from .pipeline_pndm import PNDMPipeline
from .pipeline_score_sde_ve import ScoreSdeVePipeline
from .pipeline_score_sde_vp import ScoreSdeVpPipeline
# from .pipeline_score_sde import NCSNppPipeline
# from .pipeline_score_sde import ScoreSdeVePipeline
if is_transformers_available():

View File

@ -1,94 +0,0 @@
#!/usr/bin/env python3
import numpy as np
import torch
import PIL
from diffusers import DiffusionPipeline
# from configs.ve import ffhq_ncsnpp_continuous as configs
# from configs.ve import cifar10_ncsnpp_continuous as configs
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
# Note usually we need to restore ema etc...
# ema restored checkpoint used from below
torch.backends.cuda.matmul.allow_tf32 = False
torch.manual_seed(0)
class NCSNppPipeline(DiffusionPipeline):
def __init__(self, model, scheduler):
super().__init__()
self.register_modules(model=model, scheduler=scheduler)
def __call__(self, generator=None):
N = self.scheduler.config.N
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
img_size = self.model.config.image_size
channels = self.model.config.num_channels
shape = (1, channels, img_size, img_size)
model = torch.nn.DataParallel(self.model.to(device))
centered = False
n_steps = 1
# Initial sample
x = torch.randn(*shape) * self.scheduler.config.sigma_max
x = x.to(device)
for i in range(N):
sigma_t = self.scheduler.get_sigma_t(i) * torch.ones(shape[0], device=device)
for _ in range(n_steps):
with torch.no_grad():
result = model(x, sigma_t)
x = self.scheduler.step_correct(result, x)
with torch.no_grad():
result = model(x, sigma_t)
x, x_mean = self.scheduler.step_pred(result, x, i)
x = x_mean
if centered:
x = (x + 1.0) / 2.0
return x
pipeline = NCSNppPipeline.from_pretrained("/home/patrick/ffhq_ncsnpp")
x = pipeline()
# for 5 cifar10
# x_sum = 106071.9922
# x_mean = 34.52864456176758
# for 1000 cifar10
# x_sum = 461.9700
# x_mean = 0.1504
# for N=2 for 1024
x_sum = 3382810112.0
x_mean = 1075.366455078125
def check_x_sum_x_mean(x, x_sum, x_mean):
assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
check_x_sum_x_mean(x, x_sum, x_mean)
def save_image(x):
image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("../images/hey.png")
# save_image(x)

View File

@ -0,0 +1,44 @@
#!/usr/bin/env python3
import torch
from diffusers import DiffusionPipeline
# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names
class ScoreSdeVePipeline(DiffusionPipeline):
def __init__(self, model, scheduler):
super().__init__()
self.register_modules(model=model, scheduler=scheduler)
def __call__(self, num_inference_steps=2000, generator=None):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
img_size = self.model.config.image_size
channels = self.model.config.num_channels
shape = (1, channels, img_size, img_size)
model = self.model.to(device)
# TODO(Patrick) move to scheduler config
n_steps = 1
x = torch.randn(*shape) * self.scheduler.config.sigma_max
x = x.to(device)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.set_sigmas(num_inference_steps)
for i, t in enumerate(self.scheduler.timesteps):
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device)
for _ in range(n_steps):
with torch.no_grad():
result = self.model(x, sigma_t)
x = self.scheduler.step_correct(result, x)
with torch.no_grad():
result = model(x, sigma_t)
x, x_mean = self.scheduler.step_pred(result, x, t)
return x_mean

View File

@ -0,0 +1,37 @@
#!/usr/bin/env python3
import torch
from diffusers import DiffusionPipeline
# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names
class ScoreSdeVpPipeline(DiffusionPipeline):
def __init__(self, model, scheduler):
super().__init__()
self.register_modules(model=model, scheduler=scheduler)
def __call__(self, num_inference_steps=1000, generator=None):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
img_size = self.model.config.image_size
channels = self.model.config.num_channels
shape = (1, channels, img_size, img_size)
model = self.model.to(device)
x = torch.randn(*shape).to(device)
self.scheduler.set_timesteps(num_inference_steps)
for t in self.scheduler.timesteps:
t = t * torch.ones(shape[0], device=device)
scaled_t = t * (num_inference_steps - 1)
with torch.no_grad():
result = model(x, scaled_t)
x, x_mean = self.scheduler.step_pred(result, x, t)
x_mean = (x_mean + 1.0) / 2.0
return x_mean

View File

@ -20,5 +20,6 @@ from .scheduling_ddim import DDIMScheduler
from .scheduling_ddpm import DDPMScheduler
from .scheduling_grad_tts import GradTTSScheduler
from .scheduling_pndm import PNDMScheduler
from .scheduling_sde_ve import ScoreSdeVeScheduler
from .scheduling_sde_vp import ScoreSdeVpScheduler
from .scheduling_utils import SchedulerMixin
from .scheduling_ve_sde import VeSdeScheduler

View File

@ -1,4 +1,4 @@
# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
import numpy as np
import torch
@ -21,34 +23,42 @@ from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin
class VeSdeScheduler(SchedulerMixin, ConfigMixin):
def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, N=2, sampling_eps=1e-5, tensor_format="np"):
class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, sampling_eps=1e-5, tensor_format="np"):
super().__init__()
self.register_to_config(
snr=snr,
sigma_min=sigma_min,
sigma_max=sigma_max,
N=N,
sampling_eps=sampling_eps,
)
# (PVP) - clean up with .config.
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.snr = snr
self.N = N
self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))
self.timesteps = torch.linspace(1, sampling_eps, N)
def get_sigma_t(self, t):
return self.sigma_min * (self.sigma_max / self.sigma_min) ** self.timesteps[t]
self.sigmas = None
self.discrete_sigmas = None
self.timesteps = None
def set_timesteps(self, num_inference_steps):
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
def set_sigmas(self, num_inference_steps):
if self.timesteps is None:
self.set_timesteps(num_inference_steps)
self.discrete_sigmas = torch.exp(
torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps)
)
self.sigmas = torch.tensor(
[self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps]
)
def step_pred(self, result, x, t):
t = self.timesteps[t] * torch.ones(x.shape[0], device=x.device)
# TODO(Patrick) better comments + non-PyTorch
t = t * torch.ones(x.shape[0], device=x.device)
timestep = (t * (len(self.timesteps) - 1)).long()
timestep = (t * (self.N - 1)).long()
sigma = self.discrete_sigmas.to(t.device)[timestep]
adjacent_sigma = torch.where(
timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(t.device)
timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(timestep.device)
)
f = torch.zeros_like(x)
G = torch.sqrt(sigma**2 - adjacent_sigma**2)
@ -61,10 +71,11 @@ class VeSdeScheduler(SchedulerMixin, ConfigMixin):
return x, x_mean
def step_correct(self, result, x):
# TODO(Patrick) better comments + non-PyTorch
noise = torch.randn_like(x)
grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean()
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
step_size = (self.snr * noise_norm / grad_norm) ** 2 * 2
step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
step_size = step_size * torch.ones(x.shape[0], device=x.device)
x_mean = x + step_size[:, None, None, None] * result

View File

@ -0,0 +1,64 @@
# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
import numpy as np
import torch
from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin
class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
def __init__(self, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"):
super().__init__()
self.register_to_config(
beta_min=beta_min,
beta_max=beta_max,
sampling_eps=sampling_eps,
)
self.sigmas = None
self.discrete_sigmas = None
self.timesteps = None
def set_timesteps(self, num_inference_steps):
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
def step_pred(self, result, x, t):
# TODO(Patrick) better comments + non-PyTorch
# postprocess model result
log_mean_coeff = (
-0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
)
std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
result = -result / std[:, None, None, None]
# compute
dt = -1.0 / len(self.timesteps)
beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min)
drift = -0.5 * beta_t[:, None, None, None] * x
diffusion = torch.sqrt(beta_t)
drift = drift - diffusion[:, None, None, None] ** 2 * result
x_mean = x + drift * dt
# add noise
z = torch.randn_like(x)
x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * z
return x, x_mean

View File

@ -33,8 +33,13 @@ from diffusers import (
GradTTSPipeline,
GradTTSScheduler,
LatentDiffusionPipeline,
NCSNpp,
PNDMPipeline,
PNDMScheduler,
ScoreSdeVePipeline,
ScoreSdeVeScheduler,
ScoreSdeVpPipeline,
ScoreSdeVpScheduler,
UNetGradTTSModel,
UNetLDMModel,
UNetModel,
@ -721,6 +726,40 @@ class PipelineTesterMixin(unittest.TestCase):
)
assert (mel_spec[0, :3, :3].cpu().flatten() - expected_slice).abs().max() < 1e-2
@slow
def test_score_sde_ve_pipeline(self):
torch.manual_seed(0)
model = NCSNpp.from_pretrained("fusing/ffhq_ncsnpp")
scheduler = ScoreSdeVeScheduler.from_config("fusing/ffhq_ncsnpp")
sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)
image = sde_ve(num_inference_steps=2)
expected_image_sum = 3382810112.0
expected_image_mean = 1075.366455078125
assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4
@slow
def test_score_sde_vp_pipeline(self):
model = NCSNpp.from_pretrained("fusing/cifar10-ddpmpp-vp")
scheduler = ScoreSdeVpScheduler.from_config("fusing/cifar10-ddpmpp-vp")
sde_vp = ScoreSdeVpPipeline(model=model, scheduler=scheduler)
torch.manual_seed(0)
image = sde_vp(num_inference_steps=10)
expected_image_sum = 4183.2012
expected_image_mean = 1.3617
assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4
def test_module_from_pipeline(self):
model = DiffWave(num_res_layers=4)
noise_scheduler = DDPMScheduler(timesteps=12)