upload cleaner scripts

This commit is contained in:
Patrick von Platen 2022-06-09 18:18:18 +02:00
parent 8d6494439c
commit 97226d97d4
5 changed files with 33 additions and 38 deletions

View File

@ -64,8 +64,8 @@ class DDIM(DiffusionPipeline):
# 3. compute alphas, betas
alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step)
alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step)
beta_prod_t = (1 - alpha_prod_t)
beta_prod_t_prev = (1 - alpha_prod_t_prev)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
# 4. Compute predicted previous image from predicted noise
# First: compute predicted original image from predicted noise also called

View File

@ -46,8 +46,8 @@ class DDPM(DiffusionPipeline):
# 2. compute alphas, betas
alpha_prod_t = self.noise_scheduler.get_alpha_prod(t)
alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(t - 1)
beta_prod_t = (1 - alpha_prod_t)
beta_prod_t_prev = (1 - alpha_prod_t_prev)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
# 3. compute predicted image from residual
# First: compute predicted original image from predicted noise also called
@ -69,12 +69,14 @@ class DDPM(DiffusionPipeline):
# and sample from it to get previous image
# x_{t-1} ~ N(pred_prev_image, variance) == add variane to pred_image
if t > 0:
variance = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.noise_scheduler.get_beta(t)).sqrt()
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.noise_scheduler.get_beta(t).sqrt()
# TODO(PVP):
# This variance seems to be good enough for inference - check if those `fix_small`, `fix_large`
# are really only needed for training or also for inference
# Also note LDM only uses "fixed_small";
# glide seems to use a weird mix of the two: https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
sampled_variance = variance * noise
# sampled_variance = self.noise_scheduler.sample_variance(
# t, pred_prev_image.shape, device=torch_device, generator=generator
# )
prev_image = pred_prev_image + sampled_variance
else:
prev_image = pred_prev_image

View File

@ -2,10 +2,10 @@
import math
import numpy as np
import tqdm
import torch
import torch.nn as nn
import tqdm
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import ConfigMixin
from diffusers.modeling_utils import ModelMixin
@ -740,29 +740,30 @@ class DiagonalGaussianDistribution(object):
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1,2,3]):
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.])
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
def mode(self):
return self.mean
class AutoencoderKL(ModelMixin, ConfigMixin):
def __init__(
self,
@ -836,7 +837,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
give_pre_end=give_pre_end,
)
self.quant_conv = torch.nn.Conv2d(2*z_channels, 2*embed_dim, 1)
self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
def encode(self, x):
@ -874,11 +875,11 @@ class LatentDiffusion(DiffusionPipeline):
self.unet.to(torch_device)
self.vqvae.to(torch_device)
self.bert.to(torch_device)
# get text embedding
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors='pt').to(torch_device)
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
text_embedding = self.bert(**text_input)[0]
num_trained_timesteps = self.noise_scheduler.num_timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
@ -927,9 +928,9 @@ class LatentDiffusion(DiffusionPipeline):
image = pred_mean + coeff_1 * noise
else:
image = pred_mean
image = 1 / image
image = self.vqvae(image)
image = torch.clamp((image+1.0)/2.0, min=0.0, max=1.0)
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
return image

View File

@ -113,7 +113,7 @@ class DiffusionPipeline(ConfigMixin):
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path):

View File

@ -62,6 +62,9 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# TODO(PVP) - check how much of these is actually necessary!
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
if variance_type == "fixed_small":
log_variance = torch.log(variance.clamp(min=1e-20))
elif variance_type == "fixed_large":
@ -84,17 +87,6 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
return torch.tensor(1.0)
return self.alphas_cumprod[time_step]
def sample_variance(self, time_step, shape, device, generator=None):
variance = self.log_variance[time_step]
nonzero_mask = torch.tensor([1 - (time_step == 0)], device=device).float()[None, :]
noise = self.sample_noise(shape, device=device, generator=generator)
sampled_variance = nonzero_mask * (0.5 * variance).exp()
sampled_variance = sampled_variance * noise
return sampled_variance
def sample_noise(self, shape, device, generator=None):
# always sample on CPU to be deterministic
return torch.randn(shape, generator=generator).to(device)