remove more dependencies
This commit is contained in:
parent
49a81f9f1a
commit
bc2d586dcb
146
run.py
146
run.py
|
@ -2,105 +2,14 @@
|
|||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
import ml_collections
|
||||
#from configs.ve import ffhq_ncsnpp_continuous as configs
|
||||
# from configs.ve import cifar10_ncsnpp_continuous as configs
|
||||
|
||||
|
||||
# ffhq_ncsnpp_continuous config
|
||||
def get_config():
|
||||
config = ml_collections.ConfigDict()
|
||||
# training
|
||||
config.training = training = ml_collections.ConfigDict()
|
||||
training.batch_size = 8
|
||||
training.n_iters = 2400001
|
||||
training.snapshot_freq = 50000
|
||||
training.log_freq = 50
|
||||
training.eval_freq = 100
|
||||
training.snapshot_freq_for_preemption = 5000
|
||||
training.snapshot_sampling = True
|
||||
training.sde = 'vesde'
|
||||
training.continuous = True
|
||||
training.likelihood_weighting = False
|
||||
training.reduce_mean = True
|
||||
|
||||
# sampling
|
||||
config.sampling = sampling = ml_collections.ConfigDict()
|
||||
sampling.method = 'pc'
|
||||
sampling.predictor = 'reverse_diffusion'
|
||||
sampling.corrector = 'langevin'
|
||||
sampling.probability_flow = False
|
||||
sampling.snr = 0.15
|
||||
sampling.n_steps_each = 1
|
||||
sampling.noise_removal = True
|
||||
|
||||
# eval
|
||||
config.eval = evaluate = ml_collections.ConfigDict()
|
||||
evaluate.batch_size = 1024
|
||||
evaluate.num_samples = 50000
|
||||
evaluate.begin_ckpt = 1
|
||||
evaluate.end_ckpt = 96
|
||||
|
||||
# data
|
||||
config.data = data = ml_collections.ConfigDict()
|
||||
data.dataset = 'FFHQ'
|
||||
data.image_size = 1024
|
||||
data.centered = False
|
||||
data.random_flip = True
|
||||
data.uniform_dequantization = False
|
||||
data.num_channels = 3
|
||||
# Plug in your own path to the tfrecords file.
|
||||
data.tfrecords_path = '/raid/song/ffhq-dataset/ffhq/ffhq-r10.tfrecords'
|
||||
|
||||
# model
|
||||
config.model = model = ml_collections.ConfigDict()
|
||||
model.name = 'ncsnpp'
|
||||
model.scale_by_sigma = True
|
||||
model.sigma_max = 1348
|
||||
model.num_scales = 2000
|
||||
model.ema_rate = 0.9999
|
||||
model.sigma_min = 0.01
|
||||
model.normalization = 'GroupNorm'
|
||||
model.nonlinearity = 'swish'
|
||||
model.nf = 16
|
||||
model.ch_mult = (1, 2, 4, 8, 16, 32, 32, 32)
|
||||
model.num_res_blocks = 1
|
||||
model.attn_resolutions = (16,)
|
||||
model.dropout = 0.
|
||||
model.resamp_with_conv = True
|
||||
model.conditional = True
|
||||
model.fir = True
|
||||
model.fir_kernel = [1, 3, 3, 1]
|
||||
model.skip_rescale = True
|
||||
model.resblock_type = 'biggan'
|
||||
model.progressive = 'output_skip'
|
||||
model.progressive_input = 'input_skip'
|
||||
model.progressive_combine = 'sum'
|
||||
model.attention_type = 'ddpm'
|
||||
model.init_scale = 0.
|
||||
model.fourier_scale = 16
|
||||
model.conv_size = 3
|
||||
model.embedding_type = 'fourier'
|
||||
|
||||
# optim
|
||||
config.optim = optim = ml_collections.ConfigDict()
|
||||
optim.weight_decay = 0
|
||||
optim.optimizer = 'Adam'
|
||||
optim.lr = 2e-4
|
||||
optim.beta1 = 0.9
|
||||
optim.amsgrad = False
|
||||
optim.eps = 1e-8
|
||||
optim.warmup = 5000
|
||||
optim.grad_clip = 1.
|
||||
|
||||
config.seed = 42
|
||||
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||
|
||||
return config
|
||||
|
||||
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.manual_seed(3)
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class NewReverseDiffusionPredictor:
|
||||
|
@ -182,47 +91,26 @@ def save_image(x):
|
|||
# Note usually we need to restore ema etc...
|
||||
# ema restored checkpoint used from below
|
||||
|
||||
|
||||
|
||||
config = get_config()
|
||||
|
||||
sigma_min, sigma_max = config.model.sigma_min, config.model.sigma_max
|
||||
N = config.model.num_scales
|
||||
|
||||
N = 2
|
||||
sigma_min = 0.01
|
||||
sigma_max = 1348
|
||||
sampling_eps = 1e-5
|
||||
|
||||
batch_size = 1 #@param {"type":"integer"}
|
||||
config.training.batch_size = batch_size
|
||||
config.eval.batch_size = batch_size
|
||||
batch_size = 1
|
||||
centered = False
|
||||
|
||||
from diffusers import NCSNpp
|
||||
model = NCSNpp(config).to(config.device)
|
||||
|
||||
model = NCSNpp.from_pretrained("/home/patrick/ffhq_ncsnpp").to(device)
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
loaded_state = torch.load("../score_sde_pytorch/ffhq_1024_ncsnpp_continuous_ema.pt")
|
||||
del loaded_state["module.sigmas"]
|
||||
model.load_state_dict(loaded_state, strict=False)
|
||||
|
||||
def get_data_inverse_scaler(config):
|
||||
"""Inverse data normalizer."""
|
||||
if config.data.centered:
|
||||
# Rescale [-1, 1] to [0, 1]
|
||||
return lambda x: (x + 1.) / 2.
|
||||
else:
|
||||
return lambda x: x
|
||||
|
||||
inverse_scaler = get_data_inverse_scaler(config)
|
||||
|
||||
img_size = config.data.image_size
|
||||
channels = config.data.num_channels
|
||||
img_size = model.module.config.image_size
|
||||
channels = model.module.config.num_channels
|
||||
shape = (batch_size, channels, img_size, img_size)
|
||||
probability_flow = False
|
||||
snr = 0.15 #@param {"type": "number"}
|
||||
n_steps = 1#@param {"type": "integer"}
|
||||
snr = 0.15
|
||||
n_steps = 1
|
||||
|
||||
|
||||
device = config.device
|
||||
|
||||
new_corrector = NewLangevinCorrector(score_fn=model, snr=snr, n_steps=n_steps, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
new_predictor = NewReverseDiffusionPredictor(score_fn=model, sigma_min=sigma_min, sigma_max=sigma_max, N=N)
|
||||
|
||||
|
@ -238,10 +126,12 @@ with torch.no_grad():
|
|||
x, x_mean = new_corrector.update_fn(x, vec_t)
|
||||
x, x_mean = new_predictor.update_fn(x, vec_t)
|
||||
|
||||
x = inverse_scaler(x_mean)
|
||||
x = x_mean
|
||||
if centered:
|
||||
x = (x + 1.) / 2.
|
||||
|
||||
|
||||
save_image(x)
|
||||
# save_image(x)
|
||||
|
||||
# for 5 cifar10
|
||||
x_sum = 106071.9922
|
||||
|
@ -260,4 +150,4 @@ def check_x_sum_x_mean(x, x_sum, x_mean):
|
|||
assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
|
||||
|
||||
|
||||
#check_x_sum_x_mean(x, x_sum, x_mean)
|
||||
check_x_sum_x_mean(x, x_sum, x_mean)
|
||||
|
|
|
@ -15,6 +15,9 @@
|
|||
|
||||
# helpers functions
|
||||
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..configuration_utils import ConfigMixin
|
||||
|
||||
|
||||
import functools
|
||||
import math
|
||||
|
@ -372,16 +375,16 @@ class NIN(nn.Module):
|
|||
return y.permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
def get_act(config):
|
||||
def get_act(nonlinearity):
|
||||
"""Get activation functions from the config file."""
|
||||
|
||||
if config.model.nonlinearity.lower() == "elu":
|
||||
if nonlinearity.lower() == "elu":
|
||||
return nn.ELU()
|
||||
elif config.model.nonlinearity.lower() == "relu":
|
||||
elif nonlinearity.lower() == "relu":
|
||||
return nn.ReLU()
|
||||
elif config.model.nonlinearity.lower() == "lrelu":
|
||||
elif nonlinearity.lower() == "lrelu":
|
||||
return nn.LeakyReLU(negative_slope=0.2)
|
||||
elif config.model.nonlinearity.lower() == "swish":
|
||||
elif nonlinearity.lower() == "swish":
|
||||
return nn.SiLU()
|
||||
else:
|
||||
raise NotImplementedError("activation function does not exist!")
|
||||
|
@ -710,46 +713,93 @@ class ResnetBlockBigGANpp(nn.Module):
|
|||
return (x + h) / np.sqrt(2.0)
|
||||
|
||||
|
||||
class NCSNpp(nn.Module):
|
||||
class NCSNpp(ModelMixin, ConfigMixin):
|
||||
"""NCSN++ model"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(
|
||||
self,
|
||||
centered=False,
|
||||
image_size=1024,
|
||||
num_channels=3,
|
||||
attention_type="ddpm",
|
||||
attn_resolutions=(16,),
|
||||
ch_mult=(1, 2, 4, 8, 16, 32, 32, 32),
|
||||
conditional=True,
|
||||
conv_size=3,
|
||||
dropout=0.0,
|
||||
embedding_type="fourier",
|
||||
fir=True,
|
||||
fir_kernel=(1, 3, 3, 1),
|
||||
fourier_scale=16,
|
||||
init_scale=0.0,
|
||||
nf=16,
|
||||
nonlinearity="swish",
|
||||
normalization="GroupNorm",
|
||||
num_res_blocks=1,
|
||||
progressive="output_skip",
|
||||
progressive_combine="sum",
|
||||
progressive_input="input_skip",
|
||||
resamp_with_conv=True,
|
||||
resblock_type="biggan",
|
||||
scale_by_sigma=True,
|
||||
skip_rescale=True,
|
||||
continuous=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.act = act = get_act(config)
|
||||
self.register_to_config(
|
||||
centered=centered,
|
||||
image_size=image_size,
|
||||
num_channels=num_channels,
|
||||
attention_type=attention_type,
|
||||
attn_resolutions=attn_resolutions,
|
||||
ch_mult=ch_mult,
|
||||
conditional=conditional,
|
||||
conv_size=conv_size,
|
||||
dropout=dropout,
|
||||
embedding_type=embedding_type,
|
||||
fir=fir,
|
||||
fir_kernel=fir_kernel,
|
||||
fourier_scale=fourier_scale,
|
||||
init_scale=init_scale,
|
||||
nf=nf,
|
||||
nonlinearity=nonlinearity,
|
||||
normalization=normalization,
|
||||
num_res_blocks=num_res_blocks,
|
||||
progressive=progressive,
|
||||
progressive_combine=progressive_combine,
|
||||
progressive_input=progressive_input,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
resblock_type=resblock_type,
|
||||
scale_by_sigma=scale_by_sigma,
|
||||
skip_rescale=skip_rescale,
|
||||
continuous=continuous,
|
||||
)
|
||||
self.act = act = get_act(nonlinearity)
|
||||
# self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config)))
|
||||
|
||||
self.nf = nf = config.model.nf
|
||||
ch_mult = config.model.ch_mult
|
||||
self.num_res_blocks = num_res_blocks = config.model.num_res_blocks
|
||||
self.attn_resolutions = attn_resolutions = config.model.attn_resolutions
|
||||
dropout = config.model.dropout
|
||||
resamp_with_conv = config.model.resamp_with_conv
|
||||
self.num_resolutions = num_resolutions = len(ch_mult)
|
||||
self.all_resolutions = all_resolutions = [config.data.image_size // (2**i) for i in range(num_resolutions)]
|
||||
self.nf = nf
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.all_resolutions = all_resolutions = [image_size // (2**i) for i in range(self.num_resolutions)]
|
||||
|
||||
self.conditional = conditional = config.model.conditional # noise-conditional
|
||||
fir = config.model.fir
|
||||
fir_kernel = config.model.fir_kernel
|
||||
self.skip_rescale = skip_rescale = config.model.skip_rescale
|
||||
self.resblock_type = resblock_type = config.model.resblock_type.lower()
|
||||
self.progressive = progressive = config.model.progressive.lower()
|
||||
self.progressive_input = progressive_input = config.model.progressive_input.lower()
|
||||
self.embedding_type = embedding_type = config.model.embedding_type.lower()
|
||||
init_scale = config.model.init_scale
|
||||
self.conditional = conditional
|
||||
self.skip_rescale = skip_rescale
|
||||
self.resblock_type = resblock_type
|
||||
self.progressive = progressive
|
||||
self.progressive_input = progressive_input
|
||||
self.embedding_type = embedding_type
|
||||
assert progressive in ["none", "output_skip", "residual"]
|
||||
assert progressive_input in ["none", "input_skip", "residual"]
|
||||
assert embedding_type in ["fourier", "positional"]
|
||||
combine_method = config.model.progressive_combine.lower()
|
||||
combine_method = progressive_combine.lower()
|
||||
combiner = functools.partial(Combine, method=combine_method)
|
||||
|
||||
modules = []
|
||||
# timestep/noise_level embedding; only for continuous training
|
||||
if embedding_type == "fourier":
|
||||
# Gaussian Fourier features embeddings.
|
||||
assert config.training.continuous, "Fourier features are only used for continuous training."
|
||||
|
||||
modules.append(GaussianFourierProjection(embedding_size=nf, scale=config.model.fourier_scale))
|
||||
modules.append(GaussianFourierProjection(embedding_size=nf, scale=fourier_scale))
|
||||
embed_dim = 2 * nf
|
||||
|
||||
elif embedding_type == "positional":
|
||||
|
@ -809,7 +859,7 @@ class NCSNpp(nn.Module):
|
|||
|
||||
# Downsampling block
|
||||
|
||||
channels = config.data.num_channels
|
||||
channels = num_channels
|
||||
if progressive_input != "none":
|
||||
input_pyramid_ch = channels
|
||||
|
||||
|
@ -817,7 +867,7 @@ class NCSNpp(nn.Module):
|
|||
hs_c = [nf]
|
||||
|
||||
in_ch = nf
|
||||
for i_level in range(num_resolutions):
|
||||
for i_level in range(self.num_resolutions):
|
||||
# Residual blocks for this resolution
|
||||
for i_block in range(num_res_blocks):
|
||||
out_ch = nf * ch_mult[i_level]
|
||||
|
@ -828,7 +878,7 @@ class NCSNpp(nn.Module):
|
|||
modules.append(AttnBlock(channels=in_ch))
|
||||
hs_c.append(in_ch)
|
||||
|
||||
if i_level != num_resolutions - 1:
|
||||
if i_level != self.num_resolutions - 1:
|
||||
if resblock_type == "ddpm":
|
||||
modules.append(Downsample(in_ch=in_ch))
|
||||
else:
|
||||
|
@ -852,7 +902,7 @@ class NCSNpp(nn.Module):
|
|||
|
||||
pyramid_ch = 0
|
||||
# Upsampling block
|
||||
for i_level in reversed(range(num_resolutions)):
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(num_res_blocks + 1):
|
||||
out_ch = nf * ch_mult[i_level]
|
||||
modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
|
||||
|
@ -862,7 +912,7 @@ class NCSNpp(nn.Module):
|
|||
modules.append(AttnBlock(channels=in_ch))
|
||||
|
||||
if progressive != "none":
|
||||
if i_level == num_resolutions - 1:
|
||||
if i_level == self.num_resolutions - 1:
|
||||
if progressive == "output_skip":
|
||||
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
|
||||
modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
|
||||
|
@ -899,7 +949,6 @@ class NCSNpp(nn.Module):
|
|||
self.all_modules = nn.ModuleList(modules)
|
||||
|
||||
def forward(self, x, time_cond):
|
||||
# import ipdb; ipdb.set_trace()
|
||||
# timestep/noise_level embedding; only for continuous training
|
||||
modules = self.all_modules
|
||||
m_idx = 0
|
||||
|
@ -926,7 +975,7 @@ class NCSNpp(nn.Module):
|
|||
else:
|
||||
temb = None
|
||||
|
||||
if not self.config.data.centered:
|
||||
if not self.config.centered:
|
||||
# If input data is in [0, 1]
|
||||
x = 2 * x - 1.0
|
||||
|
||||
|
@ -1044,7 +1093,7 @@ class NCSNpp(nn.Module):
|
|||
m_idx += 1
|
||||
|
||||
assert m_idx == len(modules)
|
||||
if self.config.model.scale_by_sigma:
|
||||
if self.config.scale_by_sigma:
|
||||
used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
|
||||
h = h / used_sigmas
|
||||
|
||||
|
|
Loading…
Reference in New Issue