save clean-up
This commit is contained in:
parent
e8977e957c
commit
ef044a7231
|
@ -40,39 +40,34 @@ class DDIM(DiffusionPipeline):
|
||||||
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
|
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
|
||||||
|
|
||||||
self.unet.to(torch_device)
|
self.unet.to(torch_device)
|
||||||
x = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
|
image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
|
||||||
|
|
||||||
b = self.noise_scheduler.betas.to(torch_device)
|
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
|
||||||
|
# get actual t and t-1
|
||||||
|
train_step = inference_step_times[t]
|
||||||
|
prev_train_step = inference_step_times[t - 1] if t > 0 else - 1
|
||||||
|
|
||||||
seq = inference_step_times
|
# compute alphas
|
||||||
seq_next = [-1] + list(seq[:-1])
|
alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step)
|
||||||
# for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
|
alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step)
|
||||||
# train_step = inference_step_times[t]
|
alpha_prod_t_rsqrt = 1 / alpha_prod_t.sqrt()
|
||||||
for i, j in zip(reversed(seq), reversed(seq_next)):
|
alpha_prod_t_prev_rsqrt = 1 / alpha_prod_t_prev.sqrt()
|
||||||
|
beta_prod_t_sqrt = (1 - alpha_prod_t).sqrt()
|
||||||
|
beta_prod_t_prev_sqrt = (1 - alpha_prod_t_prev).sqrt()
|
||||||
|
|
||||||
n = batch_size
|
# compute relevant coefficients
|
||||||
x0_preds = []
|
coeff_1 = (alpha_prod_t_prev - alpha_prod_t).sqrt() * alpha_prod_t_prev_rsqrt * beta_prod_t_prev_sqrt / beta_prod_t_sqrt * eta
|
||||||
xs = [x]
|
coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1 ** 2).sqrt()
|
||||||
|
|
||||||
# i = train_step
|
with torch.no_grad():
|
||||||
# j = inference_step_times[t-1] if t > 0 else -1
|
noise_residual = self.unet(image, train_step)
|
||||||
if True:
|
|
||||||
print(i)
|
|
||||||
t = (torch.ones(n) * i).to(x.device)
|
|
||||||
next_t = (torch.ones(n) * j).to(x.device)
|
|
||||||
at = compute_alpha(b, t.long())
|
|
||||||
at_next = compute_alpha(b, next_t.long())
|
|
||||||
xt = xs[-1].to('cuda')
|
|
||||||
with torch.no_grad():
|
|
||||||
et = self.unet(xt, t)
|
|
||||||
x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
|
|
||||||
x0_preds.append(x0_t.to('cpu'))
|
|
||||||
# eta
|
|
||||||
c1 = (
|
|
||||||
eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
|
|
||||||
)
|
|
||||||
c2 = ((1 - at_next) - c1 ** 2).sqrt()
|
|
||||||
xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
|
|
||||||
xs.append(xt_next.to('cpu'))
|
|
||||||
|
|
||||||
return xt_next
|
print(train_step)
|
||||||
|
|
||||||
|
pred_mean = (image - noise_residual * beta_prod_t_sqrt) * alpha_prod_t_rsqrt
|
||||||
|
xt_next = alpha_prod_t_prev.sqrt() * pred_mean + coeff_1 * torch.randn_like(image) + coeff_2 * noise_residual
|
||||||
|
# xt_next = 1 / alpha_prod_t_rsqrt * pred_mean + coeff_1 * torch.randn_like(image) + coeff_2 * noise_residual
|
||||||
|
# eta
|
||||||
|
image = xt_next
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
Loading…
Reference in New Issue