finalize pndm

This commit is contained in:
Patrick von Platen 2022-06-14 10:50:05 +00:00
parent 9b7e6f495f
commit da1f920ef1
2 changed files with 12 additions and 11 deletions

View File

@ -28,7 +28,8 @@ class PNDM(DiffusionPipeline):
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50):
# eta corresponds to η in paper and should be between [0, 1]
# For more information on the sampling method you can take a look at Algorithm 2 of
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
@ -42,21 +43,17 @@ class PNDM(DiffusionPipeline):
image = image.to(torch_device)
warmup_time_steps = self.noise_scheduler.get_warmup_time_steps(num_inference_steps)
prev_image = image
for t in tqdm.tqdm(range(len(warmup_time_steps))):
t_orig = warmup_time_steps[t]
residual = self.unet(image, t_orig)
if t % 4 == 0:
prev_image = image
image = self.noise_scheduler.step_warm_up(residual, prev_image, t, num_inference_steps)
image = self.noise_scheduler.step_prk(residual, image, t, num_inference_steps)
timesteps = self.noise_scheduler.get_time_steps(num_inference_steps)
for t in tqdm.tqdm(range(len(timesteps))):
t_orig = timesteps[t]
residual = self.unet(image, t_orig)
image = self.noise_scheduler.step(residual, image, t, num_inference_steps)
image = self.noise_scheduler.step_plms(residual, image, t, num_inference_steps)
return image

View File

@ -55,11 +55,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.set_format(tensor_format=tensor_format)
# for now we only support F-PNDM, i.e. the runge-kutta method
# For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# mainly at equations (12) and (13) and the Algorithm 2.
self.pndm_order = 4
# running values
self.cur_residual = 0
self.cur_image = None
self.ets = []
self.warmup_time_steps = {}
self.time_steps = {}
@ -95,7 +98,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return self.time_steps[num_inference_steps]
def step_warm_up(self, residual, image, t, num_inference_steps):
def step_prk(self, residual, image, t, num_inference_steps):
# TODO(Patrick) - need to rethink whether the "warmup" way is the correct API design here
warmup_time_steps = self.get_warmup_time_steps(num_inference_steps)
@ -105,6 +108,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
if t % 4 == 0:
self.cur_residual += 1 / 6 * residual
self.ets.append(residual)
self.cur_image = image
elif (t - 1) % 4 == 0:
self.cur_residual += 1 / 3 * residual
elif (t - 2) % 4 == 0:
@ -113,9 +117,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
residual = self.cur_residual + 1 / 6 * residual
self.cur_residual = 0
return self.transfer(image, t_prev, t_next, residual)
return self.transfer(self.cur_image, t_prev, t_next, residual)
def step(self, residual, image, t, num_inference_steps):
def step_plms(self, residual, image, t, num_inference_steps):
timesteps = self.get_time_steps(num_inference_steps)
t_prev = timesteps[t]