From ef044a7231a3b74d42057e46438ae7cea456109c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 8 Jun 2022 10:52:07 +0000 Subject: [PATCH] save clean-up --- models/vision/ddim/modeling_ddim.py | 57 +++++++++++++---------------- 1 file changed, 26 insertions(+), 31 deletions(-) diff --git a/models/vision/ddim/modeling_ddim.py b/models/vision/ddim/modeling_ddim.py index 513f6b58..0b618605 100644 --- a/models/vision/ddim/modeling_ddim.py +++ b/models/vision/ddim/modeling_ddim.py @@ -40,39 +40,34 @@ class DDIM(DiffusionPipeline): inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) self.unet.to(torch_device) - x = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator) + image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator) - b = self.noise_scheduler.betas.to(torch_device) + for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): + # get actual t and t-1 + train_step = inference_step_times[t] + prev_train_step = inference_step_times[t - 1] if t > 0 else - 1 - seq = inference_step_times - seq_next = [-1] + list(seq[:-1]) -# for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): -# train_step = inference_step_times[t] - for i, j in zip(reversed(seq), reversed(seq_next)): + # compute alphas + alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step) + alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step) + alpha_prod_t_rsqrt = 1 / alpha_prod_t.sqrt() + alpha_prod_t_prev_rsqrt = 1 / alpha_prod_t_prev.sqrt() + beta_prod_t_sqrt = (1 - alpha_prod_t).sqrt() + beta_prod_t_prev_sqrt = (1 - alpha_prod_t_prev).sqrt() - n = batch_size - x0_preds = [] - xs = [x] + # compute relevant coefficients + coeff_1 = (alpha_prod_t_prev - alpha_prod_t).sqrt() * alpha_prod_t_prev_rsqrt * beta_prod_t_prev_sqrt / beta_prod_t_sqrt * eta + coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1 ** 2).sqrt() -# i = train_step -# j = inference_step_times[t-1] if t > 0 else -1 - if True: - print(i) - t = (torch.ones(n) * i).to(x.device) - next_t = (torch.ones(n) * j).to(x.device) - at = compute_alpha(b, t.long()) - at_next = compute_alpha(b, next_t.long()) - xt = xs[-1].to('cuda') - with torch.no_grad(): - et = self.unet(xt, t) - x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() - x0_preds.append(x0_t.to('cpu')) - # eta - c1 = ( - eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt() - ) - c2 = ((1 - at_next) - c1 ** 2).sqrt() - xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et - xs.append(xt_next.to('cpu')) + with torch.no_grad(): + noise_residual = self.unet(image, train_step) - return xt_next + print(train_step) + + pred_mean = (image - noise_residual * beta_prod_t_sqrt) * alpha_prod_t_rsqrt + xt_next = alpha_prod_t_prev.sqrt() * pred_mean + coeff_1 * torch.randn_like(image) + coeff_2 * noise_residual +# xt_next = 1 / alpha_prod_t_rsqrt * pred_mean + coeff_1 * torch.randn_like(image) + coeff_2 * noise_residual + # eta + image = xt_next + + return image