improve pndm
This commit is contained in:
parent
11631e8154
commit
809591b7b6
|
@ -40,7 +40,7 @@ class PNDM(DiffusionPipeline):
|
|||
# Sample gaussian noise to begin loop
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
|
||||
# generator=torch.manual_seed(0)
|
||||
generator=generator,
|
||||
)
|
||||
image = image.to(torch_device)
|
||||
|
||||
|
@ -53,9 +53,30 @@ class PNDM(DiffusionPipeline):
|
|||
t = (torch.ones(image.shape[0]) * i)
|
||||
t_next = (torch.ones(image.shape[0]) * j)
|
||||
|
||||
with torch.no_grad():
|
||||
t_start, t_end = t_next, t
|
||||
img_next, ets = self.noise_scheduler.step(image, t_start, t_end, model, ets)
|
||||
residual = model(image.to("cuda"), t.to("cuda"))
|
||||
residual = residual.to("cpu")
|
||||
|
||||
t_list = [t, (t+t_next)/2, t_next]
|
||||
|
||||
if len(ets) <= 2:
|
||||
ets.append(residual)
|
||||
image = image.to("cpu")
|
||||
x_2 = self.noise_scheduler.transfer(image, t_list[0], t_list[1], residual)
|
||||
e_2 = model(x_2.to("cuda"), t_list[1].to("cuda")).to("cpu")
|
||||
x_3 = self.noise_scheduler.transfer(image, t_list[0], t_list[1], e_2)
|
||||
e_3 = model(x_3.to("cuda"), t_list[1].to("cuda")).to("cpu")
|
||||
x_4 = self.noise_scheduler.transfer(image, t_list[0], t_list[2], e_3)
|
||||
e_4 = model(x_4.to("cuda"), t_list[2].to("cuda")).to("cpu")
|
||||
residual = (1 / 6) * (residual + 2 * e_2 + 2 * e_3 + e_4)
|
||||
else:
|
||||
ets.append(residual)
|
||||
residual = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4])
|
||||
|
||||
img_next = self.noise_scheduler.transfer(image.to("cpu"), t, t_next, residual)
|
||||
|
||||
# with torch.no_grad():
|
||||
# t_start, t_end = t_next, t
|
||||
# img_next, ets = self.noise_scheduler.step(image, t_start, t_end, model, ets)
|
||||
|
||||
image = img_next
|
||||
|
||||
|
|
|
@ -88,35 +88,34 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||
#def gen_order_4(img, t, t_next, model, alphas_cump, ets):
|
||||
t_next, t = t_start, t_end
|
||||
|
||||
noise_ = model(img.to("cuda"), t.to("cuda"))
|
||||
noise_ = noise_.to("cpu")
|
||||
|
||||
t_list = [t, (t+t_next)/2, t_next]
|
||||
alphas_cump = self.alphas_cumprod
|
||||
if len(ets) > 2:
|
||||
noise_ = model(img.to("cuda"), t.to("cuda"))
|
||||
noise_ = noise_.to("cpu")
|
||||
ets.append(noise_)
|
||||
noise = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4])
|
||||
else:
|
||||
noise = self.runge_kutta(img, t_list, model, alphas_cump, ets)
|
||||
noise = self.runge_kutta(img, t_list, model, ets, noise_)
|
||||
|
||||
img_next = self.transfer(img.to("cpu"), t, t_next, noise, alphas_cump)
|
||||
img_next = self.transfer(img.to("cpu"), t, t_next, noise)
|
||||
return img_next, ets
|
||||
|
||||
def runge_kutta(self, x, t_list, model, alphas_cump, ets):
|
||||
def runge_kutta(self, x, t_list, model, ets, noise_):
|
||||
model = model.to("cuda")
|
||||
x = x.to("cpu")
|
||||
|
||||
e_1 = model(x.to("cuda"), t_list[0].to("cuda"))
|
||||
e_1 = e_1.to("cpu")
|
||||
e_1 = noise_
|
||||
ets.append(e_1)
|
||||
x_2 = self.transfer(x, t_list[0], t_list[1], e_1, alphas_cump)
|
||||
x_2 = self.transfer(x, t_list[0], t_list[1], e_1)
|
||||
|
||||
e_2 = model(x_2.to("cuda"), t_list[1].to("cuda"))
|
||||
e_2 = e_2.to("cpu")
|
||||
x_3 = self.transfer(x, t_list[0], t_list[1], e_2, alphas_cump)
|
||||
x_3 = self.transfer(x, t_list[0], t_list[1], e_2)
|
||||
|
||||
e_3 = model(x_3.to("cuda"), t_list[1].to("cuda"))
|
||||
e_3 = e_3.to("cpu")
|
||||
x_4 = self.transfer(x, t_list[0], t_list[2], e_3, alphas_cump)
|
||||
x_4 = self.transfer(x, t_list[0], t_list[2], e_3)
|
||||
|
||||
e_4 = model(x_4.to("cuda"), t_list[2].to("cuda"))
|
||||
e_4 = e_4.to("cpu")
|
||||
|
@ -125,7 +124,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
return et
|
||||
|
||||
def transfer(self, x, t, t_next, et, alphas_cump):
|
||||
def transfer(self, x, t, t_next, et):
|
||||
alphas_cump = self.alphas_cumprod
|
||||
at = alphas_cump[t.long() + 1].view(-1, 1, 1, 1)
|
||||
at_next = alphas_cump[t_next.long() + 1].view(-1, 1, 1, 1)
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ import unittest
|
|||
|
||||
import torch
|
||||
|
||||
from diffusers import DDIM, DDPM, DDIMScheduler, DDPMScheduler, LatentDiffusion, UNetModel
|
||||
from diffusers import DDIM, DDPM, DDIMScheduler, DDPMScheduler, LatentDiffusion, UNetModel, PNDM, PNDMScheduler
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
||||
|
@ -178,6 +178,25 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
)
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_pndm_cifar10(self):
|
||||
generator = torch.manual_seed(0)
|
||||
model_id = "fusing/ddpm-cifar10"
|
||||
|
||||
unet = UNetModel.from_pretrained(model_id)
|
||||
noise_scheduler = PNDMScheduler(tensor_format="pt")
|
||||
|
||||
pndm = PNDM(unet=unet, noise_scheduler=noise_scheduler)
|
||||
image = pndm(generator=generator)
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
|
||||
assert image.shape == (1, 3, 32, 32)
|
||||
expected_slice = torch.tensor(
|
||||
[-0.7888, -0.7870, -0.7759, -0.7823, -0.8014, -0.7608, -0.6818, -0.7130, -0.7471]
|
||||
)
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_ldm_text2img(self):
|
||||
model_id = "fusing/latent-diffusion-text2im-large"
|
||||
|
|
Loading…
Reference in New Issue