finalize pndm
This commit is contained in:
parent
9b7e6f495f
commit
da1f920ef1
|
@ -28,7 +28,8 @@ class PNDM(DiffusionPipeline):
|
||||||
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
|
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
|
||||||
|
|
||||||
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50):
|
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50):
|
||||||
# eta corresponds to η in paper and should be between [0, 1]
|
# For more information on the sampling method you can take a look at Algorithm 2 of
|
||||||
|
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
|
||||||
if torch_device is None:
|
if torch_device is None:
|
||||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
@ -42,21 +43,17 @@ class PNDM(DiffusionPipeline):
|
||||||
image = image.to(torch_device)
|
image = image.to(torch_device)
|
||||||
|
|
||||||
warmup_time_steps = self.noise_scheduler.get_warmup_time_steps(num_inference_steps)
|
warmup_time_steps = self.noise_scheduler.get_warmup_time_steps(num_inference_steps)
|
||||||
prev_image = image
|
|
||||||
for t in tqdm.tqdm(range(len(warmup_time_steps))):
|
for t in tqdm.tqdm(range(len(warmup_time_steps))):
|
||||||
t_orig = warmup_time_steps[t]
|
t_orig = warmup_time_steps[t]
|
||||||
residual = self.unet(image, t_orig)
|
residual = self.unet(image, t_orig)
|
||||||
|
|
||||||
if t % 4 == 0:
|
image = self.noise_scheduler.step_prk(residual, image, t, num_inference_steps)
|
||||||
prev_image = image
|
|
||||||
|
|
||||||
image = self.noise_scheduler.step_warm_up(residual, prev_image, t, num_inference_steps)
|
|
||||||
|
|
||||||
timesteps = self.noise_scheduler.get_time_steps(num_inference_steps)
|
timesteps = self.noise_scheduler.get_time_steps(num_inference_steps)
|
||||||
for t in tqdm.tqdm(range(len(timesteps))):
|
for t in tqdm.tqdm(range(len(timesteps))):
|
||||||
t_orig = timesteps[t]
|
t_orig = timesteps[t]
|
||||||
residual = self.unet(image, t_orig)
|
residual = self.unet(image, t_orig)
|
||||||
|
|
||||||
image = self.noise_scheduler.step(residual, image, t, num_inference_steps)
|
image = self.noise_scheduler.step_plms(residual, image, t, num_inference_steps)
|
||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
|
@ -55,11 +55,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
self.set_format(tensor_format=tensor_format)
|
self.set_format(tensor_format=tensor_format)
|
||||||
|
|
||||||
# for now we only support F-PNDM, i.e. the runge-kutta method
|
# For now we only support F-PNDM, i.e. the runge-kutta method
|
||||||
|
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
|
||||||
|
# mainly at equations (12) and (13) and the Algorithm 2.
|
||||||
self.pndm_order = 4
|
self.pndm_order = 4
|
||||||
|
|
||||||
# running values
|
# running values
|
||||||
self.cur_residual = 0
|
self.cur_residual = 0
|
||||||
|
self.cur_image = None
|
||||||
self.ets = []
|
self.ets = []
|
||||||
self.warmup_time_steps = {}
|
self.warmup_time_steps = {}
|
||||||
self.time_steps = {}
|
self.time_steps = {}
|
||||||
|
@ -95,7 +98,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
return self.time_steps[num_inference_steps]
|
return self.time_steps[num_inference_steps]
|
||||||
|
|
||||||
def step_warm_up(self, residual, image, t, num_inference_steps):
|
def step_prk(self, residual, image, t, num_inference_steps):
|
||||||
# TODO(Patrick) - need to rethink whether the "warmup" way is the correct API design here
|
# TODO(Patrick) - need to rethink whether the "warmup" way is the correct API design here
|
||||||
warmup_time_steps = self.get_warmup_time_steps(num_inference_steps)
|
warmup_time_steps = self.get_warmup_time_steps(num_inference_steps)
|
||||||
|
|
||||||
|
@ -105,6 +108,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
if t % 4 == 0:
|
if t % 4 == 0:
|
||||||
self.cur_residual += 1 / 6 * residual
|
self.cur_residual += 1 / 6 * residual
|
||||||
self.ets.append(residual)
|
self.ets.append(residual)
|
||||||
|
self.cur_image = image
|
||||||
elif (t - 1) % 4 == 0:
|
elif (t - 1) % 4 == 0:
|
||||||
self.cur_residual += 1 / 3 * residual
|
self.cur_residual += 1 / 3 * residual
|
||||||
elif (t - 2) % 4 == 0:
|
elif (t - 2) % 4 == 0:
|
||||||
|
@ -113,9 +117,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
residual = self.cur_residual + 1 / 6 * residual
|
residual = self.cur_residual + 1 / 6 * residual
|
||||||
self.cur_residual = 0
|
self.cur_residual = 0
|
||||||
|
|
||||||
return self.transfer(image, t_prev, t_next, residual)
|
return self.transfer(self.cur_image, t_prev, t_next, residual)
|
||||||
|
|
||||||
def step(self, residual, image, t, num_inference_steps):
|
def step_plms(self, residual, image, t, num_inference_steps):
|
||||||
timesteps = self.get_time_steps(num_inference_steps)
|
timesteps = self.get_time_steps(num_inference_steps)
|
||||||
|
|
||||||
t_prev = timesteps[t]
|
t_prev = timesteps[t]
|
||||||
|
|
Loading…
Reference in New Issue