This commit is contained in:
parent
809591b7b6
commit
059a6e9d82
|
@ -44,12 +44,17 @@ class PNDM(DiffusionPipeline):
|
||||||
)
|
)
|
||||||
image = image.to(torch_device)
|
image = image.to(torch_device)
|
||||||
|
|
||||||
seq = inference_step_times
|
seq = list(inference_step_times)
|
||||||
seq_next = [-1] + list(seq[:-1])
|
seq_next = [-1] + list(seq[:-1])
|
||||||
model = self.unet
|
model = self.unet
|
||||||
|
|
||||||
ets = []
|
ets = []
|
||||||
for i, j in zip(reversed(seq), reversed(seq_next)):
|
prev_noises = []
|
||||||
|
step_idx = len(seq) - 1
|
||||||
|
while step_idx >= 0:
|
||||||
|
i = seq[step_idx]
|
||||||
|
j = seq_next[step_idx]
|
||||||
|
|
||||||
t = (torch.ones(image.shape[0]) * i)
|
t = (torch.ones(image.shape[0]) * i)
|
||||||
t_next = (torch.ones(image.shape[0]) * j)
|
t_next = (torch.ones(image.shape[0]) * j)
|
||||||
|
|
||||||
|
@ -58,10 +63,11 @@ class PNDM(DiffusionPipeline):
|
||||||
|
|
||||||
t_list = [t, (t+t_next)/2, t_next]
|
t_list = [t, (t+t_next)/2, t_next]
|
||||||
|
|
||||||
if len(ets) <= 2:
|
ets.append(residual)
|
||||||
ets.append(residual)
|
if len(ets) <= 3:
|
||||||
image = image.to("cpu")
|
image = image.to("cpu")
|
||||||
x_2 = self.noise_scheduler.transfer(image, t_list[0], t_list[1], residual)
|
x_2 = self.noise_scheduler.transfer(image.to("cpu"), t_list[0], t_list[1], residual)
|
||||||
|
|
||||||
e_2 = model(x_2.to("cuda"), t_list[1].to("cuda")).to("cpu")
|
e_2 = model(x_2.to("cuda"), t_list[1].to("cuda")).to("cpu")
|
||||||
x_3 = self.noise_scheduler.transfer(image, t_list[0], t_list[1], e_2)
|
x_3 = self.noise_scheduler.transfer(image, t_list[0], t_list[1], e_2)
|
||||||
e_3 = model(x_3.to("cuda"), t_list[1].to("cuda")).to("cpu")
|
e_3 = model(x_3.to("cuda"), t_list[1].to("cuda")).to("cpu")
|
||||||
|
@ -69,17 +75,35 @@ class PNDM(DiffusionPipeline):
|
||||||
e_4 = model(x_4.to("cuda"), t_list[2].to("cuda")).to("cpu")
|
e_4 = model(x_4.to("cuda"), t_list[2].to("cuda")).to("cpu")
|
||||||
residual = (1 / 6) * (residual + 2 * e_2 + 2 * e_3 + e_4)
|
residual = (1 / 6) * (residual + 2 * e_2 + 2 * e_3 + e_4)
|
||||||
else:
|
else:
|
||||||
ets.append(residual)
|
|
||||||
residual = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4])
|
residual = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4])
|
||||||
|
|
||||||
img_next = self.noise_scheduler.transfer(image.to("cpu"), t, t_next, residual)
|
img_next = self.noise_scheduler.transfer(image.to("cpu"), t, t_next, residual)
|
||||||
|
|
||||||
# with torch.no_grad():
|
|
||||||
# t_start, t_end = t_next, t
|
|
||||||
# img_next, ets = self.noise_scheduler.step(image, t_start, t_end, model, ets)
|
|
||||||
|
|
||||||
image = img_next
|
image = img_next
|
||||||
|
|
||||||
|
step_idx = step_idx - 1
|
||||||
|
|
||||||
|
# if len(prev_noises) in [1, 2]:
|
||||||
|
# t = (t + t_next) / 2
|
||||||
|
# elif len(prev_noises) == 3:
|
||||||
|
# t = t_next / 2
|
||||||
|
|
||||||
|
# if len(prev_noises) == 0:
|
||||||
|
# ets.append(residual)
|
||||||
|
#
|
||||||
|
# if len(ets) > 3:
|
||||||
|
# residual = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4])
|
||||||
|
# step_idx = step_idx - 1
|
||||||
|
# elif len(ets) <= 3 and len(prev_noises) == 3:
|
||||||
|
# residual = (1 / 6) * (prev_noises[-3] + 2 * prev_noises[-2] + 2 * prev_noises[-1] + residual)
|
||||||
|
# prev_noises = []
|
||||||
|
# step_idx = step_idx - 1
|
||||||
|
# elif len(ets) <= 3 and len(prev_noises) < 3:
|
||||||
|
# prev_noises.append(residual)
|
||||||
|
# if len(prev_noises) < 2:
|
||||||
|
# t_next = (t + t_next) / 2
|
||||||
|
#
|
||||||
|
# image = self.noise_scheduler.transfer(image.to("cpu"), t, t_next, residual)
|
||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
||||||
|
|
Loading…
Reference in New Issue