finalize pndm
This commit is contained in:
parent
9b7e6f495f
commit
da1f920ef1
|
@ -28,7 +28,8 @@ class PNDM(DiffusionPipeline):
|
|||
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
|
||||
|
||||
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50):
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
# For more information on the sampling method you can take a look at Algorithm 2 of
|
||||
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
@ -42,21 +43,17 @@ class PNDM(DiffusionPipeline):
|
|||
image = image.to(torch_device)
|
||||
|
||||
warmup_time_steps = self.noise_scheduler.get_warmup_time_steps(num_inference_steps)
|
||||
prev_image = image
|
||||
for t in tqdm.tqdm(range(len(warmup_time_steps))):
|
||||
t_orig = warmup_time_steps[t]
|
||||
residual = self.unet(image, t_orig)
|
||||
|
||||
if t % 4 == 0:
|
||||
prev_image = image
|
||||
|
||||
image = self.noise_scheduler.step_warm_up(residual, prev_image, t, num_inference_steps)
|
||||
image = self.noise_scheduler.step_prk(residual, image, t, num_inference_steps)
|
||||
|
||||
timesteps = self.noise_scheduler.get_time_steps(num_inference_steps)
|
||||
for t in tqdm.tqdm(range(len(timesteps))):
|
||||
t_orig = timesteps[t]
|
||||
residual = self.unet(image, t_orig)
|
||||
|
||||
image = self.noise_scheduler.step(residual, image, t, num_inference_steps)
|
||||
image = self.noise_scheduler.step_plms(residual, image, t, num_inference_steps)
|
||||
|
||||
return image
|
||||
|
|
|
@ -55,11 +55,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
# for now we only support F-PNDM, i.e. the runge-kutta method
|
||||
# For now we only support F-PNDM, i.e. the runge-kutta method
|
||||
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
|
||||
# mainly at equations (12) and (13) and the Algorithm 2.
|
||||
self.pndm_order = 4
|
||||
|
||||
# running values
|
||||
self.cur_residual = 0
|
||||
self.cur_image = None
|
||||
self.ets = []
|
||||
self.warmup_time_steps = {}
|
||||
self.time_steps = {}
|
||||
|
@ -95,7 +98,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
return self.time_steps[num_inference_steps]
|
||||
|
||||
def step_warm_up(self, residual, image, t, num_inference_steps):
|
||||
def step_prk(self, residual, image, t, num_inference_steps):
|
||||
# TODO(Patrick) - need to rethink whether the "warmup" way is the correct API design here
|
||||
warmup_time_steps = self.get_warmup_time_steps(num_inference_steps)
|
||||
|
||||
|
@ -105,6 +108,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||
if t % 4 == 0:
|
||||
self.cur_residual += 1 / 6 * residual
|
||||
self.ets.append(residual)
|
||||
self.cur_image = image
|
||||
elif (t - 1) % 4 == 0:
|
||||
self.cur_residual += 1 / 3 * residual
|
||||
elif (t - 2) % 4 == 0:
|
||||
|
@ -113,9 +117,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||
residual = self.cur_residual + 1 / 6 * residual
|
||||
self.cur_residual = 0
|
||||
|
||||
return self.transfer(image, t_prev, t_next, residual)
|
||||
return self.transfer(self.cur_image, t_prev, t_next, residual)
|
||||
|
||||
def step(self, residual, image, t, num_inference_steps):
|
||||
def step_plms(self, residual, image, t, num_inference_steps):
|
||||
timesteps = self.get_time_steps(num_inference_steps)
|
||||
|
||||
t_prev = timesteps[t]
|
||||
|
|
Loading…
Reference in New Issue