improve pndm
This commit is contained in:
parent
11631e8154
commit
809591b7b6
|
@ -40,7 +40,7 @@ class PNDM(DiffusionPipeline):
|
||||||
# Sample gaussian noise to begin loop
|
# Sample gaussian noise to begin loop
|
||||||
image = torch.randn(
|
image = torch.randn(
|
||||||
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
|
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
|
||||||
# generator=torch.manual_seed(0)
|
generator=generator,
|
||||||
)
|
)
|
||||||
image = image.to(torch_device)
|
image = image.to(torch_device)
|
||||||
|
|
||||||
|
@ -53,9 +53,30 @@ class PNDM(DiffusionPipeline):
|
||||||
t = (torch.ones(image.shape[0]) * i)
|
t = (torch.ones(image.shape[0]) * i)
|
||||||
t_next = (torch.ones(image.shape[0]) * j)
|
t_next = (torch.ones(image.shape[0]) * j)
|
||||||
|
|
||||||
with torch.no_grad():
|
residual = model(image.to("cuda"), t.to("cuda"))
|
||||||
t_start, t_end = t_next, t
|
residual = residual.to("cpu")
|
||||||
img_next, ets = self.noise_scheduler.step(image, t_start, t_end, model, ets)
|
|
||||||
|
t_list = [t, (t+t_next)/2, t_next]
|
||||||
|
|
||||||
|
if len(ets) <= 2:
|
||||||
|
ets.append(residual)
|
||||||
|
image = image.to("cpu")
|
||||||
|
x_2 = self.noise_scheduler.transfer(image, t_list[0], t_list[1], residual)
|
||||||
|
e_2 = model(x_2.to("cuda"), t_list[1].to("cuda")).to("cpu")
|
||||||
|
x_3 = self.noise_scheduler.transfer(image, t_list[0], t_list[1], e_2)
|
||||||
|
e_3 = model(x_3.to("cuda"), t_list[1].to("cuda")).to("cpu")
|
||||||
|
x_4 = self.noise_scheduler.transfer(image, t_list[0], t_list[2], e_3)
|
||||||
|
e_4 = model(x_4.to("cuda"), t_list[2].to("cuda")).to("cpu")
|
||||||
|
residual = (1 / 6) * (residual + 2 * e_2 + 2 * e_3 + e_4)
|
||||||
|
else:
|
||||||
|
ets.append(residual)
|
||||||
|
residual = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4])
|
||||||
|
|
||||||
|
img_next = self.noise_scheduler.transfer(image.to("cpu"), t, t_next, residual)
|
||||||
|
|
||||||
|
# with torch.no_grad():
|
||||||
|
# t_start, t_end = t_next, t
|
||||||
|
# img_next, ets = self.noise_scheduler.step(image, t_start, t_end, model, ets)
|
||||||
|
|
||||||
image = img_next
|
image = img_next
|
||||||
|
|
||||||
|
|
|
@ -88,35 +88,34 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
#def gen_order_4(img, t, t_next, model, alphas_cump, ets):
|
#def gen_order_4(img, t, t_next, model, alphas_cump, ets):
|
||||||
t_next, t = t_start, t_end
|
t_next, t = t_start, t_end
|
||||||
|
|
||||||
t_list = [t, (t+t_next)/2, t_next]
|
|
||||||
alphas_cump = self.alphas_cumprod
|
|
||||||
if len(ets) > 2:
|
|
||||||
noise_ = model(img.to("cuda"), t.to("cuda"))
|
noise_ = model(img.to("cuda"), t.to("cuda"))
|
||||||
noise_ = noise_.to("cpu")
|
noise_ = noise_.to("cpu")
|
||||||
|
|
||||||
|
t_list = [t, (t+t_next)/2, t_next]
|
||||||
|
if len(ets) > 2:
|
||||||
ets.append(noise_)
|
ets.append(noise_)
|
||||||
noise = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4])
|
noise = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4])
|
||||||
else:
|
else:
|
||||||
noise = self.runge_kutta(img, t_list, model, alphas_cump, ets)
|
noise = self.runge_kutta(img, t_list, model, ets, noise_)
|
||||||
|
|
||||||
img_next = self.transfer(img.to("cpu"), t, t_next, noise, alphas_cump)
|
img_next = self.transfer(img.to("cpu"), t, t_next, noise)
|
||||||
return img_next, ets
|
return img_next, ets
|
||||||
|
|
||||||
def runge_kutta(self, x, t_list, model, alphas_cump, ets):
|
def runge_kutta(self, x, t_list, model, ets, noise_):
|
||||||
model = model.to("cuda")
|
model = model.to("cuda")
|
||||||
x = x.to("cpu")
|
x = x.to("cpu")
|
||||||
|
|
||||||
e_1 = model(x.to("cuda"), t_list[0].to("cuda"))
|
e_1 = noise_
|
||||||
e_1 = e_1.to("cpu")
|
|
||||||
ets.append(e_1)
|
ets.append(e_1)
|
||||||
x_2 = self.transfer(x, t_list[0], t_list[1], e_1, alphas_cump)
|
x_2 = self.transfer(x, t_list[0], t_list[1], e_1)
|
||||||
|
|
||||||
e_2 = model(x_2.to("cuda"), t_list[1].to("cuda"))
|
e_2 = model(x_2.to("cuda"), t_list[1].to("cuda"))
|
||||||
e_2 = e_2.to("cpu")
|
e_2 = e_2.to("cpu")
|
||||||
x_3 = self.transfer(x, t_list[0], t_list[1], e_2, alphas_cump)
|
x_3 = self.transfer(x, t_list[0], t_list[1], e_2)
|
||||||
|
|
||||||
e_3 = model(x_3.to("cuda"), t_list[1].to("cuda"))
|
e_3 = model(x_3.to("cuda"), t_list[1].to("cuda"))
|
||||||
e_3 = e_3.to("cpu")
|
e_3 = e_3.to("cpu")
|
||||||
x_4 = self.transfer(x, t_list[0], t_list[2], e_3, alphas_cump)
|
x_4 = self.transfer(x, t_list[0], t_list[2], e_3)
|
||||||
|
|
||||||
e_4 = model(x_4.to("cuda"), t_list[2].to("cuda"))
|
e_4 = model(x_4.to("cuda"), t_list[2].to("cuda"))
|
||||||
e_4 = e_4.to("cpu")
|
e_4 = e_4.to("cpu")
|
||||||
|
@ -125,7 +124,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
return et
|
return et
|
||||||
|
|
||||||
def transfer(self, x, t, t_next, et, alphas_cump):
|
def transfer(self, x, t, t_next, et):
|
||||||
|
alphas_cump = self.alphas_cumprod
|
||||||
at = alphas_cump[t.long() + 1].view(-1, 1, 1, 1)
|
at = alphas_cump[t.long() + 1].view(-1, 1, 1, 1)
|
||||||
at_next = alphas_cump[t_next.long() + 1].view(-1, 1, 1, 1)
|
at_next = alphas_cump[t_next.long() + 1].view(-1, 1, 1, 1)
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import DDIM, DDPM, DDIMScheduler, DDPMScheduler, LatentDiffusion, UNetModel
|
from diffusers import DDIM, DDPM, DDIMScheduler, DDPMScheduler, LatentDiffusion, UNetModel, PNDM, PNDMScheduler
|
||||||
from diffusers.configuration_utils import ConfigMixin
|
from diffusers.configuration_utils import ConfigMixin
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
from diffusers.pipeline_utils import DiffusionPipeline
|
||||||
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
||||||
|
@ -178,6 +178,25 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||||
)
|
)
|
||||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_pndm_cifar10(self):
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
|
model_id = "fusing/ddpm-cifar10"
|
||||||
|
|
||||||
|
unet = UNetModel.from_pretrained(model_id)
|
||||||
|
noise_scheduler = PNDMScheduler(tensor_format="pt")
|
||||||
|
|
||||||
|
pndm = PNDM(unet=unet, noise_scheduler=noise_scheduler)
|
||||||
|
image = pndm(generator=generator)
|
||||||
|
|
||||||
|
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||||
|
|
||||||
|
assert image.shape == (1, 3, 32, 32)
|
||||||
|
expected_slice = torch.tensor(
|
||||||
|
[-0.7888, -0.7870, -0.7759, -0.7823, -0.8014, -0.7608, -0.6818, -0.7130, -0.7471]
|
||||||
|
)
|
||||||
|
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_ldm_text2img(self):
|
def test_ldm_text2img(self):
|
||||||
model_id = "fusing/latent-diffusion-text2im-large"
|
model_id = "fusing/latent-diffusion-text2im-large"
|
||||||
|
|
Loading…
Reference in New Issue