finish first version sde ve
This commit is contained in:
parent
bc2d586dcb
commit
de810814da
|
@ -10,7 +10,7 @@ from .modeling_utils import ModelMixin
|
|||
from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline
|
||||
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin
|
||||
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin, VeSdeScheduler
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
|
|
|
@ -15,10 +15,6 @@
|
|||
|
||||
# helpers functions
|
||||
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..configuration_utils import ConfigMixin
|
||||
|
||||
|
||||
import functools
|
||||
import math
|
||||
import string
|
||||
|
@ -28,16 +24,15 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
||||
return upfirdn2d_native(
|
||||
input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
|
||||
)
|
||||
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
|
||||
|
||||
|
||||
def upfirdn2d_native(
|
||||
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
||||
):
|
||||
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
|
||||
_, channel, in_h, in_w = input.shape
|
||||
input = input.reshape(-1, in_h, in_w, 1)
|
||||
|
||||
|
@ -48,9 +43,7 @@ def upfirdn2d_native(
|
|||
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
||||
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
||||
|
||||
out = F.pad(
|
||||
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
|
||||
)
|
||||
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
||||
out = out[
|
||||
:,
|
||||
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
||||
|
@ -59,9 +52,7 @@ def upfirdn2d_native(
|
|||
]
|
||||
|
||||
out = out.permute(0, 3, 1, 2)
|
||||
out = out.reshape(
|
||||
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
|
||||
)
|
||||
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
||||
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
||||
out = F.conv2d(out, w)
|
||||
out = out.reshape(
|
||||
|
@ -350,7 +341,7 @@ conv3x3 = ddpm_conv3x3
|
|||
|
||||
|
||||
def _einsum(a, b, c, x, y):
|
||||
einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c))
|
||||
einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c))
|
||||
return torch.einsum(einsum_str, x, y)
|
||||
|
||||
|
||||
|
|
|
@ -5,6 +5,9 @@ from .pipeline_ddpm import DDPMPipeline
|
|||
from .pipeline_pndm import PNDMPipeline
|
||||
|
||||
|
||||
# from .pipeline_score_sde import NCSNppPipeline
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from .pipeline_glide import GlidePipeline
|
||||
from .pipeline_latent_diffusion import LatentDiffusionPipeline
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
#!/usr/bin/env python3
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
|
||||
# from configs.ve import ffhq_ncsnpp_continuous as configs
|
||||
# from configs.ve import cifar10_ncsnpp_continuous as configs
|
||||
|
||||
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
|
||||
# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
|
||||
# Note usually we need to restore ema etc...
|
||||
# ema restored checkpoint used from below
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class NCSNppPipeline(DiffusionPipeline):
|
||||
def __init__(self, model, scheduler):
|
||||
super().__init__()
|
||||
self.register_modules(model=model, scheduler=scheduler)
|
||||
|
||||
def __call__(self, generator=None):
|
||||
N = self.scheduler.config.N
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
img_size = self.model.config.image_size
|
||||
channels = self.model.config.num_channels
|
||||
shape = (1, channels, img_size, img_size)
|
||||
|
||||
model = torch.nn.DataParallel(self.model.to(device))
|
||||
|
||||
centered = False
|
||||
n_steps = 1
|
||||
|
||||
# Initial sample
|
||||
x = torch.randn(*shape) * self.scheduler.config.sigma_max
|
||||
x = x.to(device)
|
||||
|
||||
for i in range(N):
|
||||
sigma_t = self.scheduler.get_sigma_t(i) * torch.ones(shape[0], device=device)
|
||||
|
||||
for _ in range(n_steps):
|
||||
with torch.no_grad():
|
||||
result = model(x, sigma_t)
|
||||
x = self.scheduler.step_correct(result, x)
|
||||
|
||||
with torch.no_grad():
|
||||
result = model(x, sigma_t)
|
||||
|
||||
x, x_mean = self.scheduler.step_pred(result, x, i)
|
||||
|
||||
x = x_mean
|
||||
|
||||
if centered:
|
||||
x = (x + 1.0) / 2.0
|
||||
|
||||
return x
|
||||
|
||||
|
||||
pipeline = NCSNppPipeline.from_pretrained("/home/patrick/ffhq_ncsnpp")
|
||||
x = pipeline()
|
||||
|
||||
|
||||
# for 5 cifar10
|
||||
# x_sum = 106071.9922
|
||||
# x_mean = 34.52864456176758
|
||||
|
||||
# for 1000 cifar10
|
||||
# x_sum = 461.9700
|
||||
# x_mean = 0.1504
|
||||
|
||||
# for N=2 for 1024
|
||||
x_sum = 3382810112.0
|
||||
x_mean = 1075.366455078125
|
||||
|
||||
|
||||
def check_x_sum_x_mean(x, x_sum, x_mean):
|
||||
assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
|
||||
assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
|
||||
|
||||
|
||||
check_x_sum_x_mean(x, x_sum, x_mean)
|
||||
|
||||
|
||||
def save_image(x):
|
||||
image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
|
||||
image_pil = PIL.Image.fromarray(image_processed[0])
|
||||
image_pil.save("../images/hey.png")
|
||||
|
||||
|
||||
# save_image(x)
|
|
@ -21,3 +21,4 @@ from .scheduling_ddpm import DDPMScheduler
|
|||
from .scheduling_grad_tts import GradTTSScheduler
|
||||
from .scheduling_pndm import PNDMScheduler
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
from .scheduling_ve_sde import VeSdeScheduler
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
class VeSdeScheduler(SchedulerMixin, ConfigMixin):
|
||||
def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, N=2, sampling_eps=1e-5, tensor_format="np"):
|
||||
super().__init__()
|
||||
self.register_to_config(
|
||||
snr=snr,
|
||||
sigma_min=sigma_min,
|
||||
sigma_max=sigma_max,
|
||||
N=N,
|
||||
sampling_eps=sampling_eps,
|
||||
)
|
||||
# (PVP) - clean up with .config.
|
||||
self.sigma_min = sigma_min
|
||||
self.sigma_max = sigma_max
|
||||
self.snr = snr
|
||||
self.N = N
|
||||
self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))
|
||||
self.timesteps = torch.linspace(1, sampling_eps, N)
|
||||
|
||||
def get_sigma_t(self, t):
|
||||
return self.sigma_min * (self.sigma_max / self.sigma_min) ** self.timesteps[t]
|
||||
|
||||
def step_pred(self, result, x, t):
|
||||
t = self.timesteps[t] * torch.ones(x.shape[0], device=x.device)
|
||||
|
||||
timestep = (t * (self.N - 1)).long()
|
||||
sigma = self.discrete_sigmas.to(t.device)[timestep]
|
||||
adjacent_sigma = torch.where(
|
||||
timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(t.device)
|
||||
)
|
||||
f = torch.zeros_like(x)
|
||||
G = torch.sqrt(sigma**2 - adjacent_sigma**2)
|
||||
|
||||
f = f - G[:, None, None, None] ** 2 * result
|
||||
|
||||
z = torch.randn_like(x)
|
||||
x_mean = x - f
|
||||
x = x_mean + G[:, None, None, None] * z
|
||||
return x, x_mean
|
||||
|
||||
def step_correct(self, result, x):
|
||||
noise = torch.randn_like(x)
|
||||
grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean()
|
||||
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
|
||||
step_size = (self.snr * noise_norm / grad_norm) ** 2 * 2
|
||||
step_size = step_size * torch.ones(x.shape[0], device=x.device)
|
||||
x_mean = x + step_size[:, None, None, None] * result
|
||||
|
||||
x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise
|
||||
|
||||
return x
|
Loading…
Reference in New Issue