improve pndm

This commit is contained in:
Patrick von Platen 2022-06-13 16:29:22 +00:00
parent 11631e8154
commit 809591b7b6
3 changed files with 57 additions and 17 deletions

View File

@ -40,7 +40,7 @@ class PNDM(DiffusionPipeline):
# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
# generator=torch.manual_seed(0)
generator=generator,
)
image = image.to(torch_device)
@ -53,9 +53,30 @@ class PNDM(DiffusionPipeline):
t = (torch.ones(image.shape[0]) * i)
t_next = (torch.ones(image.shape[0]) * j)
with torch.no_grad():
t_start, t_end = t_next, t
img_next, ets = self.noise_scheduler.step(image, t_start, t_end, model, ets)
residual = model(image.to("cuda"), t.to("cuda"))
residual = residual.to("cpu")
t_list = [t, (t+t_next)/2, t_next]
if len(ets) <= 2:
ets.append(residual)
image = image.to("cpu")
x_2 = self.noise_scheduler.transfer(image, t_list[0], t_list[1], residual)
e_2 = model(x_2.to("cuda"), t_list[1].to("cuda")).to("cpu")
x_3 = self.noise_scheduler.transfer(image, t_list[0], t_list[1], e_2)
e_3 = model(x_3.to("cuda"), t_list[1].to("cuda")).to("cpu")
x_4 = self.noise_scheduler.transfer(image, t_list[0], t_list[2], e_3)
e_4 = model(x_4.to("cuda"), t_list[2].to("cuda")).to("cpu")
residual = (1 / 6) * (residual + 2 * e_2 + 2 * e_3 + e_4)
else:
ets.append(residual)
residual = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4])
img_next = self.noise_scheduler.transfer(image.to("cpu"), t, t_next, residual)
# with torch.no_grad():
# t_start, t_end = t_next, t
# img_next, ets = self.noise_scheduler.step(image, t_start, t_end, model, ets)
image = img_next

View File

@ -88,35 +88,34 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
#def gen_order_4(img, t, t_next, model, alphas_cump, ets):
t_next, t = t_start, t_end
noise_ = model(img.to("cuda"), t.to("cuda"))
noise_ = noise_.to("cpu")
t_list = [t, (t+t_next)/2, t_next]
alphas_cump = self.alphas_cumprod
if len(ets) > 2:
noise_ = model(img.to("cuda"), t.to("cuda"))
noise_ = noise_.to("cpu")
ets.append(noise_)
noise = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4])
else:
noise = self.runge_kutta(img, t_list, model, alphas_cump, ets)
noise = self.runge_kutta(img, t_list, model, ets, noise_)
img_next = self.transfer(img.to("cpu"), t, t_next, noise, alphas_cump)
img_next = self.transfer(img.to("cpu"), t, t_next, noise)
return img_next, ets
def runge_kutta(self, x, t_list, model, alphas_cump, ets):
def runge_kutta(self, x, t_list, model, ets, noise_):
model = model.to("cuda")
x = x.to("cpu")
e_1 = model(x.to("cuda"), t_list[0].to("cuda"))
e_1 = e_1.to("cpu")
e_1 = noise_
ets.append(e_1)
x_2 = self.transfer(x, t_list[0], t_list[1], e_1, alphas_cump)
x_2 = self.transfer(x, t_list[0], t_list[1], e_1)
e_2 = model(x_2.to("cuda"), t_list[1].to("cuda"))
e_2 = e_2.to("cpu")
x_3 = self.transfer(x, t_list[0], t_list[1], e_2, alphas_cump)
x_3 = self.transfer(x, t_list[0], t_list[1], e_2)
e_3 = model(x_3.to("cuda"), t_list[1].to("cuda"))
e_3 = e_3.to("cpu")
x_4 = self.transfer(x, t_list[0], t_list[2], e_3, alphas_cump)
x_4 = self.transfer(x, t_list[0], t_list[2], e_3)
e_4 = model(x_4.to("cuda"), t_list[2].to("cuda"))
e_4 = e_4.to("cpu")
@ -125,7 +124,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return et
def transfer(self, x, t, t_next, et, alphas_cump):
def transfer(self, x, t, t_next, et):
alphas_cump = self.alphas_cumprod
at = alphas_cump[t.long() + 1].view(-1, 1, 1, 1)
at_next = alphas_cump[t_next.long() + 1].view(-1, 1, 1, 1)

View File

@ -19,7 +19,7 @@ import unittest
import torch
from diffusers import DDIM, DDPM, DDIMScheduler, DDPMScheduler, LatentDiffusion, UNetModel
from diffusers import DDIM, DDPM, DDIMScheduler, DDPMScheduler, LatentDiffusion, UNetModel, PNDM, PNDMScheduler
from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.testing_utils import floats_tensor, slow, torch_device
@ -178,6 +178,25 @@ class PipelineTesterMixin(unittest.TestCase):
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
@slow
def test_pndm_cifar10(self):
generator = torch.manual_seed(0)
model_id = "fusing/ddpm-cifar10"
unet = UNetModel.from_pretrained(model_id)
noise_scheduler = PNDMScheduler(tensor_format="pt")
pndm = PNDM(unet=unet, noise_scheduler=noise_scheduler)
image = pndm(generator=generator)
image_slice = image[0, -1, -3:, -3:].cpu()
assert image.shape == (1, 3, 32, 32)
expected_slice = torch.tensor(
[-0.7888, -0.7870, -0.7759, -0.7823, -0.8014, -0.7608, -0.6818, -0.7130, -0.7471]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
@slow
def test_ldm_text2img(self):
model_id = "fusing/latent-diffusion-text2im-large"