port first 1024 model

This commit is contained in:
Patrick von Platen 2022-06-24 19:44:17 +00:00
parent 78e99a997b
commit 49a81f9f1a
1 changed files with 150 additions and 200 deletions

350
run.py
View File

@ -1,104 +1,128 @@
#!/usr/bin/env python3
import numpy as np
import PIL
import functools
import models
from models import utils as mutils
from models import ncsnv2
from models import ncsnpp
from models import ddpm as ddpm_model
from models import layerspp
from models import layers
from models import normalization
from models.ema import ExponentialMovingAverage
from losses import get_optimizer
from utils import restore_checkpoint
import sampling
from sde_lib import VESDE, VPSDE, subVPSDE
from sampling import (NoneCorrector,
ReverseDiffusionPredictor,
LangevinCorrector,
EulerMaruyamaPredictor,
AncestralSamplingPredictor,
NonePredictor,
AnnealedLangevinDynamics)
import datasets
import torch
import ml_collections
#from configs.ve import ffhq_ncsnpp_continuous as configs
# from configs.ve import cifar10_ncsnpp_continuous as configs
# ffhq_ncsnpp_continuous config
def get_config():
config = ml_collections.ConfigDict()
# training
config.training = training = ml_collections.ConfigDict()
training.batch_size = 8
training.n_iters = 2400001
training.snapshot_freq = 50000
training.log_freq = 50
training.eval_freq = 100
training.snapshot_freq_for_preemption = 5000
training.snapshot_sampling = True
training.sde = 'vesde'
training.continuous = True
training.likelihood_weighting = False
training.reduce_mean = True
# sampling
config.sampling = sampling = ml_collections.ConfigDict()
sampling.method = 'pc'
sampling.predictor = 'reverse_diffusion'
sampling.corrector = 'langevin'
sampling.probability_flow = False
sampling.snr = 0.15
sampling.n_steps_each = 1
sampling.noise_removal = True
# eval
config.eval = evaluate = ml_collections.ConfigDict()
evaluate.batch_size = 1024
evaluate.num_samples = 50000
evaluate.begin_ckpt = 1
evaluate.end_ckpt = 96
# data
config.data = data = ml_collections.ConfigDict()
data.dataset = 'FFHQ'
data.image_size = 1024
data.centered = False
data.random_flip = True
data.uniform_dequantization = False
data.num_channels = 3
# Plug in your own path to the tfrecords file.
data.tfrecords_path = '/raid/song/ffhq-dataset/ffhq/ffhq-r10.tfrecords'
# model
config.model = model = ml_collections.ConfigDict()
model.name = 'ncsnpp'
model.scale_by_sigma = True
model.sigma_max = 1348
model.num_scales = 2000
model.ema_rate = 0.9999
model.sigma_min = 0.01
model.normalization = 'GroupNorm'
model.nonlinearity = 'swish'
model.nf = 16
model.ch_mult = (1, 2, 4, 8, 16, 32, 32, 32)
model.num_res_blocks = 1
model.attn_resolutions = (16,)
model.dropout = 0.
model.resamp_with_conv = True
model.conditional = True
model.fir = True
model.fir_kernel = [1, 3, 3, 1]
model.skip_rescale = True
model.resblock_type = 'biggan'
model.progressive = 'output_skip'
model.progressive_input = 'input_skip'
model.progressive_combine = 'sum'
model.attention_type = 'ddpm'
model.init_scale = 0.
model.fourier_scale = 16
model.conv_size = 3
model.embedding_type = 'fourier'
# optim
config.optim = optim = ml_collections.ConfigDict()
optim.weight_decay = 0
optim.optimizer = 'Adam'
optim.lr = 2e-4
optim.beta1 = 0.9
optim.amsgrad = False
optim.eps = 1e-8
optim.warmup = 5000
optim.grad_clip = 1.
config.seed = 42
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
return config
torch.backends.cuda.matmul.allow_tf32 = False
torch.manual_seed(0)
#class NewVESDE(SDE):
# def __init__(self, sigma_min=0.01, sigma_max=50, N=1000):
# """Construct a Variance Exploding SDE.
#
# Args:
# sigma_min: smallest sigma.
# sigma_max: largest sigma.
# N: number of discretization steps
# """
# super().__init__(N)
# self.sigma_min = sigma_min
# self.sigma_max = sigma_max
# self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))
# self.N = N
#
# @property
# def T(self):
# return 1
#
# def sde(self, x, t):
# sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
# drift = torch.zeros_like(x)
# diffusion = sigma * torch.sqrt(torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min)),
# device=t.device))
# return drift, diffusion
#
# def marginal_prob(self, x, t):
# std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
# mean = x
# return mean, std
#
# def prior_sampling(self, shape):
# return torch.randn(*shape) * self.sigma_max
#
# def prior_logp(self, z):
# shape = z.shape
# N = np.prod(shape[1:])
# return -N / 2. * np.log(2 * np.pi * self.sigma_max ** 2) - torch.sum(z ** 2, dim=(1, 2, 3)) / (2 * self.sigma_max ** 2)
#
# def discretize(self, x, t):
# """SMLD(NCSN) discretization."""
# timestep = (t * (self.N - 1) / self.T).long()
# sigma = self.discrete_sigmas.to(t.device)[timestep]
# adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),
# self.discrete_sigmas[timestep - 1].to(t.device))
# f = torch.zeros_like(x)
# G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)
# return f, G
torch.manual_seed(3)
class NewReverseDiffusionPredictor:
def __init__(self, sde, score_fn, probability_flow=False):
def __init__(self, score_fn, probability_flow=False, sigma_min=0.0, sigma_max=0.0, N=0):
super().__init__()
self.sde = sde
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.N = N
self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))
self.probability_flow = probability_flow
self.score_fn = score_fn
def discretize(self, x, t):
timestep = (t * (self.sde.N - 1) / self.sde.T).long()
sigma = self.sde.discrete_sigmas.to(t.device)[timestep]
timestep = (t * (self.N - 1)).long()
sigma = self.discrete_sigmas.to(t.device)[timestep]
adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),
self.sde.discrete_sigmas[timestep - 1].to(t.device))
self.discrete_sigmas[timestep - 1].to(t.device))
f = torch.zeros_like(x)
G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)
labels = self.sde.marginal_prob(torch.zeros_like(x), t)[1]
labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
result = self.score_fn(x, labels)
rev_f = f - G[:, None, None, None] ** 2 * result * (0.5 if self.probability_flow else 1.)
@ -114,26 +138,27 @@ class NewReverseDiffusionPredictor:
class NewLangevinCorrector:
def __init__(self, sde, score_fn, snr, n_steps):
def __init__(self, score_fn, snr, n_steps, sigma_min=0.0, sigma_max=0.0):
super().__init__()
self.sde = sde
self.score_fn = score_fn
self.snr = snr
self.n_steps = n_steps
self.sigma_min = sigma_min
self.sigma_max = sigma_max
def update_fn(self, x, t):
sde = self.sde
score_fn = self.score_fn
n_steps = self.n_steps
target_snr = self.snr
if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE):
timestep = (t * (sde.N - 1) / sde.T).long()
alpha = sde.alphas.to(t.device)[timestep]
else:
alpha = torch.ones_like(t)
# if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE):
# timestep = (t * (sde.N - 1) / sde.T).long()
# alpha = sde.alphas.to(t.device)[timestep]
# else:
alpha = torch.ones_like(t)
for i in range(n_steps):
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
grad = score_fn(x, labels)
noise = torch.randn_like(x)
grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
@ -152,64 +177,42 @@ def save_image(x):
image_pil.save("../images/hey.png")
sde = 'VESDE' #@param ['VESDE', 'VPSDE', 'subVPSDE'] {"type": "string"}
if sde.lower() == 'vesde':
# from configs.ve import cifar10_ncsnpp_continuous as configs
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
from configs.ve import ffhq_ncsnpp_continuous as configs
ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
config = configs.get_config()
config.model.num_scales = 2
sde = VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
sampling_eps = 1e-5
elif sde.lower() == 'vpsde':
from configs.vp import cifar10_ddpmpp_continuous as configs
ckpt_filename = "exp/vp/cifar10_ddpmpp_continuous/checkpoint_8.pth"
config = configs.get_config()
sde = VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
sampling_eps = 1e-3
elif sde.lower() == 'subvpsde':
from configs.subvp import cifar10_ddpmpp_continuous as configs
ckpt_filename = "exp/subvp/cifar10_ddpmpp_continuous/checkpoint_26.pth"
config = configs.get_config()
sde = subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
sampling_eps = 1e-3
#ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
# Note usually we need to restore ema etc...
# ema restored checkpoint used from below
config = get_config()
sigma_min, sigma_max = config.model.sigma_min, config.model.sigma_max
N = config.model.num_scales
sampling_eps = 1e-5
batch_size = 1 #@param {"type":"integer"}
config.training.batch_size = batch_size
config.eval.batch_size = batch_size
random_seed = 0 #@param {"type": "integer"}
#sigmas = mutils.get_sigmas(config)
#scaler = datasets.get_data_scaler(config)
#inverse_scaler = datasets.get_data_inverse_scaler(config)
#score_model = mutils.create_model(config)
#
#optimizer = get_optimizer(config, score_model.parameters())
#ema = ExponentialMovingAverage(score_model.parameters(),
# decay=config.model.ema_rate)
#state = dict(step=0, optimizer=optimizer,
# model=score_model, ema=ema)
#
#state = restore_checkpoint(ckpt_filename, state, config.device)
#ema.copy_to(score_model.parameters())
#score_model = mutils.create_model(config)
from diffusers import NCSNpp
score_model = NCSNpp(config).to(config.device)
score_model = torch.nn.DataParallel(score_model)
model = NCSNpp(config).to(config.device)
model = torch.nn.DataParallel(model)
loaded_state = torch.load("./ffhq_1024_ncsnpp_continuous_ema.pt")
loaded_state = torch.load("../score_sde_pytorch/ffhq_1024_ncsnpp_continuous_ema.pt")
del loaded_state["module.sigmas"]
score_model.load_state_dict(loaded_state, strict=False)
model.load_state_dict(loaded_state, strict=False)
inverse_scaler = datasets.get_data_inverse_scaler(config)
predictor = ReverseDiffusionPredictor #@param ["EulerMaruyamaPredictor", "AncestralSamplingPredictor", "ReverseDiffusionPredictor", "None"] {"type": "raw"}
corrector = LangevinCorrector #@param ["LangevinCorrector", "AnnealedLangevinDynamics", "None"] {"type": "raw"}
def get_data_inverse_scaler(config):
"""Inverse data normalizer."""
if config.data.centered:
# Rescale [-1, 1] to [0, 1]
return lambda x: (x + 1.) / 2.
else:
return lambda x: x
inverse_scaler = get_data_inverse_scaler(config)
#@title PC sampling
img_size = config.data.image_size
channels = config.data.num_channels
shape = (batch_size, channels, img_size, img_size)
@ -218,80 +221,27 @@ snr = 0.15 #@param {"type": "number"}
n_steps = 1#@param {"type": "integer"}
#sampling_fn = sampling.get_pc_sampler(sde, shape, predictor, corrector,
# inverse_scaler, snr, n_steps=n_steps,
# probability_flow=probability_flow,
# continuous=config.training.continuous,
# eps=sampling_eps, device=config.device)
#
#x, n = sampling_fn(score_model)
#save_image(x)
def shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous):
"""A wrapper that configures and returns the update function of predictors."""
score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous)
if predictor is None:
# Corrector-only sampler
predictor_obj = NonePredictor(sde, score_fn, probability_flow)
else:
predictor_obj = predictor(sde, score_fn, probability_flow)
return predictor_obj.update_fn(x, t)
def shared_corrector_update_fn(x, t, sde, model, corrector, continuous, snr, n_steps):
"""A wrapper tha configures and returns the update function of correctors."""
score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous)
if corrector is None:
# Predictor-only sampler
corrector_obj = NoneCorrector(sde, score_fn, snr, n_steps)
else:
corrector_obj = corrector(sde, score_fn, snr, n_steps)
return corrector_obj.update_fn(x, t)
continuous = config.training.continuous
predictor_update_fn = functools.partial(shared_predictor_update_fn,
sde=sde,
predictor=predictor,
probability_flow=probability_flow,
continuous=continuous)
corrector_update_fn = functools.partial(shared_corrector_update_fn,
sde=sde,
corrector=corrector,
continuous=continuous,
snr=snr,
n_steps=n_steps)
device = config.device
model = score_model
denoise = True
new_corrector = NewLangevinCorrector(sde=sde, score_fn=model, snr=snr, n_steps=n_steps)
new_predictor = NewReverseDiffusionPredictor(sde=sde, score_fn=model)
new_corrector = NewLangevinCorrector(score_fn=model, snr=snr, n_steps=n_steps, sigma_min=sigma_min, sigma_max=sigma_max)
new_predictor = NewReverseDiffusionPredictor(score_fn=model, sigma_min=sigma_min, sigma_max=sigma_max, N=N)
#
with torch.no_grad():
# Initial sample
x = sde.prior_sampling(shape).to(device)
timesteps = torch.linspace(sde.T, sampling_eps, sde.N, device=device)
x = torch.randn(*shape) * sigma_max
x = x.to(device)
timesteps = torch.linspace(1, sampling_eps, N, device=device)
for i in range(sde.N):
for i in range(N):
t = timesteps[i]
vec_t = torch.ones(shape[0], device=t.device) * t
# x, x_mean = corrector_update_fn(x, vec_t, model=model)
# x, x_mean = predictor_update_fn(x, vec_t, model=model)
x, x_mean = new_corrector.update_fn(x, vec_t)
x, x_mean = new_predictor.update_fn(x, vec_t)
x, n = inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1)
x = inverse_scaler(x_mean)
#save_image(x)
save_image(x)
# for 5 cifar10
x_sum = 106071.9922
@ -310,4 +260,4 @@ def check_x_sum_x_mean(x, x_sum, x_mean):
assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
check_x_sum_x_mean(x, x_sum, x_mean)
#check_x_sum_x_mean(x, x_sum, x_mean)