port first 1024 model
This commit is contained in:
parent
78e99a997b
commit
49a81f9f1a
350
run.py
350
run.py
|
@ -1,104 +1,128 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import PIL
|
import PIL
|
||||||
import functools
|
|
||||||
|
|
||||||
import models
|
|
||||||
from models import utils as mutils
|
|
||||||
from models import ncsnv2
|
|
||||||
from models import ncsnpp
|
|
||||||
from models import ddpm as ddpm_model
|
|
||||||
from models import layerspp
|
|
||||||
from models import layers
|
|
||||||
from models import normalization
|
|
||||||
from models.ema import ExponentialMovingAverage
|
|
||||||
from losses import get_optimizer
|
|
||||||
|
|
||||||
from utils import restore_checkpoint
|
|
||||||
|
|
||||||
import sampling
|
|
||||||
from sde_lib import VESDE, VPSDE, subVPSDE
|
|
||||||
from sampling import (NoneCorrector,
|
|
||||||
ReverseDiffusionPredictor,
|
|
||||||
LangevinCorrector,
|
|
||||||
EulerMaruyamaPredictor,
|
|
||||||
AncestralSamplingPredictor,
|
|
||||||
NonePredictor,
|
|
||||||
AnnealedLangevinDynamics)
|
|
||||||
import datasets
|
|
||||||
import torch
|
import torch
|
||||||
|
import ml_collections
|
||||||
|
#from configs.ve import ffhq_ncsnpp_continuous as configs
|
||||||
|
# from configs.ve import cifar10_ncsnpp_continuous as configs
|
||||||
|
|
||||||
|
|
||||||
|
# ffhq_ncsnpp_continuous config
|
||||||
|
def get_config():
|
||||||
|
config = ml_collections.ConfigDict()
|
||||||
|
# training
|
||||||
|
config.training = training = ml_collections.ConfigDict()
|
||||||
|
training.batch_size = 8
|
||||||
|
training.n_iters = 2400001
|
||||||
|
training.snapshot_freq = 50000
|
||||||
|
training.log_freq = 50
|
||||||
|
training.eval_freq = 100
|
||||||
|
training.snapshot_freq_for_preemption = 5000
|
||||||
|
training.snapshot_sampling = True
|
||||||
|
training.sde = 'vesde'
|
||||||
|
training.continuous = True
|
||||||
|
training.likelihood_weighting = False
|
||||||
|
training.reduce_mean = True
|
||||||
|
|
||||||
|
# sampling
|
||||||
|
config.sampling = sampling = ml_collections.ConfigDict()
|
||||||
|
sampling.method = 'pc'
|
||||||
|
sampling.predictor = 'reverse_diffusion'
|
||||||
|
sampling.corrector = 'langevin'
|
||||||
|
sampling.probability_flow = False
|
||||||
|
sampling.snr = 0.15
|
||||||
|
sampling.n_steps_each = 1
|
||||||
|
sampling.noise_removal = True
|
||||||
|
|
||||||
|
# eval
|
||||||
|
config.eval = evaluate = ml_collections.ConfigDict()
|
||||||
|
evaluate.batch_size = 1024
|
||||||
|
evaluate.num_samples = 50000
|
||||||
|
evaluate.begin_ckpt = 1
|
||||||
|
evaluate.end_ckpt = 96
|
||||||
|
|
||||||
|
# data
|
||||||
|
config.data = data = ml_collections.ConfigDict()
|
||||||
|
data.dataset = 'FFHQ'
|
||||||
|
data.image_size = 1024
|
||||||
|
data.centered = False
|
||||||
|
data.random_flip = True
|
||||||
|
data.uniform_dequantization = False
|
||||||
|
data.num_channels = 3
|
||||||
|
# Plug in your own path to the tfrecords file.
|
||||||
|
data.tfrecords_path = '/raid/song/ffhq-dataset/ffhq/ffhq-r10.tfrecords'
|
||||||
|
|
||||||
|
# model
|
||||||
|
config.model = model = ml_collections.ConfigDict()
|
||||||
|
model.name = 'ncsnpp'
|
||||||
|
model.scale_by_sigma = True
|
||||||
|
model.sigma_max = 1348
|
||||||
|
model.num_scales = 2000
|
||||||
|
model.ema_rate = 0.9999
|
||||||
|
model.sigma_min = 0.01
|
||||||
|
model.normalization = 'GroupNorm'
|
||||||
|
model.nonlinearity = 'swish'
|
||||||
|
model.nf = 16
|
||||||
|
model.ch_mult = (1, 2, 4, 8, 16, 32, 32, 32)
|
||||||
|
model.num_res_blocks = 1
|
||||||
|
model.attn_resolutions = (16,)
|
||||||
|
model.dropout = 0.
|
||||||
|
model.resamp_with_conv = True
|
||||||
|
model.conditional = True
|
||||||
|
model.fir = True
|
||||||
|
model.fir_kernel = [1, 3, 3, 1]
|
||||||
|
model.skip_rescale = True
|
||||||
|
model.resblock_type = 'biggan'
|
||||||
|
model.progressive = 'output_skip'
|
||||||
|
model.progressive_input = 'input_skip'
|
||||||
|
model.progressive_combine = 'sum'
|
||||||
|
model.attention_type = 'ddpm'
|
||||||
|
model.init_scale = 0.
|
||||||
|
model.fourier_scale = 16
|
||||||
|
model.conv_size = 3
|
||||||
|
model.embedding_type = 'fourier'
|
||||||
|
|
||||||
|
# optim
|
||||||
|
config.optim = optim = ml_collections.ConfigDict()
|
||||||
|
optim.weight_decay = 0
|
||||||
|
optim.optimizer = 'Adam'
|
||||||
|
optim.lr = 2e-4
|
||||||
|
optim.beta1 = 0.9
|
||||||
|
optim.amsgrad = False
|
||||||
|
optim.eps = 1e-8
|
||||||
|
optim.warmup = 5000
|
||||||
|
optim.grad_clip = 1.
|
||||||
|
|
||||||
|
config.seed = 42
|
||||||
|
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = False
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(3)
|
||||||
|
|
||||||
|
|
||||||
#class NewVESDE(SDE):
|
|
||||||
# def __init__(self, sigma_min=0.01, sigma_max=50, N=1000):
|
|
||||||
# """Construct a Variance Exploding SDE.
|
|
||||||
#
|
|
||||||
# Args:
|
|
||||||
# sigma_min: smallest sigma.
|
|
||||||
# sigma_max: largest sigma.
|
|
||||||
# N: number of discretization steps
|
|
||||||
# """
|
|
||||||
# super().__init__(N)
|
|
||||||
# self.sigma_min = sigma_min
|
|
||||||
# self.sigma_max = sigma_max
|
|
||||||
# self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))
|
|
||||||
# self.N = N
|
|
||||||
#
|
|
||||||
# @property
|
|
||||||
# def T(self):
|
|
||||||
# return 1
|
|
||||||
#
|
|
||||||
# def sde(self, x, t):
|
|
||||||
# sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
|
|
||||||
# drift = torch.zeros_like(x)
|
|
||||||
# diffusion = sigma * torch.sqrt(torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min)),
|
|
||||||
# device=t.device))
|
|
||||||
# return drift, diffusion
|
|
||||||
#
|
|
||||||
# def marginal_prob(self, x, t):
|
|
||||||
# std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
|
|
||||||
# mean = x
|
|
||||||
# return mean, std
|
|
||||||
#
|
|
||||||
# def prior_sampling(self, shape):
|
|
||||||
# return torch.randn(*shape) * self.sigma_max
|
|
||||||
#
|
|
||||||
# def prior_logp(self, z):
|
|
||||||
# shape = z.shape
|
|
||||||
# N = np.prod(shape[1:])
|
|
||||||
# return -N / 2. * np.log(2 * np.pi * self.sigma_max ** 2) - torch.sum(z ** 2, dim=(1, 2, 3)) / (2 * self.sigma_max ** 2)
|
|
||||||
#
|
|
||||||
# def discretize(self, x, t):
|
|
||||||
# """SMLD(NCSN) discretization."""
|
|
||||||
# timestep = (t * (self.N - 1) / self.T).long()
|
|
||||||
# sigma = self.discrete_sigmas.to(t.device)[timestep]
|
|
||||||
# adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),
|
|
||||||
# self.discrete_sigmas[timestep - 1].to(t.device))
|
|
||||||
# f = torch.zeros_like(x)
|
|
||||||
# G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)
|
|
||||||
# return f, G
|
|
||||||
|
|
||||||
|
|
||||||
class NewReverseDiffusionPredictor:
|
class NewReverseDiffusionPredictor:
|
||||||
def __init__(self, sde, score_fn, probability_flow=False):
|
def __init__(self, score_fn, probability_flow=False, sigma_min=0.0, sigma_max=0.0, N=0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.sde = sde
|
self.sigma_min = sigma_min
|
||||||
|
self.sigma_max = sigma_max
|
||||||
|
self.N = N
|
||||||
|
self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))
|
||||||
|
|
||||||
self.probability_flow = probability_flow
|
self.probability_flow = probability_flow
|
||||||
self.score_fn = score_fn
|
self.score_fn = score_fn
|
||||||
|
|
||||||
def discretize(self, x, t):
|
def discretize(self, x, t):
|
||||||
timestep = (t * (self.sde.N - 1) / self.sde.T).long()
|
timestep = (t * (self.N - 1)).long()
|
||||||
sigma = self.sde.discrete_sigmas.to(t.device)[timestep]
|
sigma = self.discrete_sigmas.to(t.device)[timestep]
|
||||||
adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),
|
adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),
|
||||||
self.sde.discrete_sigmas[timestep - 1].to(t.device))
|
self.discrete_sigmas[timestep - 1].to(t.device))
|
||||||
f = torch.zeros_like(x)
|
f = torch.zeros_like(x)
|
||||||
G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)
|
G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)
|
||||||
|
|
||||||
labels = self.sde.marginal_prob(torch.zeros_like(x), t)[1]
|
labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
|
||||||
result = self.score_fn(x, labels)
|
result = self.score_fn(x, labels)
|
||||||
|
|
||||||
rev_f = f - G[:, None, None, None] ** 2 * result * (0.5 if self.probability_flow else 1.)
|
rev_f = f - G[:, None, None, None] ** 2 * result * (0.5 if self.probability_flow else 1.)
|
||||||
|
@ -114,26 +138,27 @@ class NewReverseDiffusionPredictor:
|
||||||
|
|
||||||
|
|
||||||
class NewLangevinCorrector:
|
class NewLangevinCorrector:
|
||||||
def __init__(self, sde, score_fn, snr, n_steps):
|
def __init__(self, score_fn, snr, n_steps, sigma_min=0.0, sigma_max=0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.sde = sde
|
|
||||||
self.score_fn = score_fn
|
self.score_fn = score_fn
|
||||||
self.snr = snr
|
self.snr = snr
|
||||||
self.n_steps = n_steps
|
self.n_steps = n_steps
|
||||||
|
|
||||||
|
self.sigma_min = sigma_min
|
||||||
|
self.sigma_max = sigma_max
|
||||||
|
|
||||||
def update_fn(self, x, t):
|
def update_fn(self, x, t):
|
||||||
sde = self.sde
|
|
||||||
score_fn = self.score_fn
|
score_fn = self.score_fn
|
||||||
n_steps = self.n_steps
|
n_steps = self.n_steps
|
||||||
target_snr = self.snr
|
target_snr = self.snr
|
||||||
if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE):
|
# if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE):
|
||||||
timestep = (t * (sde.N - 1) / sde.T).long()
|
# timestep = (t * (sde.N - 1) / sde.T).long()
|
||||||
alpha = sde.alphas.to(t.device)[timestep]
|
# alpha = sde.alphas.to(t.device)[timestep]
|
||||||
else:
|
# else:
|
||||||
alpha = torch.ones_like(t)
|
alpha = torch.ones_like(t)
|
||||||
|
|
||||||
for i in range(n_steps):
|
for i in range(n_steps):
|
||||||
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
|
||||||
grad = score_fn(x, labels)
|
grad = score_fn(x, labels)
|
||||||
noise = torch.randn_like(x)
|
noise = torch.randn_like(x)
|
||||||
grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
|
grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
|
||||||
|
@ -152,64 +177,42 @@ def save_image(x):
|
||||||
image_pil.save("../images/hey.png")
|
image_pil.save("../images/hey.png")
|
||||||
|
|
||||||
|
|
||||||
sde = 'VESDE' #@param ['VESDE', 'VPSDE', 'subVPSDE'] {"type": "string"}
|
|
||||||
if sde.lower() == 'vesde':
|
|
||||||
# from configs.ve import cifar10_ncsnpp_continuous as configs
|
|
||||||
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
|
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
|
||||||
from configs.ve import ffhq_ncsnpp_continuous as configs
|
#ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
|
||||||
ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
|
# Note usually we need to restore ema etc...
|
||||||
config = configs.get_config()
|
# ema restored checkpoint used from below
|
||||||
config.model.num_scales = 2
|
|
||||||
sde = VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
|
|
||||||
sampling_eps = 1e-5
|
|
||||||
elif sde.lower() == 'vpsde':
|
config = get_config()
|
||||||
from configs.vp import cifar10_ddpmpp_continuous as configs
|
|
||||||
ckpt_filename = "exp/vp/cifar10_ddpmpp_continuous/checkpoint_8.pth"
|
sigma_min, sigma_max = config.model.sigma_min, config.model.sigma_max
|
||||||
config = configs.get_config()
|
N = config.model.num_scales
|
||||||
sde = VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
|
|
||||||
sampling_eps = 1e-3
|
sampling_eps = 1e-5
|
||||||
elif sde.lower() == 'subvpsde':
|
|
||||||
from configs.subvp import cifar10_ddpmpp_continuous as configs
|
|
||||||
ckpt_filename = "exp/subvp/cifar10_ddpmpp_continuous/checkpoint_26.pth"
|
|
||||||
config = configs.get_config()
|
|
||||||
sde = subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
|
|
||||||
sampling_eps = 1e-3
|
|
||||||
|
|
||||||
batch_size = 1 #@param {"type":"integer"}
|
batch_size = 1 #@param {"type":"integer"}
|
||||||
config.training.batch_size = batch_size
|
config.training.batch_size = batch_size
|
||||||
config.eval.batch_size = batch_size
|
config.eval.batch_size = batch_size
|
||||||
|
|
||||||
random_seed = 0 #@param {"type": "integer"}
|
|
||||||
|
|
||||||
#sigmas = mutils.get_sigmas(config)
|
|
||||||
#scaler = datasets.get_data_scaler(config)
|
|
||||||
#inverse_scaler = datasets.get_data_inverse_scaler(config)
|
|
||||||
#score_model = mutils.create_model(config)
|
|
||||||
#
|
|
||||||
#optimizer = get_optimizer(config, score_model.parameters())
|
|
||||||
#ema = ExponentialMovingAverage(score_model.parameters(),
|
|
||||||
# decay=config.model.ema_rate)
|
|
||||||
#state = dict(step=0, optimizer=optimizer,
|
|
||||||
# model=score_model, ema=ema)
|
|
||||||
#
|
|
||||||
#state = restore_checkpoint(ckpt_filename, state, config.device)
|
|
||||||
#ema.copy_to(score_model.parameters())
|
|
||||||
|
|
||||||
#score_model = mutils.create_model(config)
|
|
||||||
|
|
||||||
from diffusers import NCSNpp
|
from diffusers import NCSNpp
|
||||||
score_model = NCSNpp(config).to(config.device)
|
model = NCSNpp(config).to(config.device)
|
||||||
score_model = torch.nn.DataParallel(score_model)
|
model = torch.nn.DataParallel(model)
|
||||||
|
|
||||||
loaded_state = torch.load("./ffhq_1024_ncsnpp_continuous_ema.pt")
|
loaded_state = torch.load("../score_sde_pytorch/ffhq_1024_ncsnpp_continuous_ema.pt")
|
||||||
del loaded_state["module.sigmas"]
|
del loaded_state["module.sigmas"]
|
||||||
score_model.load_state_dict(loaded_state, strict=False)
|
model.load_state_dict(loaded_state, strict=False)
|
||||||
|
|
||||||
inverse_scaler = datasets.get_data_inverse_scaler(config)
|
def get_data_inverse_scaler(config):
|
||||||
predictor = ReverseDiffusionPredictor #@param ["EulerMaruyamaPredictor", "AncestralSamplingPredictor", "ReverseDiffusionPredictor", "None"] {"type": "raw"}
|
"""Inverse data normalizer."""
|
||||||
corrector = LangevinCorrector #@param ["LangevinCorrector", "AnnealedLangevinDynamics", "None"] {"type": "raw"}
|
if config.data.centered:
|
||||||
|
# Rescale [-1, 1] to [0, 1]
|
||||||
|
return lambda x: (x + 1.) / 2.
|
||||||
|
else:
|
||||||
|
return lambda x: x
|
||||||
|
|
||||||
|
inverse_scaler = get_data_inverse_scaler(config)
|
||||||
|
|
||||||
#@title PC sampling
|
|
||||||
img_size = config.data.image_size
|
img_size = config.data.image_size
|
||||||
channels = config.data.num_channels
|
channels = config.data.num_channels
|
||||||
shape = (batch_size, channels, img_size, img_size)
|
shape = (batch_size, channels, img_size, img_size)
|
||||||
|
@ -218,80 +221,27 @@ snr = 0.15 #@param {"type": "number"}
|
||||||
n_steps = 1#@param {"type": "integer"}
|
n_steps = 1#@param {"type": "integer"}
|
||||||
|
|
||||||
|
|
||||||
#sampling_fn = sampling.get_pc_sampler(sde, shape, predictor, corrector,
|
|
||||||
# inverse_scaler, snr, n_steps=n_steps,
|
|
||||||
# probability_flow=probability_flow,
|
|
||||||
# continuous=config.training.continuous,
|
|
||||||
# eps=sampling_eps, device=config.device)
|
|
||||||
#
|
|
||||||
#x, n = sampling_fn(score_model)
|
|
||||||
#save_image(x)
|
|
||||||
|
|
||||||
|
|
||||||
def shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous):
|
|
||||||
"""A wrapper that configures and returns the update function of predictors."""
|
|
||||||
score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous)
|
|
||||||
if predictor is None:
|
|
||||||
# Corrector-only sampler
|
|
||||||
predictor_obj = NonePredictor(sde, score_fn, probability_flow)
|
|
||||||
else:
|
|
||||||
predictor_obj = predictor(sde, score_fn, probability_flow)
|
|
||||||
return predictor_obj.update_fn(x, t)
|
|
||||||
|
|
||||||
|
|
||||||
def shared_corrector_update_fn(x, t, sde, model, corrector, continuous, snr, n_steps):
|
|
||||||
"""A wrapper tha configures and returns the update function of correctors."""
|
|
||||||
score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous)
|
|
||||||
if corrector is None:
|
|
||||||
# Predictor-only sampler
|
|
||||||
corrector_obj = NoneCorrector(sde, score_fn, snr, n_steps)
|
|
||||||
else:
|
|
||||||
corrector_obj = corrector(sde, score_fn, snr, n_steps)
|
|
||||||
return corrector_obj.update_fn(x, t)
|
|
||||||
|
|
||||||
|
|
||||||
continuous = config.training.continuous
|
|
||||||
|
|
||||||
|
|
||||||
predictor_update_fn = functools.partial(shared_predictor_update_fn,
|
|
||||||
sde=sde,
|
|
||||||
predictor=predictor,
|
|
||||||
probability_flow=probability_flow,
|
|
||||||
continuous=continuous)
|
|
||||||
|
|
||||||
corrector_update_fn = functools.partial(shared_corrector_update_fn,
|
|
||||||
sde=sde,
|
|
||||||
corrector=corrector,
|
|
||||||
continuous=continuous,
|
|
||||||
snr=snr,
|
|
||||||
n_steps=n_steps)
|
|
||||||
|
|
||||||
device = config.device
|
device = config.device
|
||||||
model = score_model
|
|
||||||
denoise = True
|
|
||||||
|
|
||||||
new_corrector = NewLangevinCorrector(sde=sde, score_fn=model, snr=snr, n_steps=n_steps)
|
new_corrector = NewLangevinCorrector(score_fn=model, snr=snr, n_steps=n_steps, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||||
new_predictor = NewReverseDiffusionPredictor(sde=sde, score_fn=model)
|
new_predictor = NewReverseDiffusionPredictor(score_fn=model, sigma_min=sigma_min, sigma_max=sigma_max, N=N)
|
||||||
|
|
||||||
#
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# Initial sample
|
# Initial sample
|
||||||
x = sde.prior_sampling(shape).to(device)
|
x = torch.randn(*shape) * sigma_max
|
||||||
timesteps = torch.linspace(sde.T, sampling_eps, sde.N, device=device)
|
x = x.to(device)
|
||||||
|
timesteps = torch.linspace(1, sampling_eps, N, device=device)
|
||||||
|
|
||||||
for i in range(sde.N):
|
for i in range(N):
|
||||||
t = timesteps[i]
|
t = timesteps[i]
|
||||||
vec_t = torch.ones(shape[0], device=t.device) * t
|
vec_t = torch.ones(shape[0], device=t.device) * t
|
||||||
# x, x_mean = corrector_update_fn(x, vec_t, model=model)
|
|
||||||
# x, x_mean = predictor_update_fn(x, vec_t, model=model)
|
|
||||||
x, x_mean = new_corrector.update_fn(x, vec_t)
|
x, x_mean = new_corrector.update_fn(x, vec_t)
|
||||||
x, x_mean = new_predictor.update_fn(x, vec_t)
|
x, x_mean = new_predictor.update_fn(x, vec_t)
|
||||||
|
|
||||||
x, n = inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1)
|
x = inverse_scaler(x_mean)
|
||||||
|
|
||||||
|
|
||||||
|
save_image(x)
|
||||||
#save_image(x)
|
|
||||||
|
|
||||||
# for 5 cifar10
|
# for 5 cifar10
|
||||||
x_sum = 106071.9922
|
x_sum = 106071.9922
|
||||||
|
@ -310,4 +260,4 @@ def check_x_sum_x_mean(x, x_sum, x_mean):
|
||||||
assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
|
assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
|
||||||
|
|
||||||
|
|
||||||
check_x_sum_x_mean(x, x_sum, x_mean)
|
#check_x_sum_x_mean(x, x_sum, x_mean)
|
||||||
|
|
Loading…
Reference in New Issue