diff --git a/models/vision/latent_diffusion/modeling_latent_diffusion.py b/models/vision/latent_diffusion/modeling_latent_diffusion.py index 4d160053..49cbef23 100644 --- a/models/vision/latent_diffusion/modeling_latent_diffusion.py +++ b/models/vision/latent_diffusion/modeling_latent_diffusion.py @@ -2,9 +2,11 @@ import math import numpy as np +import tqdm import torch import torch.nn as nn +from diffusers import DiffusionPipeline from diffusers.configuration_utils import ConfigMixin from diffusers.modeling_utils import ModelMixin @@ -719,3 +721,215 @@ class VQModel(ModelMixin, ConfigMixin): quant = self.post_quant_conv(quant) dec = self.decoder(quant) return dec + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + +class AutoencoderKL(ModelMixin, ConfigMixin): + def __init__( + self, + ch, + out_ch, + num_res_blocks, + attn_resolutions, + in_channels, + resolution, + z_channels, + n_embed, + embed_dim, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + ch_mult=(1, 2, 4, 8), + dropout=0.0, + double_z=True, + resamp_with_conv=True, + give_pre_end=False, + ): + super().__init__() + + # register all __init__ params with self.register + self.register( + ch=ch, + out_ch=out_ch, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolutions, + in_channels=in_channels, + resolution=resolution, + z_channels=z_channels, + n_embed=n_embed, + embed_dim=embed_dim, + remap=remap, + sane_index_shape=sane_index_shape, + ch_mult=ch_mult, + dropout=dropout, + double_z=double_z, + resamp_with_conv=resamp_with_conv, + give_pre_end=give_pre_end, + ) + + # pass init params to Encoder + self.encoder = Encoder( + ch=ch, + out_ch=out_ch, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolutions, + in_channels=in_channels, + resolution=resolution, + z_channels=z_channels, + ch_mult=ch_mult, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + double_z=double_z, + give_pre_end=give_pre_end, + ) + + # pass init params to Decoder + self.decoder = Decoder( + ch=ch, + out_ch=out_ch, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolutions, + in_channels=in_channels, + resolution=resolution, + z_channels=z_channels, + ch_mult=ch_mult, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + give_pre_end=give_pre_end, + ) + + self.quant_conv = torch.nn.Conv2d(2*z_channels, 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1) + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + +class LatentDiffusion(DiffusionPipeline): + def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler): + super().__init__() + self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler) + + def __call__(self, prompt, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50): + # eta corresponds to η in paper and should be between [0, 1] + + if torch_device is None: + torch_device = "cuda" if torch.cuda.is_available() else "cpu" + + self.unet.to(torch_device) + self.vqvae.to(torch_device) + self.bert.to(torch_device) + + # get text embedding + text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors='pt').to(torch_device) + text_embedding = self.bert(**text_input)[0] + + num_trained_timesteps = self.noise_scheduler.num_timesteps + inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) + + image = self.noise_scheduler.sample_noise( + (batch_size, self.unet.in_channels, self.unet.resolution // 8, self.unet.resolution // 8), + device=torch_device, + generator=generator, + ) + + for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): + # get actual t and t-1 + train_step = inference_step_times[t] + prev_train_step = inference_step_times[t - 1] if t > 0 else -1 + + # compute alphas + alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step) + alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step) + alpha_prod_t_rsqrt = 1 / alpha_prod_t.sqrt() + alpha_prod_t_prev_rsqrt = 1 / alpha_prod_t_prev.sqrt() + beta_prod_t_sqrt = (1 - alpha_prod_t).sqrt() + beta_prod_t_prev_sqrt = (1 - alpha_prod_t_prev).sqrt() + + # compute relevant coefficients + coeff_1 = ( + (alpha_prod_t_prev - alpha_prod_t).sqrt() + * alpha_prod_t_prev_rsqrt + * beta_prod_t_prev_sqrt + / beta_prod_t_sqrt + * eta + ) + coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1**2).sqrt() + + # model forward + with torch.no_grad(): + train_step = torch.tensor([train_step] * image.shape[0], device=torch_device) + noise_residual = self.unet(image, train_step, context=text_embedding) + + # predict mean of prev image + pred_mean = alpha_prod_t_rsqrt * (image - beta_prod_t_sqrt * noise_residual) + pred_mean = torch.clamp(pred_mean, -1, 1) + pred_mean = (1 / alpha_prod_t_prev_rsqrt) * pred_mean + coeff_2 * noise_residual + + # if eta > 0.0 add noise. Note eta = 1.0 essentially corresponds to DDPM + if eta > 0.0: + noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator) + image = pred_mean + coeff_1 * noise + else: + image = pred_mean + + image = 1 / image + image = self.vqvae(image) + image = torch.clamp((image+1.0)/2.0, min=0.0, max=1.0) + + return image diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index dcef3909..e85075a5 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -1024,6 +1024,8 @@ class UNetLDMModel(ModelMixin, ConfigMixin): self.num_classes is not None ), "must specify y if and only if the model is class-conditional" hs = [] + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device) t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) emb = self.time_embed(t_emb)