diff --git a/run.py b/run.py index 7a55acba..0180c348 100755 --- a/run.py +++ b/run.py @@ -1,104 +1,128 @@ #!/usr/bin/env python3 import numpy as np import PIL -import functools - -import models -from models import utils as mutils -from models import ncsnv2 -from models import ncsnpp -from models import ddpm as ddpm_model -from models import layerspp -from models import layers -from models import normalization -from models.ema import ExponentialMovingAverage -from losses import get_optimizer - -from utils import restore_checkpoint - -import sampling -from sde_lib import VESDE, VPSDE, subVPSDE -from sampling import (NoneCorrector, - ReverseDiffusionPredictor, - LangevinCorrector, - EulerMaruyamaPredictor, - AncestralSamplingPredictor, - NonePredictor, - AnnealedLangevinDynamics) -import datasets import torch +import ml_collections +#from configs.ve import ffhq_ncsnpp_continuous as configs +# from configs.ve import cifar10_ncsnpp_continuous as configs + + +# ffhq_ncsnpp_continuous config +def get_config(): + config = ml_collections.ConfigDict() + # training + config.training = training = ml_collections.ConfigDict() + training.batch_size = 8 + training.n_iters = 2400001 + training.snapshot_freq = 50000 + training.log_freq = 50 + training.eval_freq = 100 + training.snapshot_freq_for_preemption = 5000 + training.snapshot_sampling = True + training.sde = 'vesde' + training.continuous = True + training.likelihood_weighting = False + training.reduce_mean = True + + # sampling + config.sampling = sampling = ml_collections.ConfigDict() + sampling.method = 'pc' + sampling.predictor = 'reverse_diffusion' + sampling.corrector = 'langevin' + sampling.probability_flow = False + sampling.snr = 0.15 + sampling.n_steps_each = 1 + sampling.noise_removal = True + + # eval + config.eval = evaluate = ml_collections.ConfigDict() + evaluate.batch_size = 1024 + evaluate.num_samples = 50000 + evaluate.begin_ckpt = 1 + evaluate.end_ckpt = 96 + + # data + config.data = data = ml_collections.ConfigDict() + data.dataset = 'FFHQ' + data.image_size = 1024 + data.centered = False + data.random_flip = True + data.uniform_dequantization = False + data.num_channels = 3 + # Plug in your own path to the tfrecords file. + data.tfrecords_path = '/raid/song/ffhq-dataset/ffhq/ffhq-r10.tfrecords' + + # model + config.model = model = ml_collections.ConfigDict() + model.name = 'ncsnpp' + model.scale_by_sigma = True + model.sigma_max = 1348 + model.num_scales = 2000 + model.ema_rate = 0.9999 + model.sigma_min = 0.01 + model.normalization = 'GroupNorm' + model.nonlinearity = 'swish' + model.nf = 16 + model.ch_mult = (1, 2, 4, 8, 16, 32, 32, 32) + model.num_res_blocks = 1 + model.attn_resolutions = (16,) + model.dropout = 0. + model.resamp_with_conv = True + model.conditional = True + model.fir = True + model.fir_kernel = [1, 3, 3, 1] + model.skip_rescale = True + model.resblock_type = 'biggan' + model.progressive = 'output_skip' + model.progressive_input = 'input_skip' + model.progressive_combine = 'sum' + model.attention_type = 'ddpm' + model.init_scale = 0. + model.fourier_scale = 16 + model.conv_size = 3 + model.embedding_type = 'fourier' + + # optim + config.optim = optim = ml_collections.ConfigDict() + optim.weight_decay = 0 + optim.optimizer = 'Adam' + optim.lr = 2e-4 + optim.beta1 = 0.9 + optim.amsgrad = False + optim.eps = 1e-8 + optim.warmup = 5000 + optim.grad_clip = 1. + + config.seed = 42 + config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + + return config torch.backends.cuda.matmul.allow_tf32 = False -torch.manual_seed(0) - - -#class NewVESDE(SDE): -# def __init__(self, sigma_min=0.01, sigma_max=50, N=1000): -# """Construct a Variance Exploding SDE. -# -# Args: -# sigma_min: smallest sigma. -# sigma_max: largest sigma. -# N: number of discretization steps -# """ -# super().__init__(N) -# self.sigma_min = sigma_min -# self.sigma_max = sigma_max -# self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N)) -# self.N = N -# -# @property -# def T(self): -# return 1 -# -# def sde(self, x, t): -# sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t -# drift = torch.zeros_like(x) -# diffusion = sigma * torch.sqrt(torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min)), -# device=t.device)) -# return drift, diffusion -# -# def marginal_prob(self, x, t): -# std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t -# mean = x -# return mean, std -# -# def prior_sampling(self, shape): -# return torch.randn(*shape) * self.sigma_max -# -# def prior_logp(self, z): -# shape = z.shape -# N = np.prod(shape[1:]) -# return -N / 2. * np.log(2 * np.pi * self.sigma_max ** 2) - torch.sum(z ** 2, dim=(1, 2, 3)) / (2 * self.sigma_max ** 2) -# -# def discretize(self, x, t): -# """SMLD(NCSN) discretization.""" -# timestep = (t * (self.N - 1) / self.T).long() -# sigma = self.discrete_sigmas.to(t.device)[timestep] -# adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t), -# self.discrete_sigmas[timestep - 1].to(t.device)) -# f = torch.zeros_like(x) -# G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2) -# return f, G +torch.manual_seed(3) class NewReverseDiffusionPredictor: - def __init__(self, sde, score_fn, probability_flow=False): + def __init__(self, score_fn, probability_flow=False, sigma_min=0.0, sigma_max=0.0, N=0): super().__init__() - self.sde = sde + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.N = N + self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N)) + self.probability_flow = probability_flow self.score_fn = score_fn def discretize(self, x, t): - timestep = (t * (self.sde.N - 1) / self.sde.T).long() - sigma = self.sde.discrete_sigmas.to(t.device)[timestep] + timestep = (t * (self.N - 1)).long() + sigma = self.discrete_sigmas.to(t.device)[timestep] adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t), - self.sde.discrete_sigmas[timestep - 1].to(t.device)) + self.discrete_sigmas[timestep - 1].to(t.device)) f = torch.zeros_like(x) G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2) - labels = self.sde.marginal_prob(torch.zeros_like(x), t)[1] + labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t result = self.score_fn(x, labels) rev_f = f - G[:, None, None, None] ** 2 * result * (0.5 if self.probability_flow else 1.) @@ -114,26 +138,27 @@ class NewReverseDiffusionPredictor: class NewLangevinCorrector: - def __init__(self, sde, score_fn, snr, n_steps): + def __init__(self, score_fn, snr, n_steps, sigma_min=0.0, sigma_max=0.0): super().__init__() - self.sde = sde self.score_fn = score_fn self.snr = snr self.n_steps = n_steps + self.sigma_min = sigma_min + self.sigma_max = sigma_max + def update_fn(self, x, t): - sde = self.sde score_fn = self.score_fn n_steps = self.n_steps target_snr = self.snr - if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE): - timestep = (t * (sde.N - 1) / sde.T).long() - alpha = sde.alphas.to(t.device)[timestep] - else: - alpha = torch.ones_like(t) +# if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE): +# timestep = (t * (sde.N - 1) / sde.T).long() +# alpha = sde.alphas.to(t.device)[timestep] +# else: + alpha = torch.ones_like(t) for i in range(n_steps): - labels = sde.marginal_prob(torch.zeros_like(x), t)[1] + labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t grad = score_fn(x, labels) noise = torch.randn_like(x) grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean() @@ -152,64 +177,42 @@ def save_image(x): image_pil.save("../images/hey.png") -sde = 'VESDE' #@param ['VESDE', 'VPSDE', 'subVPSDE'] {"type": "string"} -if sde.lower() == 'vesde': -# from configs.ve import cifar10_ncsnpp_continuous as configs # ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth" - from configs.ve import ffhq_ncsnpp_continuous as configs - ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth" - config = configs.get_config() - config.model.num_scales = 2 - sde = VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales) - sampling_eps = 1e-5 -elif sde.lower() == 'vpsde': - from configs.vp import cifar10_ddpmpp_continuous as configs - ckpt_filename = "exp/vp/cifar10_ddpmpp_continuous/checkpoint_8.pth" - config = configs.get_config() - sde = VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) - sampling_eps = 1e-3 -elif sde.lower() == 'subvpsde': - from configs.subvp import cifar10_ddpmpp_continuous as configs - ckpt_filename = "exp/subvp/cifar10_ddpmpp_continuous/checkpoint_26.pth" - config = configs.get_config() - sde = subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) - sampling_eps = 1e-3 +#ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth" +# Note usually we need to restore ema etc... +# ema restored checkpoint used from below + + + +config = get_config() + +sigma_min, sigma_max = config.model.sigma_min, config.model.sigma_max +N = config.model.num_scales + +sampling_eps = 1e-5 batch_size = 1 #@param {"type":"integer"} config.training.batch_size = batch_size config.eval.batch_size = batch_size -random_seed = 0 #@param {"type": "integer"} - -#sigmas = mutils.get_sigmas(config) -#scaler = datasets.get_data_scaler(config) -#inverse_scaler = datasets.get_data_inverse_scaler(config) -#score_model = mutils.create_model(config) -# -#optimizer = get_optimizer(config, score_model.parameters()) -#ema = ExponentialMovingAverage(score_model.parameters(), -# decay=config.model.ema_rate) -#state = dict(step=0, optimizer=optimizer, -# model=score_model, ema=ema) -# -#state = restore_checkpoint(ckpt_filename, state, config.device) -#ema.copy_to(score_model.parameters()) - -#score_model = mutils.create_model(config) - from diffusers import NCSNpp -score_model = NCSNpp(config).to(config.device) -score_model = torch.nn.DataParallel(score_model) +model = NCSNpp(config).to(config.device) +model = torch.nn.DataParallel(model) -loaded_state = torch.load("./ffhq_1024_ncsnpp_continuous_ema.pt") +loaded_state = torch.load("../score_sde_pytorch/ffhq_1024_ncsnpp_continuous_ema.pt") del loaded_state["module.sigmas"] -score_model.load_state_dict(loaded_state, strict=False) +model.load_state_dict(loaded_state, strict=False) -inverse_scaler = datasets.get_data_inverse_scaler(config) -predictor = ReverseDiffusionPredictor #@param ["EulerMaruyamaPredictor", "AncestralSamplingPredictor", "ReverseDiffusionPredictor", "None"] {"type": "raw"} -corrector = LangevinCorrector #@param ["LangevinCorrector", "AnnealedLangevinDynamics", "None"] {"type": "raw"} +def get_data_inverse_scaler(config): + """Inverse data normalizer.""" + if config.data.centered: + # Rescale [-1, 1] to [0, 1] + return lambda x: (x + 1.) / 2. + else: + return lambda x: x + +inverse_scaler = get_data_inverse_scaler(config) -#@title PC sampling img_size = config.data.image_size channels = config.data.num_channels shape = (batch_size, channels, img_size, img_size) @@ -218,80 +221,27 @@ snr = 0.15 #@param {"type": "number"} n_steps = 1#@param {"type": "integer"} -#sampling_fn = sampling.get_pc_sampler(sde, shape, predictor, corrector, -# inverse_scaler, snr, n_steps=n_steps, -# probability_flow=probability_flow, -# continuous=config.training.continuous, -# eps=sampling_eps, device=config.device) -# -#x, n = sampling_fn(score_model) -#save_image(x) - - -def shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous): - """A wrapper that configures and returns the update function of predictors.""" - score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous) - if predictor is None: - # Corrector-only sampler - predictor_obj = NonePredictor(sde, score_fn, probability_flow) - else: - predictor_obj = predictor(sde, score_fn, probability_flow) - return predictor_obj.update_fn(x, t) - - -def shared_corrector_update_fn(x, t, sde, model, corrector, continuous, snr, n_steps): - """A wrapper tha configures and returns the update function of correctors.""" - score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous) - if corrector is None: - # Predictor-only sampler - corrector_obj = NoneCorrector(sde, score_fn, snr, n_steps) - else: - corrector_obj = corrector(sde, score_fn, snr, n_steps) - return corrector_obj.update_fn(x, t) - - -continuous = config.training.continuous - - -predictor_update_fn = functools.partial(shared_predictor_update_fn, - sde=sde, - predictor=predictor, - probability_flow=probability_flow, - continuous=continuous) - -corrector_update_fn = functools.partial(shared_corrector_update_fn, - sde=sde, - corrector=corrector, - continuous=continuous, - snr=snr, - n_steps=n_steps) - device = config.device -model = score_model -denoise = True -new_corrector = NewLangevinCorrector(sde=sde, score_fn=model, snr=snr, n_steps=n_steps) -new_predictor = NewReverseDiffusionPredictor(sde=sde, score_fn=model) +new_corrector = NewLangevinCorrector(score_fn=model, snr=snr, n_steps=n_steps, sigma_min=sigma_min, sigma_max=sigma_max) +new_predictor = NewReverseDiffusionPredictor(score_fn=model, sigma_min=sigma_min, sigma_max=sigma_max, N=N) -# with torch.no_grad(): # Initial sample - x = sde.prior_sampling(shape).to(device) - timesteps = torch.linspace(sde.T, sampling_eps, sde.N, device=device) + x = torch.randn(*shape) * sigma_max + x = x.to(device) + timesteps = torch.linspace(1, sampling_eps, N, device=device) - for i in range(sde.N): + for i in range(N): t = timesteps[i] vec_t = torch.ones(shape[0], device=t.device) * t -# x, x_mean = corrector_update_fn(x, vec_t, model=model) -# x, x_mean = predictor_update_fn(x, vec_t, model=model) x, x_mean = new_corrector.update_fn(x, vec_t) x, x_mean = new_predictor.update_fn(x, vec_t) - x, n = inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1) + x = inverse_scaler(x_mean) - -#save_image(x) +save_image(x) # for 5 cifar10 x_sum = 106071.9922 @@ -310,4 +260,4 @@ def check_x_sum_x_mean(x, x_sum, x_mean): assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" -check_x_sum_x_mean(x, x_sum, x_mean) +#check_x_sum_x_mean(x, x_sum, x_mean)