From bc2d586dcbba9429f4b0d9600a559fff18f599b6 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 25 Jun 2022 00:53:55 +0000 Subject: [PATCH] remove more dependencies --- run.py | 146 +++--------------- .../models/unet_sde_score_estimation.py | 125 ++++++++++----- 2 files changed, 105 insertions(+), 166 deletions(-) diff --git a/run.py b/run.py index 0180c348..cae97139 100755 --- a/run.py +++ b/run.py @@ -2,105 +2,14 @@ import numpy as np import PIL import torch -import ml_collections #from configs.ve import ffhq_ncsnpp_continuous as configs # from configs.ve import cifar10_ncsnpp_continuous as configs -# ffhq_ncsnpp_continuous config -def get_config(): - config = ml_collections.ConfigDict() - # training - config.training = training = ml_collections.ConfigDict() - training.batch_size = 8 - training.n_iters = 2400001 - training.snapshot_freq = 50000 - training.log_freq = 50 - training.eval_freq = 100 - training.snapshot_freq_for_preemption = 5000 - training.snapshot_sampling = True - training.sde = 'vesde' - training.continuous = True - training.likelihood_weighting = False - training.reduce_mean = True - - # sampling - config.sampling = sampling = ml_collections.ConfigDict() - sampling.method = 'pc' - sampling.predictor = 'reverse_diffusion' - sampling.corrector = 'langevin' - sampling.probability_flow = False - sampling.snr = 0.15 - sampling.n_steps_each = 1 - sampling.noise_removal = True - - # eval - config.eval = evaluate = ml_collections.ConfigDict() - evaluate.batch_size = 1024 - evaluate.num_samples = 50000 - evaluate.begin_ckpt = 1 - evaluate.end_ckpt = 96 - - # data - config.data = data = ml_collections.ConfigDict() - data.dataset = 'FFHQ' - data.image_size = 1024 - data.centered = False - data.random_flip = True - data.uniform_dequantization = False - data.num_channels = 3 - # Plug in your own path to the tfrecords file. - data.tfrecords_path = '/raid/song/ffhq-dataset/ffhq/ffhq-r10.tfrecords' - - # model - config.model = model = ml_collections.ConfigDict() - model.name = 'ncsnpp' - model.scale_by_sigma = True - model.sigma_max = 1348 - model.num_scales = 2000 - model.ema_rate = 0.9999 - model.sigma_min = 0.01 - model.normalization = 'GroupNorm' - model.nonlinearity = 'swish' - model.nf = 16 - model.ch_mult = (1, 2, 4, 8, 16, 32, 32, 32) - model.num_res_blocks = 1 - model.attn_resolutions = (16,) - model.dropout = 0. - model.resamp_with_conv = True - model.conditional = True - model.fir = True - model.fir_kernel = [1, 3, 3, 1] - model.skip_rescale = True - model.resblock_type = 'biggan' - model.progressive = 'output_skip' - model.progressive_input = 'input_skip' - model.progressive_combine = 'sum' - model.attention_type = 'ddpm' - model.init_scale = 0. - model.fourier_scale = 16 - model.conv_size = 3 - model.embedding_type = 'fourier' - - # optim - config.optim = optim = ml_collections.ConfigDict() - optim.weight_decay = 0 - optim.optimizer = 'Adam' - optim.lr = 2e-4 - optim.beta1 = 0.9 - optim.amsgrad = False - optim.eps = 1e-8 - optim.warmup = 5000 - optim.grad_clip = 1. - - config.seed = 42 - config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') - - return config - +device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') torch.backends.cuda.matmul.allow_tf32 = False -torch.manual_seed(3) +torch.manual_seed(0) class NewReverseDiffusionPredictor: @@ -182,47 +91,26 @@ def save_image(x): # Note usually we need to restore ema etc... # ema restored checkpoint used from below - - -config = get_config() - -sigma_min, sigma_max = config.model.sigma_min, config.model.sigma_max -N = config.model.num_scales - +N = 2 +sigma_min = 0.01 +sigma_max = 1348 sampling_eps = 1e-5 - -batch_size = 1 #@param {"type":"integer"} -config.training.batch_size = batch_size -config.eval.batch_size = batch_size +batch_size = 1 +centered = False from diffusers import NCSNpp -model = NCSNpp(config).to(config.device) + +model = NCSNpp.from_pretrained("/home/patrick/ffhq_ncsnpp").to(device) model = torch.nn.DataParallel(model) -loaded_state = torch.load("../score_sde_pytorch/ffhq_1024_ncsnpp_continuous_ema.pt") -del loaded_state["module.sigmas"] -model.load_state_dict(loaded_state, strict=False) - -def get_data_inverse_scaler(config): - """Inverse data normalizer.""" - if config.data.centered: - # Rescale [-1, 1] to [0, 1] - return lambda x: (x + 1.) / 2. - else: - return lambda x: x - -inverse_scaler = get_data_inverse_scaler(config) - -img_size = config.data.image_size -channels = config.data.num_channels +img_size = model.module.config.image_size +channels = model.module.config.num_channels shape = (batch_size, channels, img_size, img_size) probability_flow = False -snr = 0.15 #@param {"type": "number"} -n_steps = 1#@param {"type": "integer"} +snr = 0.15 +n_steps = 1 -device = config.device - new_corrector = NewLangevinCorrector(score_fn=model, snr=snr, n_steps=n_steps, sigma_min=sigma_min, sigma_max=sigma_max) new_predictor = NewReverseDiffusionPredictor(score_fn=model, sigma_min=sigma_min, sigma_max=sigma_max, N=N) @@ -238,10 +126,12 @@ with torch.no_grad(): x, x_mean = new_corrector.update_fn(x, vec_t) x, x_mean = new_predictor.update_fn(x, vec_t) - x = inverse_scaler(x_mean) + x = x_mean + if centered: + x = (x + 1.) / 2. -save_image(x) +# save_image(x) # for 5 cifar10 x_sum = 106071.9922 @@ -260,4 +150,4 @@ def check_x_sum_x_mean(x, x_sum, x_mean): assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" -#check_x_sum_x_mean(x, x_sum, x_mean) +check_x_sum_x_mean(x, x_sum, x_mean) diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 26b4419e..30671ef2 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -15,6 +15,9 @@ # helpers functions +from ..modeling_utils import ModelMixin +from ..configuration_utils import ConfigMixin + import functools import math @@ -372,16 +375,16 @@ class NIN(nn.Module): return y.permute(0, 3, 1, 2) -def get_act(config): +def get_act(nonlinearity): """Get activation functions from the config file.""" - if config.model.nonlinearity.lower() == "elu": + if nonlinearity.lower() == "elu": return nn.ELU() - elif config.model.nonlinearity.lower() == "relu": + elif nonlinearity.lower() == "relu": return nn.ReLU() - elif config.model.nonlinearity.lower() == "lrelu": + elif nonlinearity.lower() == "lrelu": return nn.LeakyReLU(negative_slope=0.2) - elif config.model.nonlinearity.lower() == "swish": + elif nonlinearity.lower() == "swish": return nn.SiLU() else: raise NotImplementedError("activation function does not exist!") @@ -710,46 +713,93 @@ class ResnetBlockBigGANpp(nn.Module): return (x + h) / np.sqrt(2.0) -class NCSNpp(nn.Module): +class NCSNpp(ModelMixin, ConfigMixin): """NCSN++ model""" - def __init__(self, config): + def __init__( + self, + centered=False, + image_size=1024, + num_channels=3, + attention_type="ddpm", + attn_resolutions=(16,), + ch_mult=(1, 2, 4, 8, 16, 32, 32, 32), + conditional=True, + conv_size=3, + dropout=0.0, + embedding_type="fourier", + fir=True, + fir_kernel=(1, 3, 3, 1), + fourier_scale=16, + init_scale=0.0, + nf=16, + nonlinearity="swish", + normalization="GroupNorm", + num_res_blocks=1, + progressive="output_skip", + progressive_combine="sum", + progressive_input="input_skip", + resamp_with_conv=True, + resblock_type="biggan", + scale_by_sigma=True, + skip_rescale=True, + continuous=True, + ): super().__init__() - self.config = config - self.act = act = get_act(config) + self.register_to_config( + centered=centered, + image_size=image_size, + num_channels=num_channels, + attention_type=attention_type, + attn_resolutions=attn_resolutions, + ch_mult=ch_mult, + conditional=conditional, + conv_size=conv_size, + dropout=dropout, + embedding_type=embedding_type, + fir=fir, + fir_kernel=fir_kernel, + fourier_scale=fourier_scale, + init_scale=init_scale, + nf=nf, + nonlinearity=nonlinearity, + normalization=normalization, + num_res_blocks=num_res_blocks, + progressive=progressive, + progressive_combine=progressive_combine, + progressive_input=progressive_input, + resamp_with_conv=resamp_with_conv, + resblock_type=resblock_type, + scale_by_sigma=scale_by_sigma, + skip_rescale=skip_rescale, + continuous=continuous, + ) + self.act = act = get_act(nonlinearity) # self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config))) - self.nf = nf = config.model.nf - ch_mult = config.model.ch_mult - self.num_res_blocks = num_res_blocks = config.model.num_res_blocks - self.attn_resolutions = attn_resolutions = config.model.attn_resolutions - dropout = config.model.dropout - resamp_with_conv = config.model.resamp_with_conv - self.num_resolutions = num_resolutions = len(ch_mult) - self.all_resolutions = all_resolutions = [config.data.image_size // (2**i) for i in range(num_resolutions)] + self.nf = nf + self.num_res_blocks = num_res_blocks + self.attn_resolutions = attn_resolutions + self.num_resolutions = len(ch_mult) + self.all_resolutions = all_resolutions = [image_size // (2**i) for i in range(self.num_resolutions)] - self.conditional = conditional = config.model.conditional # noise-conditional - fir = config.model.fir - fir_kernel = config.model.fir_kernel - self.skip_rescale = skip_rescale = config.model.skip_rescale - self.resblock_type = resblock_type = config.model.resblock_type.lower() - self.progressive = progressive = config.model.progressive.lower() - self.progressive_input = progressive_input = config.model.progressive_input.lower() - self.embedding_type = embedding_type = config.model.embedding_type.lower() - init_scale = config.model.init_scale + self.conditional = conditional + self.skip_rescale = skip_rescale + self.resblock_type = resblock_type + self.progressive = progressive + self.progressive_input = progressive_input + self.embedding_type = embedding_type assert progressive in ["none", "output_skip", "residual"] assert progressive_input in ["none", "input_skip", "residual"] assert embedding_type in ["fourier", "positional"] - combine_method = config.model.progressive_combine.lower() + combine_method = progressive_combine.lower() combiner = functools.partial(Combine, method=combine_method) modules = [] # timestep/noise_level embedding; only for continuous training if embedding_type == "fourier": # Gaussian Fourier features embeddings. - assert config.training.continuous, "Fourier features are only used for continuous training." - - modules.append(GaussianFourierProjection(embedding_size=nf, scale=config.model.fourier_scale)) + modules.append(GaussianFourierProjection(embedding_size=nf, scale=fourier_scale)) embed_dim = 2 * nf elif embedding_type == "positional": @@ -809,7 +859,7 @@ class NCSNpp(nn.Module): # Downsampling block - channels = config.data.num_channels + channels = num_channels if progressive_input != "none": input_pyramid_ch = channels @@ -817,7 +867,7 @@ class NCSNpp(nn.Module): hs_c = [nf] in_ch = nf - for i_level in range(num_resolutions): + for i_level in range(self.num_resolutions): # Residual blocks for this resolution for i_block in range(num_res_blocks): out_ch = nf * ch_mult[i_level] @@ -828,7 +878,7 @@ class NCSNpp(nn.Module): modules.append(AttnBlock(channels=in_ch)) hs_c.append(in_ch) - if i_level != num_resolutions - 1: + if i_level != self.num_resolutions - 1: if resblock_type == "ddpm": modules.append(Downsample(in_ch=in_ch)) else: @@ -852,7 +902,7 @@ class NCSNpp(nn.Module): pyramid_ch = 0 # Upsampling block - for i_level in reversed(range(num_resolutions)): + for i_level in reversed(range(self.num_resolutions)): for i_block in range(num_res_blocks + 1): out_ch = nf * ch_mult[i_level] modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch)) @@ -862,7 +912,7 @@ class NCSNpp(nn.Module): modules.append(AttnBlock(channels=in_ch)) if progressive != "none": - if i_level == num_resolutions - 1: + if i_level == self.num_resolutions - 1: if progressive == "output_skip": modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) @@ -899,7 +949,6 @@ class NCSNpp(nn.Module): self.all_modules = nn.ModuleList(modules) def forward(self, x, time_cond): - # import ipdb; ipdb.set_trace() # timestep/noise_level embedding; only for continuous training modules = self.all_modules m_idx = 0 @@ -926,7 +975,7 @@ class NCSNpp(nn.Module): else: temb = None - if not self.config.data.centered: + if not self.config.centered: # If input data is in [0, 1] x = 2 * x - 1.0 @@ -1044,7 +1093,7 @@ class NCSNpp(nn.Module): m_idx += 1 assert m_idx == len(modules) - if self.config.model.scale_by_sigma: + if self.config.scale_by_sigma: used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:])))) h = h / used_sigmas