add LatentDiffusion pipeline

This commit is contained in:
patil-suraj 2022-06-09 16:40:20 +02:00
parent 4229101ea2
commit 302ac73b74
2 changed files with 216 additions and 0 deletions

View File

@ -2,9 +2,11 @@
import math
import numpy as np
import tqdm
import torch
import torch.nn as nn
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import ConfigMixin
from diffusers.modeling_utils import ModelMixin
@ -719,3 +721,215 @@ class VQModel(ModelMixin, ConfigMixin):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1,2,3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
class AutoencoderKL(ModelMixin, ConfigMixin):
def __init__(
self,
ch,
out_ch,
num_res_blocks,
attn_resolutions,
in_channels,
resolution,
z_channels,
n_embed,
embed_dim,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
ch_mult=(1, 2, 4, 8),
dropout=0.0,
double_z=True,
resamp_with_conv=True,
give_pre_end=False,
):
super().__init__()
# register all __init__ params with self.register
self.register(
ch=ch,
out_ch=out_ch,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions,
in_channels=in_channels,
resolution=resolution,
z_channels=z_channels,
n_embed=n_embed,
embed_dim=embed_dim,
remap=remap,
sane_index_shape=sane_index_shape,
ch_mult=ch_mult,
dropout=dropout,
double_z=double_z,
resamp_with_conv=resamp_with_conv,
give_pre_end=give_pre_end,
)
# pass init params to Encoder
self.encoder = Encoder(
ch=ch,
out_ch=out_ch,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions,
in_channels=in_channels,
resolution=resolution,
z_channels=z_channels,
ch_mult=ch_mult,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
double_z=double_z,
give_pre_end=give_pre_end,
)
# pass init params to Decoder
self.decoder = Decoder(
ch=ch,
out_ch=out_ch,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions,
in_channels=in_channels,
resolution=resolution,
z_channels=z_channels,
ch_mult=ch_mult,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
give_pre_end=give_pre_end,
)
self.quant_conv = torch.nn.Conv2d(2*z_channels, 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
class LatentDiffusion(DiffusionPipeline):
def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler):
super().__init__()
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler)
def __call__(self, prompt, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50):
# eta corresponds to η in paper and should be between [0, 1]
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.unet.to(torch_device)
self.vqvae.to(torch_device)
self.bert.to(torch_device)
# get text embedding
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors='pt').to(torch_device)
text_embedding = self.bert(**text_input)[0]
num_trained_timesteps = self.noise_scheduler.num_timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
image = self.noise_scheduler.sample_noise(
(batch_size, self.unet.in_channels, self.unet.resolution // 8, self.unet.resolution // 8),
device=torch_device,
generator=generator,
)
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# get actual t and t-1
train_step = inference_step_times[t]
prev_train_step = inference_step_times[t - 1] if t > 0 else -1
# compute alphas
alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step)
alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step)
alpha_prod_t_rsqrt = 1 / alpha_prod_t.sqrt()
alpha_prod_t_prev_rsqrt = 1 / alpha_prod_t_prev.sqrt()
beta_prod_t_sqrt = (1 - alpha_prod_t).sqrt()
beta_prod_t_prev_sqrt = (1 - alpha_prod_t_prev).sqrt()
# compute relevant coefficients
coeff_1 = (
(alpha_prod_t_prev - alpha_prod_t).sqrt()
* alpha_prod_t_prev_rsqrt
* beta_prod_t_prev_sqrt
/ beta_prod_t_sqrt
* eta
)
coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1**2).sqrt()
# model forward
with torch.no_grad():
train_step = torch.tensor([train_step] * image.shape[0], device=torch_device)
noise_residual = self.unet(image, train_step, context=text_embedding)
# predict mean of prev image
pred_mean = alpha_prod_t_rsqrt * (image - beta_prod_t_sqrt * noise_residual)
pred_mean = torch.clamp(pred_mean, -1, 1)
pred_mean = (1 / alpha_prod_t_prev_rsqrt) * pred_mean + coeff_2 * noise_residual
# if eta > 0.0 add noise. Note eta = 1.0 essentially corresponds to DDPM
if eta > 0.0:
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
image = pred_mean + coeff_1 * noise
else:
image = pred_mean
image = 1 / image
image = self.vqvae(image)
image = torch.clamp((image+1.0)/2.0, min=0.0, max=1.0)
return image

View File

@ -1024,6 +1024,8 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device)
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)