add LatentDiffusion pipeline
This commit is contained in:
parent
4229101ea2
commit
302ac73b74
|
@ -2,9 +2,11 @@
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import tqdm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
from diffusers.configuration_utils import ConfigMixin
|
from diffusers.configuration_utils import ConfigMixin
|
||||||
from diffusers.modeling_utils import ModelMixin
|
from diffusers.modeling_utils import ModelMixin
|
||||||
|
|
||||||
|
@ -719,3 +721,215 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||||
quant = self.post_quant_conv(quant)
|
quant = self.post_quant_conv(quant)
|
||||||
dec = self.decoder(quant)
|
dec = self.decoder(quant)
|
||||||
return dec
|
return dec
|
||||||
|
|
||||||
|
|
||||||
|
class DiagonalGaussianDistribution(object):
|
||||||
|
def __init__(self, parameters, deterministic=False):
|
||||||
|
self.parameters = parameters
|
||||||
|
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||||
|
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||||
|
self.deterministic = deterministic
|
||||||
|
self.std = torch.exp(0.5 * self.logvar)
|
||||||
|
self.var = torch.exp(self.logvar)
|
||||||
|
if self.deterministic:
|
||||||
|
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
||||||
|
|
||||||
|
def sample(self):
|
||||||
|
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def kl(self, other=None):
|
||||||
|
if self.deterministic:
|
||||||
|
return torch.Tensor([0.])
|
||||||
|
else:
|
||||||
|
if other is None:
|
||||||
|
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
||||||
|
+ self.var - 1.0 - self.logvar,
|
||||||
|
dim=[1, 2, 3])
|
||||||
|
else:
|
||||||
|
return 0.5 * torch.sum(
|
||||||
|
torch.pow(self.mean - other.mean, 2) / other.var
|
||||||
|
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||||
|
dim=[1, 2, 3])
|
||||||
|
|
||||||
|
def nll(self, sample, dims=[1,2,3]):
|
||||||
|
if self.deterministic:
|
||||||
|
return torch.Tensor([0.])
|
||||||
|
logtwopi = np.log(2.0 * np.pi)
|
||||||
|
return 0.5 * torch.sum(
|
||||||
|
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||||
|
dim=dims)
|
||||||
|
|
||||||
|
def mode(self):
|
||||||
|
return self.mean
|
||||||
|
|
||||||
|
class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ch,
|
||||||
|
out_ch,
|
||||||
|
num_res_blocks,
|
||||||
|
attn_resolutions,
|
||||||
|
in_channels,
|
||||||
|
resolution,
|
||||||
|
z_channels,
|
||||||
|
n_embed,
|
||||||
|
embed_dim,
|
||||||
|
remap=None,
|
||||||
|
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||||
|
ch_mult=(1, 2, 4, 8),
|
||||||
|
dropout=0.0,
|
||||||
|
double_z=True,
|
||||||
|
resamp_with_conv=True,
|
||||||
|
give_pre_end=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# register all __init__ params with self.register
|
||||||
|
self.register(
|
||||||
|
ch=ch,
|
||||||
|
out_ch=out_ch,
|
||||||
|
num_res_blocks=num_res_blocks,
|
||||||
|
attn_resolutions=attn_resolutions,
|
||||||
|
in_channels=in_channels,
|
||||||
|
resolution=resolution,
|
||||||
|
z_channels=z_channels,
|
||||||
|
n_embed=n_embed,
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
remap=remap,
|
||||||
|
sane_index_shape=sane_index_shape,
|
||||||
|
ch_mult=ch_mult,
|
||||||
|
dropout=dropout,
|
||||||
|
double_z=double_z,
|
||||||
|
resamp_with_conv=resamp_with_conv,
|
||||||
|
give_pre_end=give_pre_end,
|
||||||
|
)
|
||||||
|
|
||||||
|
# pass init params to Encoder
|
||||||
|
self.encoder = Encoder(
|
||||||
|
ch=ch,
|
||||||
|
out_ch=out_ch,
|
||||||
|
num_res_blocks=num_res_blocks,
|
||||||
|
attn_resolutions=attn_resolutions,
|
||||||
|
in_channels=in_channels,
|
||||||
|
resolution=resolution,
|
||||||
|
z_channels=z_channels,
|
||||||
|
ch_mult=ch_mult,
|
||||||
|
dropout=dropout,
|
||||||
|
resamp_with_conv=resamp_with_conv,
|
||||||
|
double_z=double_z,
|
||||||
|
give_pre_end=give_pre_end,
|
||||||
|
)
|
||||||
|
|
||||||
|
# pass init params to Decoder
|
||||||
|
self.decoder = Decoder(
|
||||||
|
ch=ch,
|
||||||
|
out_ch=out_ch,
|
||||||
|
num_res_blocks=num_res_blocks,
|
||||||
|
attn_resolutions=attn_resolutions,
|
||||||
|
in_channels=in_channels,
|
||||||
|
resolution=resolution,
|
||||||
|
z_channels=z_channels,
|
||||||
|
ch_mult=ch_mult,
|
||||||
|
dropout=dropout,
|
||||||
|
resamp_with_conv=resamp_with_conv,
|
||||||
|
give_pre_end=give_pre_end,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.quant_conv = torch.nn.Conv2d(2*z_channels, 2*embed_dim, 1)
|
||||||
|
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
h = self.encoder(x)
|
||||||
|
moments = self.quant_conv(h)
|
||||||
|
posterior = DiagonalGaussianDistribution(moments)
|
||||||
|
return posterior
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
z = self.post_quant_conv(z)
|
||||||
|
dec = self.decoder(z)
|
||||||
|
return dec
|
||||||
|
|
||||||
|
def forward(self, input, sample_posterior=True):
|
||||||
|
posterior = self.encode(input)
|
||||||
|
if sample_posterior:
|
||||||
|
z = posterior.sample()
|
||||||
|
else:
|
||||||
|
z = posterior.mode()
|
||||||
|
dec = self.decode(z)
|
||||||
|
return dec, posterior
|
||||||
|
|
||||||
|
|
||||||
|
class LatentDiffusion(DiffusionPipeline):
|
||||||
|
def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler):
|
||||||
|
super().__init__()
|
||||||
|
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler)
|
||||||
|
|
||||||
|
def __call__(self, prompt, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50):
|
||||||
|
# eta corresponds to η in paper and should be between [0, 1]
|
||||||
|
|
||||||
|
if torch_device is None:
|
||||||
|
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
self.unet.to(torch_device)
|
||||||
|
self.vqvae.to(torch_device)
|
||||||
|
self.bert.to(torch_device)
|
||||||
|
|
||||||
|
# get text embedding
|
||||||
|
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors='pt').to(torch_device)
|
||||||
|
text_embedding = self.bert(**text_input)[0]
|
||||||
|
|
||||||
|
num_trained_timesteps = self.noise_scheduler.num_timesteps
|
||||||
|
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
|
||||||
|
|
||||||
|
image = self.noise_scheduler.sample_noise(
|
||||||
|
(batch_size, self.unet.in_channels, self.unet.resolution // 8, self.unet.resolution // 8),
|
||||||
|
device=torch_device,
|
||||||
|
generator=generator,
|
||||||
|
)
|
||||||
|
|
||||||
|
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
|
||||||
|
# get actual t and t-1
|
||||||
|
train_step = inference_step_times[t]
|
||||||
|
prev_train_step = inference_step_times[t - 1] if t > 0 else -1
|
||||||
|
|
||||||
|
# compute alphas
|
||||||
|
alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step)
|
||||||
|
alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step)
|
||||||
|
alpha_prod_t_rsqrt = 1 / alpha_prod_t.sqrt()
|
||||||
|
alpha_prod_t_prev_rsqrt = 1 / alpha_prod_t_prev.sqrt()
|
||||||
|
beta_prod_t_sqrt = (1 - alpha_prod_t).sqrt()
|
||||||
|
beta_prod_t_prev_sqrt = (1 - alpha_prod_t_prev).sqrt()
|
||||||
|
|
||||||
|
# compute relevant coefficients
|
||||||
|
coeff_1 = (
|
||||||
|
(alpha_prod_t_prev - alpha_prod_t).sqrt()
|
||||||
|
* alpha_prod_t_prev_rsqrt
|
||||||
|
* beta_prod_t_prev_sqrt
|
||||||
|
/ beta_prod_t_sqrt
|
||||||
|
* eta
|
||||||
|
)
|
||||||
|
coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1**2).sqrt()
|
||||||
|
|
||||||
|
# model forward
|
||||||
|
with torch.no_grad():
|
||||||
|
train_step = torch.tensor([train_step] * image.shape[0], device=torch_device)
|
||||||
|
noise_residual = self.unet(image, train_step, context=text_embedding)
|
||||||
|
|
||||||
|
# predict mean of prev image
|
||||||
|
pred_mean = alpha_prod_t_rsqrt * (image - beta_prod_t_sqrt * noise_residual)
|
||||||
|
pred_mean = torch.clamp(pred_mean, -1, 1)
|
||||||
|
pred_mean = (1 / alpha_prod_t_prev_rsqrt) * pred_mean + coeff_2 * noise_residual
|
||||||
|
|
||||||
|
# if eta > 0.0 add noise. Note eta = 1.0 essentially corresponds to DDPM
|
||||||
|
if eta > 0.0:
|
||||||
|
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
|
||||||
|
image = pred_mean + coeff_1 * noise
|
||||||
|
else:
|
||||||
|
image = pred_mean
|
||||||
|
|
||||||
|
image = 1 / image
|
||||||
|
image = self.vqvae(image)
|
||||||
|
image = torch.clamp((image+1.0)/2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
|
@ -1024,6 +1024,8 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||||
self.num_classes is not None
|
self.num_classes is not None
|
||||||
), "must specify y if and only if the model is class-conditional"
|
), "must specify y if and only if the model is class-conditional"
|
||||||
hs = []
|
hs = []
|
||||||
|
if not torch.is_tensor(timesteps):
|
||||||
|
timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device)
|
||||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||||
emb = self.time_embed(t_emb)
|
emb = self.time_embed(t_emb)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue