diff --git a/run.py b/run.py new file mode 100755 index 00000000..61e29603 --- /dev/null +++ b/run.py @@ -0,0 +1,288 @@ +#!/usr/bin/env python3 +import numpy as np +import PIL +import functools + +import models +from models import utils as mutils +from models import ncsnv2 +from models import ncsnpp +from models import ddpm as ddpm_model +from models import layerspp +from models import layers +from models import normalization + +from utils import restore_checkpoint + +import sampling +from sde_lib import VESDE, VPSDE, subVPSDE +from sampling import (NoneCorrector, + ReverseDiffusionPredictor, + LangevinCorrector, + EulerMaruyamaPredictor, + AncestralSamplingPredictor, + NonePredictor, + AnnealedLangevinDynamics) +import datasets +import torch + + +torch.manual_seed(0) + + +#class NewVESDE(SDE): +# def __init__(self, sigma_min=0.01, sigma_max=50, N=1000): +# """Construct a Variance Exploding SDE. +# +# Args: +# sigma_min: smallest sigma. +# sigma_max: largest sigma. +# N: number of discretization steps +# """ +# super().__init__(N) +# self.sigma_min = sigma_min +# self.sigma_max = sigma_max +# self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N)) +# self.N = N +# +# @property +# def T(self): +# return 1 +# +# def sde(self, x, t): +# sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t +# drift = torch.zeros_like(x) +# diffusion = sigma * torch.sqrt(torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min)), +# device=t.device)) +# return drift, diffusion +# +# def marginal_prob(self, x, t): +# std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t +# mean = x +# return mean, std +# +# def prior_sampling(self, shape): +# return torch.randn(*shape) * self.sigma_max +# +# def prior_logp(self, z): +# shape = z.shape +# N = np.prod(shape[1:]) +# return -N / 2. * np.log(2 * np.pi * self.sigma_max ** 2) - torch.sum(z ** 2, dim=(1, 2, 3)) / (2 * self.sigma_max ** 2) +# +# def discretize(self, x, t): +# """SMLD(NCSN) discretization.""" +# timestep = (t * (self.N - 1) / self.T).long() +# sigma = self.discrete_sigmas.to(t.device)[timestep] +# adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t), +# self.discrete_sigmas[timestep - 1].to(t.device)) +# f = torch.zeros_like(x) +# G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2) +# return f, G + + +class NewReverseDiffusionPredictor: + + def __init__(self, sde, score_fn, probability_flow=False): + super().__init__() + self.sde = sde + self.probability_flow = probability_flow + self.score_fn = score_fn + + def discretize(self, x, t): + timestep = (t * (self.sde.N - 1) / self.sde.T).long() + sigma = self.sde.discrete_sigmas.to(t.device)[timestep] + adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t), + self.sde.discrete_sigmas[timestep - 1].to(t.device)) + f = torch.zeros_like(x) + G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2) + + labels = self.sde.marginal_prob(torch.zeros_like(x), t)[1] + result = self.score_fn(x, labels) + + rev_f = f - G[:, None, None, None] ** 2 * result * (0.5 if self.probability_flow else 1.) + rev_G = torch.zeros_like(G) if self.probability_flow else G + return rev_f, rev_G + + def update_fn(self, x, t): + f, G = self.discretize(x, t) + z = torch.randn_like(x) + x_mean = x - f + x = x_mean + G[:, None, None, None] * z + return x, x_mean + + +class NewLangevinCorrector: + + def __init__(self, sde, score_fn, snr, n_steps): + super().__init__() + self.sde = sde + self.score_fn = score_fn + self.snr = snr + self.n_steps = n_steps + + def update_fn(self, x, t): + sde = self.sde + score_fn = self.score_fn + n_steps = self.n_steps + target_snr = self.snr + if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE): + timestep = (t * (sde.N - 1) / sde.T).long() + alpha = sde.alphas.to(t.device)[timestep] + else: + alpha = torch.ones_like(t) + + for i in range(n_steps): + labels = sde.marginal_prob(torch.zeros_like(x), t)[1] + grad = score_fn(x, labels) + noise = torch.randn_like(x) + grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean() + noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() + step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha + x_mean = x + step_size[:, None, None, None] * grad + x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise + + return x, x_mean + + + +def save_image(x): +# image_processed = x.cpu().permute(0, 2, 3, 1) +# image_processed = (image_processed + 1.0) * 127.5 +# image_processed = image_processed.numpy().astype(np.uint8) + image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8) + image_pil = PIL.Image.fromarray(image_processed[0]) + + # 6. save image + image_pil.save("../images/hey.png") + + +#x = np.load("cifar10.npy") +# +#save_image(x) +# @title Load the score-based model +sde = 'VESDE' #@param ['VESDE', 'VPSDE', 'subVPSDE'] {"type": "string"} +if sde.lower() == 'vesde': + from configs.ve import cifar10_ncsnpp_continuous as configs + ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth" +# from configs.ve import ffhq_ncsnpp_continuous as configs +# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth" + config = configs.get_config() + config.model.num_scales = 1000 + sde = VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales) + sampling_eps = 1e-5 +elif sde.lower() == 'vpsde': + from configs.vp import cifar10_ddpmpp_continuous as configs + ckpt_filename = "exp/vp/cifar10_ddpmpp_continuous/checkpoint_8.pth" + config = configs.get_config() + sde = VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) + sampling_eps = 1e-3 +elif sde.lower() == 'subvpsde': + from configs.subvp import cifar10_ddpmpp_continuous as configs + ckpt_filename = "exp/subvp/cifar10_ddpmpp_continuous/checkpoint_26.pth" + config = configs.get_config() + sde = subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) + sampling_eps = 1e-3 + +batch_size = 1 #@param {"type":"integer"} +config.training.batch_size = batch_size +config.eval.batch_size = batch_size + +random_seed = 0 #@param {"type": "integer"} + +score_model = mutils.create_model(config) + +loaded_state = torch.load(ckpt_filename) +score_model.load_state_dict(loaded_state["model"], strict=False) + +inverse_scaler = datasets.get_data_inverse_scaler(config) +predictor = ReverseDiffusionPredictor #@param ["EulerMaruyamaPredictor", "AncestralSamplingPredictor", "ReverseDiffusionPredictor", "None"] {"type": "raw"} +corrector = LangevinCorrector #@param ["LangevinCorrector", "AnnealedLangevinDynamics", "None"] {"type": "raw"} + +def image_grid(x): + size = config.data.image_size + channels = config.data.num_channels + img = x.reshape(-1, size, size, channels) + w = int(np.sqrt(img.shape[0])) + img = img.reshape((w, w, size, size, channels)).transpose((0, 2, 1, 3, 4)).reshape((w * size, w * size, channels)) + return img + +#@title PC sampling +img_size = config.data.image_size +channels = config.data.num_channels +shape = (batch_size, channels, img_size, img_size) +probability_flow = False +snr = 0.16 #@param {"type": "number"} +n_steps = 1#@param {"type": "integer"} + + +def shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous): + """A wrapper that configures and returns the update function of predictors.""" + score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous) + if predictor is None: + # Corrector-only sampler + predictor_obj = NonePredictor(sde, score_fn, probability_flow) + else: + predictor_obj = predictor(sde, score_fn, probability_flow) + return predictor_obj.update_fn(x, t) + + +def shared_corrector_update_fn(x, t, sde, model, corrector, continuous, snr, n_steps): + """A wrapper tha configures and returns the update function of correctors.""" + score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous) + if corrector is None: + # Predictor-only sampler + corrector_obj = NoneCorrector(sde, score_fn, snr, n_steps) + else: + corrector_obj = corrector(sde, score_fn, snr, n_steps) + return corrector_obj.update_fn(x, t) + + +continuous = config.training.continuous + + +predictor_update_fn = functools.partial(shared_predictor_update_fn, + sde=sde, + predictor=predictor, + probability_flow=probability_flow, + continuous=continuous) + +corrector_update_fn = functools.partial(shared_corrector_update_fn, + sde=sde, + corrector=corrector, + continuous=continuous, + snr=snr, + n_steps=n_steps) + +device = "cuda" +model = score_model.to(device) +denoise = False + +new_corrector = NewLangevinCorrector(sde=sde, score_fn=model, snr=snr, n_steps=n_steps) +new_predictor = NewReverseDiffusionPredictor(sde=sde, score_fn=model) + + +with torch.no_grad(): + # Initial sample + x = sde.prior_sampling(shape).to(device) + timesteps = torch.linspace(sde.T, sampling_eps, sde.N, device=device) + + for i in range(sde.N): + t = timesteps[i] + vec_t = torch.ones(shape[0], device=t.device) * t +# x, x_mean = corrector_update_fn(x, vec_t, model=model) +# x, x_mean = predictor_update_fn(x, vec_t, model=model) + x, x_mean = new_corrector.update_fn(x, vec_t) + x, x_mean = new_predictor.update_fn(x, vec_t) + + x, n = inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1) + + +# for 5 +#assert x.abs().sum().cpu().item() - 106114.90625 < 1e-2, "sum wrong" +#assert x.abs().mean().cpu().item() - 34.5426139831543 < 1e-4, "mean wrong" + +# for 1000 +assert x.abs().sum().cpu().item() - 436.5811 < 1e-2, "sum wrong" +assert x.abs().mean().cpu().item() - 0.1421 < 1e-4, "mean wrong" + +save_image(x)