port first 1024 model

This commit is contained in:
Patrick von Platen 2022-06-24 19:44:17 +00:00
parent 78e99a997b
commit 49a81f9f1a
1 changed files with 150 additions and 200 deletions

350
run.py
View File

@ -1,104 +1,128 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import numpy as np import numpy as np
import PIL import PIL
import functools
import models
from models import utils as mutils
from models import ncsnv2
from models import ncsnpp
from models import ddpm as ddpm_model
from models import layerspp
from models import layers
from models import normalization
from models.ema import ExponentialMovingAverage
from losses import get_optimizer
from utils import restore_checkpoint
import sampling
from sde_lib import VESDE, VPSDE, subVPSDE
from sampling import (NoneCorrector,
ReverseDiffusionPredictor,
LangevinCorrector,
EulerMaruyamaPredictor,
AncestralSamplingPredictor,
NonePredictor,
AnnealedLangevinDynamics)
import datasets
import torch import torch
import ml_collections
#from configs.ve import ffhq_ncsnpp_continuous as configs
# from configs.ve import cifar10_ncsnpp_continuous as configs
# ffhq_ncsnpp_continuous config
def get_config():
config = ml_collections.ConfigDict()
# training
config.training = training = ml_collections.ConfigDict()
training.batch_size = 8
training.n_iters = 2400001
training.snapshot_freq = 50000
training.log_freq = 50
training.eval_freq = 100
training.snapshot_freq_for_preemption = 5000
training.snapshot_sampling = True
training.sde = 'vesde'
training.continuous = True
training.likelihood_weighting = False
training.reduce_mean = True
# sampling
config.sampling = sampling = ml_collections.ConfigDict()
sampling.method = 'pc'
sampling.predictor = 'reverse_diffusion'
sampling.corrector = 'langevin'
sampling.probability_flow = False
sampling.snr = 0.15
sampling.n_steps_each = 1
sampling.noise_removal = True
# eval
config.eval = evaluate = ml_collections.ConfigDict()
evaluate.batch_size = 1024
evaluate.num_samples = 50000
evaluate.begin_ckpt = 1
evaluate.end_ckpt = 96
# data
config.data = data = ml_collections.ConfigDict()
data.dataset = 'FFHQ'
data.image_size = 1024
data.centered = False
data.random_flip = True
data.uniform_dequantization = False
data.num_channels = 3
# Plug in your own path to the tfrecords file.
data.tfrecords_path = '/raid/song/ffhq-dataset/ffhq/ffhq-r10.tfrecords'
# model
config.model = model = ml_collections.ConfigDict()
model.name = 'ncsnpp'
model.scale_by_sigma = True
model.sigma_max = 1348
model.num_scales = 2000
model.ema_rate = 0.9999
model.sigma_min = 0.01
model.normalization = 'GroupNorm'
model.nonlinearity = 'swish'
model.nf = 16
model.ch_mult = (1, 2, 4, 8, 16, 32, 32, 32)
model.num_res_blocks = 1
model.attn_resolutions = (16,)
model.dropout = 0.
model.resamp_with_conv = True
model.conditional = True
model.fir = True
model.fir_kernel = [1, 3, 3, 1]
model.skip_rescale = True
model.resblock_type = 'biggan'
model.progressive = 'output_skip'
model.progressive_input = 'input_skip'
model.progressive_combine = 'sum'
model.attention_type = 'ddpm'
model.init_scale = 0.
model.fourier_scale = 16
model.conv_size = 3
model.embedding_type = 'fourier'
# optim
config.optim = optim = ml_collections.ConfigDict()
optim.weight_decay = 0
optim.optimizer = 'Adam'
optim.lr = 2e-4
optim.beta1 = 0.9
optim.amsgrad = False
optim.eps = 1e-8
optim.warmup = 5000
optim.grad_clip = 1.
config.seed = 42
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
return config
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
torch.manual_seed(0) torch.manual_seed(3)
#class NewVESDE(SDE):
# def __init__(self, sigma_min=0.01, sigma_max=50, N=1000):
# """Construct a Variance Exploding SDE.
#
# Args:
# sigma_min: smallest sigma.
# sigma_max: largest sigma.
# N: number of discretization steps
# """
# super().__init__(N)
# self.sigma_min = sigma_min
# self.sigma_max = sigma_max
# self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))
# self.N = N
#
# @property
# def T(self):
# return 1
#
# def sde(self, x, t):
# sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
# drift = torch.zeros_like(x)
# diffusion = sigma * torch.sqrt(torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min)),
# device=t.device))
# return drift, diffusion
#
# def marginal_prob(self, x, t):
# std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
# mean = x
# return mean, std
#
# def prior_sampling(self, shape):
# return torch.randn(*shape) * self.sigma_max
#
# def prior_logp(self, z):
# shape = z.shape
# N = np.prod(shape[1:])
# return -N / 2. * np.log(2 * np.pi * self.sigma_max ** 2) - torch.sum(z ** 2, dim=(1, 2, 3)) / (2 * self.sigma_max ** 2)
#
# def discretize(self, x, t):
# """SMLD(NCSN) discretization."""
# timestep = (t * (self.N - 1) / self.T).long()
# sigma = self.discrete_sigmas.to(t.device)[timestep]
# adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),
# self.discrete_sigmas[timestep - 1].to(t.device))
# f = torch.zeros_like(x)
# G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)
# return f, G
class NewReverseDiffusionPredictor: class NewReverseDiffusionPredictor:
def __init__(self, sde, score_fn, probability_flow=False): def __init__(self, score_fn, probability_flow=False, sigma_min=0.0, sigma_max=0.0, N=0):
super().__init__() super().__init__()
self.sde = sde self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.N = N
self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))
self.probability_flow = probability_flow self.probability_flow = probability_flow
self.score_fn = score_fn self.score_fn = score_fn
def discretize(self, x, t): def discretize(self, x, t):
timestep = (t * (self.sde.N - 1) / self.sde.T).long() timestep = (t * (self.N - 1)).long()
sigma = self.sde.discrete_sigmas.to(t.device)[timestep] sigma = self.discrete_sigmas.to(t.device)[timestep]
adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t), adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),
self.sde.discrete_sigmas[timestep - 1].to(t.device)) self.discrete_sigmas[timestep - 1].to(t.device))
f = torch.zeros_like(x) f = torch.zeros_like(x)
G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2) G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)
labels = self.sde.marginal_prob(torch.zeros_like(x), t)[1] labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
result = self.score_fn(x, labels) result = self.score_fn(x, labels)
rev_f = f - G[:, None, None, None] ** 2 * result * (0.5 if self.probability_flow else 1.) rev_f = f - G[:, None, None, None] ** 2 * result * (0.5 if self.probability_flow else 1.)
@ -114,26 +138,27 @@ class NewReverseDiffusionPredictor:
class NewLangevinCorrector: class NewLangevinCorrector:
def __init__(self, sde, score_fn, snr, n_steps): def __init__(self, score_fn, snr, n_steps, sigma_min=0.0, sigma_max=0.0):
super().__init__() super().__init__()
self.sde = sde
self.score_fn = score_fn self.score_fn = score_fn
self.snr = snr self.snr = snr
self.n_steps = n_steps self.n_steps = n_steps
self.sigma_min = sigma_min
self.sigma_max = sigma_max
def update_fn(self, x, t): def update_fn(self, x, t):
sde = self.sde
score_fn = self.score_fn score_fn = self.score_fn
n_steps = self.n_steps n_steps = self.n_steps
target_snr = self.snr target_snr = self.snr
if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE): # if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE):
timestep = (t * (sde.N - 1) / sde.T).long() # timestep = (t * (sde.N - 1) / sde.T).long()
alpha = sde.alphas.to(t.device)[timestep] # alpha = sde.alphas.to(t.device)[timestep]
else: # else:
alpha = torch.ones_like(t) alpha = torch.ones_like(t)
for i in range(n_steps): for i in range(n_steps):
labels = sde.marginal_prob(torch.zeros_like(x), t)[1] labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
grad = score_fn(x, labels) grad = score_fn(x, labels)
noise = torch.randn_like(x) noise = torch.randn_like(x)
grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean() grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
@ -152,64 +177,42 @@ def save_image(x):
image_pil.save("../images/hey.png") image_pil.save("../images/hey.png")
sde = 'VESDE' #@param ['VESDE', 'VPSDE', 'subVPSDE'] {"type": "string"}
if sde.lower() == 'vesde':
# from configs.ve import cifar10_ncsnpp_continuous as configs
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth" # ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
from configs.ve import ffhq_ncsnpp_continuous as configs #ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth" # Note usually we need to restore ema etc...
config = configs.get_config() # ema restored checkpoint used from below
config.model.num_scales = 2
sde = VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
sampling_eps = 1e-5
elif sde.lower() == 'vpsde': config = get_config()
from configs.vp import cifar10_ddpmpp_continuous as configs
ckpt_filename = "exp/vp/cifar10_ddpmpp_continuous/checkpoint_8.pth" sigma_min, sigma_max = config.model.sigma_min, config.model.sigma_max
config = configs.get_config() N = config.model.num_scales
sde = VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
sampling_eps = 1e-3 sampling_eps = 1e-5
elif sde.lower() == 'subvpsde':
from configs.subvp import cifar10_ddpmpp_continuous as configs
ckpt_filename = "exp/subvp/cifar10_ddpmpp_continuous/checkpoint_26.pth"
config = configs.get_config()
sde = subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
sampling_eps = 1e-3
batch_size = 1 #@param {"type":"integer"} batch_size = 1 #@param {"type":"integer"}
config.training.batch_size = batch_size config.training.batch_size = batch_size
config.eval.batch_size = batch_size config.eval.batch_size = batch_size
random_seed = 0 #@param {"type": "integer"}
#sigmas = mutils.get_sigmas(config)
#scaler = datasets.get_data_scaler(config)
#inverse_scaler = datasets.get_data_inverse_scaler(config)
#score_model = mutils.create_model(config)
#
#optimizer = get_optimizer(config, score_model.parameters())
#ema = ExponentialMovingAverage(score_model.parameters(),
# decay=config.model.ema_rate)
#state = dict(step=0, optimizer=optimizer,
# model=score_model, ema=ema)
#
#state = restore_checkpoint(ckpt_filename, state, config.device)
#ema.copy_to(score_model.parameters())
#score_model = mutils.create_model(config)
from diffusers import NCSNpp from diffusers import NCSNpp
score_model = NCSNpp(config).to(config.device) model = NCSNpp(config).to(config.device)
score_model = torch.nn.DataParallel(score_model) model = torch.nn.DataParallel(model)
loaded_state = torch.load("./ffhq_1024_ncsnpp_continuous_ema.pt") loaded_state = torch.load("../score_sde_pytorch/ffhq_1024_ncsnpp_continuous_ema.pt")
del loaded_state["module.sigmas"] del loaded_state["module.sigmas"]
score_model.load_state_dict(loaded_state, strict=False) model.load_state_dict(loaded_state, strict=False)
inverse_scaler = datasets.get_data_inverse_scaler(config) def get_data_inverse_scaler(config):
predictor = ReverseDiffusionPredictor #@param ["EulerMaruyamaPredictor", "AncestralSamplingPredictor", "ReverseDiffusionPredictor", "None"] {"type": "raw"} """Inverse data normalizer."""
corrector = LangevinCorrector #@param ["LangevinCorrector", "AnnealedLangevinDynamics", "None"] {"type": "raw"} if config.data.centered:
# Rescale [-1, 1] to [0, 1]
return lambda x: (x + 1.) / 2.
else:
return lambda x: x
inverse_scaler = get_data_inverse_scaler(config)
#@title PC sampling
img_size = config.data.image_size img_size = config.data.image_size
channels = config.data.num_channels channels = config.data.num_channels
shape = (batch_size, channels, img_size, img_size) shape = (batch_size, channels, img_size, img_size)
@ -218,80 +221,27 @@ snr = 0.15 #@param {"type": "number"}
n_steps = 1#@param {"type": "integer"} n_steps = 1#@param {"type": "integer"}
#sampling_fn = sampling.get_pc_sampler(sde, shape, predictor, corrector,
# inverse_scaler, snr, n_steps=n_steps,
# probability_flow=probability_flow,
# continuous=config.training.continuous,
# eps=sampling_eps, device=config.device)
#
#x, n = sampling_fn(score_model)
#save_image(x)
def shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous):
"""A wrapper that configures and returns the update function of predictors."""
score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous)
if predictor is None:
# Corrector-only sampler
predictor_obj = NonePredictor(sde, score_fn, probability_flow)
else:
predictor_obj = predictor(sde, score_fn, probability_flow)
return predictor_obj.update_fn(x, t)
def shared_corrector_update_fn(x, t, sde, model, corrector, continuous, snr, n_steps):
"""A wrapper tha configures and returns the update function of correctors."""
score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous)
if corrector is None:
# Predictor-only sampler
corrector_obj = NoneCorrector(sde, score_fn, snr, n_steps)
else:
corrector_obj = corrector(sde, score_fn, snr, n_steps)
return corrector_obj.update_fn(x, t)
continuous = config.training.continuous
predictor_update_fn = functools.partial(shared_predictor_update_fn,
sde=sde,
predictor=predictor,
probability_flow=probability_flow,
continuous=continuous)
corrector_update_fn = functools.partial(shared_corrector_update_fn,
sde=sde,
corrector=corrector,
continuous=continuous,
snr=snr,
n_steps=n_steps)
device = config.device device = config.device
model = score_model
denoise = True
new_corrector = NewLangevinCorrector(sde=sde, score_fn=model, snr=snr, n_steps=n_steps) new_corrector = NewLangevinCorrector(score_fn=model, snr=snr, n_steps=n_steps, sigma_min=sigma_min, sigma_max=sigma_max)
new_predictor = NewReverseDiffusionPredictor(sde=sde, score_fn=model) new_predictor = NewReverseDiffusionPredictor(score_fn=model, sigma_min=sigma_min, sigma_max=sigma_max, N=N)
#
with torch.no_grad(): with torch.no_grad():
# Initial sample # Initial sample
x = sde.prior_sampling(shape).to(device) x = torch.randn(*shape) * sigma_max
timesteps = torch.linspace(sde.T, sampling_eps, sde.N, device=device) x = x.to(device)
timesteps = torch.linspace(1, sampling_eps, N, device=device)
for i in range(sde.N): for i in range(N):
t = timesteps[i] t = timesteps[i]
vec_t = torch.ones(shape[0], device=t.device) * t vec_t = torch.ones(shape[0], device=t.device) * t
# x, x_mean = corrector_update_fn(x, vec_t, model=model)
# x, x_mean = predictor_update_fn(x, vec_t, model=model)
x, x_mean = new_corrector.update_fn(x, vec_t) x, x_mean = new_corrector.update_fn(x, vec_t)
x, x_mean = new_predictor.update_fn(x, vec_t) x, x_mean = new_predictor.update_fn(x, vec_t)
x, n = inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1) x = inverse_scaler(x_mean)
save_image(x)
#save_image(x)
# for 5 cifar10 # for 5 cifar10
x_sum = 106071.9922 x_sum = 106071.9922
@ -310,4 +260,4 @@ def check_x_sum_x_mean(x, x_sum, x_mean):
assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
check_x_sum_x_mean(x, x_sum, x_mean) #check_x_sum_x_mean(x, x_sum, x_mean)