From da1f920ef124d00c5e81ba423e9d45e8783e9841 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 14 Jun 2022 10:50:05 +0000 Subject: [PATCH] finalize pndm --- src/diffusers/pipelines/pipeline_pndm.py | 11 ++++------- src/diffusers/schedulers/scheduling_pndm.py | 12 ++++++++---- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_pndm.py b/src/diffusers/pipelines/pipeline_pndm.py index 1116b604..93d735a8 100644 --- a/src/diffusers/pipelines/pipeline_pndm.py +++ b/src/diffusers/pipelines/pipeline_pndm.py @@ -28,7 +28,8 @@ class PNDM(DiffusionPipeline): self.register_modules(unet=unet, noise_scheduler=noise_scheduler) def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50): - # eta corresponds to η in paper and should be between [0, 1] + # For more information on the sampling method you can take a look at Algorithm 2 of + # the official paper: https://arxiv.org/pdf/2202.09778.pdf if torch_device is None: torch_device = "cuda" if torch.cuda.is_available() else "cpu" @@ -42,21 +43,17 @@ class PNDM(DiffusionPipeline): image = image.to(torch_device) warmup_time_steps = self.noise_scheduler.get_warmup_time_steps(num_inference_steps) - prev_image = image for t in tqdm.tqdm(range(len(warmup_time_steps))): t_orig = warmup_time_steps[t] residual = self.unet(image, t_orig) - if t % 4 == 0: - prev_image = image - - image = self.noise_scheduler.step_warm_up(residual, prev_image, t, num_inference_steps) + image = self.noise_scheduler.step_prk(residual, image, t, num_inference_steps) timesteps = self.noise_scheduler.get_time_steps(num_inference_steps) for t in tqdm.tqdm(range(len(timesteps))): t_orig = timesteps[t] residual = self.unet(image, t_orig) - image = self.noise_scheduler.step(residual, image, t, num_inference_steps) + image = self.noise_scheduler.step_plms(residual, image, t, num_inference_steps) return image diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index cc27b520..85fa6fb2 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -55,11 +55,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): self.set_format(tensor_format=tensor_format) - # for now we only support F-PNDM, i.e. the runge-kutta method + # For now we only support F-PNDM, i.e. the runge-kutta method + # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf + # mainly at equations (12) and (13) and the Algorithm 2. self.pndm_order = 4 # running values self.cur_residual = 0 + self.cur_image = None self.ets = [] self.warmup_time_steps = {} self.time_steps = {} @@ -95,7 +98,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): return self.time_steps[num_inference_steps] - def step_warm_up(self, residual, image, t, num_inference_steps): + def step_prk(self, residual, image, t, num_inference_steps): # TODO(Patrick) - need to rethink whether the "warmup" way is the correct API design here warmup_time_steps = self.get_warmup_time_steps(num_inference_steps) @@ -105,6 +108,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): if t % 4 == 0: self.cur_residual += 1 / 6 * residual self.ets.append(residual) + self.cur_image = image elif (t - 1) % 4 == 0: self.cur_residual += 1 / 3 * residual elif (t - 2) % 4 == 0: @@ -113,9 +117,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): residual = self.cur_residual + 1 / 6 * residual self.cur_residual = 0 - return self.transfer(image, t_prev, t_next, residual) + return self.transfer(self.cur_image, t_prev, t_next, residual) - def step(self, residual, image, t, num_inference_steps): + def step_plms(self, residual, image, t, num_inference_steps): timesteps = self.get_time_steps(num_inference_steps) t_prev = timesteps[t]