From 7bb3dcd18e70873cf1a0a27c407388cdd6b796a6 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 10 Jun 2022 14:58:33 +0200 Subject: [PATCH] update ldm --- .../modeling_latent_diffusion.py | 41 ++++--------------- 1 file changed, 8 insertions(+), 33 deletions(-) diff --git a/models/vision/latent_diffusion/modeling_latent_diffusion.py b/models/vision/latent_diffusion/modeling_latent_diffusion.py index 14928076..bd4a8d8b 100644 --- a/models/vision/latent_diffusion/modeling_latent_diffusion.py +++ b/models/vision/latent_diffusion/modeling_latent_diffusion.py @@ -924,42 +924,17 @@ class LatentDiffusion(DiffusionPipeline): pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2) pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond) - # 2. get actual t and t-1 - train_step = inference_step_times[t] - prev_train_step = inference_step_times[t - 1] if t > 0 else -1 + # 2. predict previous mean of image x_t-1 + pred_prev_image = self.noise_scheduler.compute_prev_image_step(pred_noise_t, image, t, num_inference_steps, eta) - # 3. compute alphas, betas - alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step) - alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step) - beta_prod_t = 1 - alpha_prod_t - beta_prod_t_prev = 1 - alpha_prod_t_prev - - # 4. Compute predicted previous image from predicted noise - # First: compute predicted original image from predicted noise also called - # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt() - - # Second: Compute variance: "sigma_t(η)" -> see formula (16) - # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) - std_dev_t = (beta_prod_t_prev / beta_prod_t).sqrt() * (1 - alpha_prod_t / alpha_prod_t_prev).sqrt() - std_dev_t = eta * std_dev_t - - # Third: Compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2).sqrt() * pred_noise_t - - # Forth: Compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_prev_image = alpha_prod_t_prev.sqrt() * pred_original_image + pred_image_direction - - # 5. Sample x_t-1 image optionally if η > 0.0 by adding noise to pred_prev_image - # Note: eta = 1.0 essentially corresponds to DDPM - if eta > 0.0: + # 3. optionally sample variance + variance = 0 + if eta > 0: noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator) - prev_image = pred_prev_image + std_dev_t * noise - else: - prev_image = pred_prev_image + variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise - # 6. Set current image to prev_image: x_t -> x_t-1 - image = prev_image + # 4. set current image to prev_image: x_t -> x_t-1 + image = pred_prev_image + variance # scale and decode image with vae image = 1 / 0.18215 * image