clean up sde ve more

This commit is contained in:
Patrick von Platen 2022-06-25 18:25:43 +00:00
parent de810814da
commit 433cb3f801
7 changed files with 120 additions and 58 deletions

View File

@ -226,6 +226,30 @@ image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("test.png")
```
#### **Example 1024x1024 image generation with SDE VE**
See [paper](https://arxiv.org/abs/2011.13456) for more information on SDE VE.
```python
from diffusers import DiffusionPipeline
import torch
import PIL.Image
torch.manual_seed(32)
score_sde_sv = DiffusionPipeline.from_pretrained("fusing/ffhq_ncsnpp")
# Note this might take up to 3 minutes on a GPU
image = score_sde_sv(num_inference_steps=2000)
image = image.permute(0, 2, 3, 1).cpu().numpy()
image = np.clip(image * 255, 0, 255).astype(np.uint8)
image_pil = PIL.Image.fromarray(image[0])
# save image
image_pil.save("test.png")
```
#### **Text to Image generation with Latent Diffusion**
_Note: To use latent diffusion install transformers from [this branch](https://github.com/patil-suraj/transformers/tree/ldm-bert)._

View File

@ -9,8 +9,15 @@ __version__ = "0.0.4"
from .modeling_utils import ModelMixin
from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel
from .pipeline_utils import DiffusionPipeline
from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin, VeSdeScheduler
from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline, ScoreSdeVePipeline
from .schedulers import (
DDIMScheduler,
DDPMScheduler,
GradTTSScheduler,
PNDMScheduler,
SchedulerMixin,
ScoreSdeVeScheduler,
)
if is_transformers_available():

View File

@ -3,9 +3,10 @@ from .pipeline_bddm import BDDMPipeline
from .pipeline_ddim import DDIMPipeline
from .pipeline_ddpm import DDPMPipeline
from .pipeline_pndm import PNDMPipeline
from .pipeline_score_sde_ve import ScoreSdeVePipeline
# from .pipeline_score_sde import NCSNppPipeline
# from .pipeline_score_sde import ScoreSdeVePipeline
if is_transformers_available():

View File

@ -6,51 +6,44 @@ import PIL
from diffusers import DiffusionPipeline
# from configs.ve import ffhq_ncsnpp_continuous as configs
# from configs.ve import cifar10_ncsnpp_continuous as configs
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
# Note usually we need to restore ema etc...
# ema restored checkpoint used from below
torch.backends.cuda.matmul.allow_tf32 = False
torch.manual_seed(0)
# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names
class NCSNppPipeline(DiffusionPipeline):
class ScoreSdeVePipeline(DiffusionPipeline):
def __init__(self, model, scheduler):
super().__init__()
self.register_modules(model=model, scheduler=scheduler)
def __call__(self, generator=None):
N = self.scheduler.config.N
def __call__(self, num_inference_steps=2000, generator=None):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
img_size = self.model.config.image_size
channels = self.model.config.num_channels
shape = (1, channels, img_size, img_size)
model = torch.nn.DataParallel(self.model.to(device))
model = self.model.to(device)
centered = False
n_steps = 1
# Initial sample
x = torch.randn(*shape) * self.scheduler.config.sigma_max
x = x.to(device)
for i in range(N):
sigma_t = self.scheduler.get_sigma_t(i) * torch.ones(shape[0], device=device)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.set_sigmas(num_inference_steps)
for i, t in enumerate(self.scheduler.timesteps):
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device)
for _ in range(n_steps):
with torch.no_grad():
result = model(x, sigma_t)
result = self.model(x, sigma_t)
x = self.scheduler.step_correct(result, x)
with torch.no_grad():
result = model(x, sigma_t)
x, x_mean = self.scheduler.step_pred(result, x, i)
x, x_mean = self.scheduler.step_pred(result, x, t)
x = x_mean
@ -60,9 +53,16 @@ class NCSNppPipeline(DiffusionPipeline):
return x
pipeline = NCSNppPipeline.from_pretrained("/home/patrick/ffhq_ncsnpp")
x = pipeline()
# from configs.ve import ffhq_ncsnpp_continuous as configs
# from configs.ve import cifar10_ncsnpp_continuous as configs
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
# Note usually we need to restore ema etc...
# ema restored checkpoint used from below
# pipeline = ScoreSdeVePipeline.from_pretrained("/home/patrick/ffhq_ncsnpp")
# x = pipeline(num_inference_steps=2)
# for 5 cifar10
# x_sum = 106071.9922
@ -73,22 +73,22 @@ x = pipeline()
# x_mean = 0.1504
# for N=2 for 1024
x_sum = 3382810112.0
x_mean = 1075.366455078125
def check_x_sum_x_mean(x, x_sum, x_mean):
assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
check_x_sum_x_mean(x, x_sum, x_mean)
def save_image(x):
image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("../images/hey.png")
# x_sum = 3382810112.0
# x_mean = 1075.366455078125
#
#
# def check_x_sum_x_mean(x, x_sum, x_mean):
# assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
# assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
#
#
# check_x_sum_x_mean(x, x_sum, x_mean)
#
#
# def save_image(x):
# image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
# image_pil = PIL.Image.fromarray(image_processed[0])
# image_pil.save("../images/hey.png")
#
#
# save_image(x)

View File

@ -21,4 +21,4 @@ from .scheduling_ddpm import DDPMScheduler
from .scheduling_grad_tts import GradTTSScheduler
from .scheduling_pndm import PNDMScheduler
from .scheduling_utils import SchedulerMixin
from .scheduling_ve_sde import VeSdeScheduler
from .scheduling_sde_ve import ScoreSdeVeScheduler

View File

@ -1,4 +1,4 @@
# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
import numpy as np
import torch
@ -21,7 +23,7 @@ from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin
class VeSdeScheduler(SchedulerMixin, ConfigMixin):
class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, N=2, sampling_eps=1e-5, tensor_format="np"):
super().__init__()
self.register_to_config(
@ -31,24 +33,32 @@ class VeSdeScheduler(SchedulerMixin, ConfigMixin):
N=N,
sampling_eps=sampling_eps,
)
# (PVP) - clean up with .config.
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.snr = snr
self.N = N
self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))
self.timesteps = torch.linspace(1, sampling_eps, N)
def get_sigma_t(self, t):
return self.sigma_min * (self.sigma_max / self.sigma_min) ** self.timesteps[t]
self.sigmas = None
self.discrete_sigmas = None
self.timesteps = None
def set_timesteps(self, num_inference_steps):
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
def set_sigmas(self, num_inference_steps):
if self.timesteps is None:
self.set_timesteps(num_inference_steps)
self.discrete_sigmas = torch.exp(
torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps)
)
self.sigmas = torch.tensor(
[self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps]
)
def step_pred(self, result, x, t):
t = self.timesteps[t] * torch.ones(x.shape[0], device=x.device)
t = t * torch.ones(x.shape[0], device=x.device)
timestep = (t * (2 - 1)).long()
timestep = (t * (self.N - 1)).long()
sigma = self.discrete_sigmas.to(t.device)[timestep]
adjacent_sigma = torch.where(
timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(t.device)
timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(timestep.device)
)
f = torch.zeros_like(x)
G = torch.sqrt(sigma**2 - adjacent_sigma**2)
@ -64,7 +74,7 @@ class VeSdeScheduler(SchedulerMixin, ConfigMixin):
noise = torch.randn_like(x)
grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean()
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
step_size = (self.snr * noise_norm / grad_norm) ** 2 * 2
step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
step_size = step_size * torch.ones(x.shape[0], device=x.device)
x_mean = x + step_size[:, None, None, None] * result

View File

@ -33,8 +33,11 @@ from diffusers import (
GradTTSPipeline,
GradTTSScheduler,
LatentDiffusionPipeline,
NCSNpp,
PNDMPipeline,
PNDMScheduler,
ScoreSdeVePipeline,
ScoreSdeVeScheduler,
UNetGradTTSModel,
UNetLDMModel,
UNetModel,
@ -721,6 +724,23 @@ class PipelineTesterMixin(unittest.TestCase):
)
assert (mel_spec[0, :3, :3].cpu().flatten() - expected_slice).abs().max() < 1e-2
@slow
def test_score_sde_ve_pipeline(self):
torch.manual_seed(0)
model = NCSNpp.from_pretrained("fusing/ffhq_ncsnpp")
scheduler = ScoreSdeVeScheduler.from_config("fusing/ffhq_ncsnpp")
sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)
image = sde_ve(num_inference_steps=2)
expected_image_sum = 3382810112.0
expected_image_mean = 1075.366455078125
assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4
def test_module_from_pipeline(self):
model = DiffWave(num_res_layers=4)
noise_scheduler = DDPMScheduler(timesteps=12)