finish first version sde ve

This commit is contained in:
Patrick von Platen 2022-06-25 02:50:42 +00:00
parent bc2d586dcb
commit de810814da
6 changed files with 180 additions and 18 deletions

View File

@ -10,7 +10,7 @@ from .modeling_utils import ModelMixin
from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel
from .pipeline_utils import DiffusionPipeline
from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin, VeSdeScheduler
if is_transformers_available():

View File

@ -15,10 +15,6 @@
# helpers functions
from ..modeling_utils import ModelMixin
from ..configuration_utils import ConfigMixin
import functools
import math
import string
@ -28,16 +24,15 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
return upfirdn2d_native(
input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
)
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
def upfirdn2d_native(
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
):
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)
@ -48,9 +43,7 @@ def upfirdn2d_native(
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
)
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
@ -59,9 +52,7 @@ def upfirdn2d_native(
]
out = out.permute(0, 3, 1, 2)
out = out.reshape(
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
)
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
@ -350,7 +341,7 @@ conv3x3 = ddpm_conv3x3
def _einsum(a, b, c, x, y):
einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c))
einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c))
return torch.einsum(einsum_str, x, y)

View File

@ -5,6 +5,9 @@ from .pipeline_ddpm import DDPMPipeline
from .pipeline_pndm import PNDMPipeline
# from .pipeline_score_sde import NCSNppPipeline
if is_transformers_available():
from .pipeline_glide import GlidePipeline
from .pipeline_latent_diffusion import LatentDiffusionPipeline

View File

@ -0,0 +1,94 @@
#!/usr/bin/env python3
import numpy as np
import torch
import PIL
from diffusers import DiffusionPipeline
# from configs.ve import ffhq_ncsnpp_continuous as configs
# from configs.ve import cifar10_ncsnpp_continuous as configs
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
# Note usually we need to restore ema etc...
# ema restored checkpoint used from below
torch.backends.cuda.matmul.allow_tf32 = False
torch.manual_seed(0)
class NCSNppPipeline(DiffusionPipeline):
def __init__(self, model, scheduler):
super().__init__()
self.register_modules(model=model, scheduler=scheduler)
def __call__(self, generator=None):
N = self.scheduler.config.N
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
img_size = self.model.config.image_size
channels = self.model.config.num_channels
shape = (1, channels, img_size, img_size)
model = torch.nn.DataParallel(self.model.to(device))
centered = False
n_steps = 1
# Initial sample
x = torch.randn(*shape) * self.scheduler.config.sigma_max
x = x.to(device)
for i in range(N):
sigma_t = self.scheduler.get_sigma_t(i) * torch.ones(shape[0], device=device)
for _ in range(n_steps):
with torch.no_grad():
result = model(x, sigma_t)
x = self.scheduler.step_correct(result, x)
with torch.no_grad():
result = model(x, sigma_t)
x, x_mean = self.scheduler.step_pred(result, x, i)
x = x_mean
if centered:
x = (x + 1.0) / 2.0
return x
pipeline = NCSNppPipeline.from_pretrained("/home/patrick/ffhq_ncsnpp")
x = pipeline()
# for 5 cifar10
# x_sum = 106071.9922
# x_mean = 34.52864456176758
# for 1000 cifar10
# x_sum = 461.9700
# x_mean = 0.1504
# for N=2 for 1024
x_sum = 3382810112.0
x_mean = 1075.366455078125
def check_x_sum_x_mean(x, x_sum, x_mean):
assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
check_x_sum_x_mean(x, x_sum, x_mean)
def save_image(x):
image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("../images/hey.png")
# save_image(x)

View File

@ -21,3 +21,4 @@ from .scheduling_ddpm import DDPMScheduler
from .scheduling_grad_tts import GradTTSScheduler
from .scheduling_pndm import PNDMScheduler
from .scheduling_utils import SchedulerMixin
from .scheduling_ve_sde import VeSdeScheduler

View File

@ -0,0 +1,73 @@
# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import numpy as np
import torch
from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin
class VeSdeScheduler(SchedulerMixin, ConfigMixin):
def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, N=2, sampling_eps=1e-5, tensor_format="np"):
super().__init__()
self.register_to_config(
snr=snr,
sigma_min=sigma_min,
sigma_max=sigma_max,
N=N,
sampling_eps=sampling_eps,
)
# (PVP) - clean up with .config.
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.snr = snr
self.N = N
self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))
self.timesteps = torch.linspace(1, sampling_eps, N)
def get_sigma_t(self, t):
return self.sigma_min * (self.sigma_max / self.sigma_min) ** self.timesteps[t]
def step_pred(self, result, x, t):
t = self.timesteps[t] * torch.ones(x.shape[0], device=x.device)
timestep = (t * (self.N - 1)).long()
sigma = self.discrete_sigmas.to(t.device)[timestep]
adjacent_sigma = torch.where(
timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(t.device)
)
f = torch.zeros_like(x)
G = torch.sqrt(sigma**2 - adjacent_sigma**2)
f = f - G[:, None, None, None] ** 2 * result
z = torch.randn_like(x)
x_mean = x - f
x = x_mean + G[:, None, None, None] * z
return x, x_mean
def step_correct(self, result, x):
noise = torch.randn_like(x)
grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean()
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
step_size = (self.snr * noise_norm / grad_norm) ** 2 * 2
step_size = step_size * torch.ones(x.shape[0], device=x.device)
x_mean = x + step_size[:, None, None, None] * result
x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise
return x