finish pndm scheduler

This commit is contained in:
Patrick von Platen 2022-06-17 15:51:03 +02:00
parent de22d4cd5d
commit b2274ece73
3 changed files with 230 additions and 38 deletions

View File

@ -42,9 +42,9 @@ class PNDM(DiffusionPipeline):
)
image = image.to(torch_device)
warmup_time_steps = self.noise_scheduler.get_warmup_time_steps(num_inference_steps)
for t in tqdm.tqdm(range(len(warmup_time_steps))):
t_orig = warmup_time_steps[t]
prk_time_steps = self.noise_scheduler.get_prk_time_steps(num_inference_steps)
for t in tqdm.tqdm(range(len(prk_time_steps))):
t_orig = prk_time_steps[t]
residual = self.unet(image, t_orig)
image = self.noise_scheduler.step_prk(residual, image, t, num_inference_steps)

View File

@ -56,15 +56,16 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# mainly at equations (12) and (13) and the Algorithm 2.
# mainly at formula (9), (12), (13) and the Algorithm 2.
self.pndm_order = 4
# running values
self.cur_residual = 0
self.cur_sample = None
self.ets = []
self.warmup_time_steps = {}
self.prk_time_steps = {}
self.time_steps = {}
self.set_prk_mode()
def get_alpha(self, time_step):
return self.alphas[time_step]
@ -77,18 +78,18 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return self.one
return self.alphas_cumprod[time_step]
def get_warmup_time_steps(self, num_inference_steps):
if num_inference_steps in self.warmup_time_steps:
return self.warmup_time_steps[num_inference_steps]
def get_prk_time_steps(self, num_inference_steps):
if num_inference_steps in self.prk_time_steps:
return self.prk_time_steps[num_inference_steps]
inference_step_times = list(range(0, self.config.timesteps, self.config.timesteps // num_inference_steps))
warmup_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile(
prk_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile(
np.array([0, self.config.timesteps // num_inference_steps // 2]), self.pndm_order
)
self.warmup_time_steps[num_inference_steps] = list(reversed(warmup_time_steps[:-1].repeat(2)[1:-1]))
self.prk_time_steps[num_inference_steps] = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1]))
return self.warmup_time_steps[num_inference_steps]
return self.prk_time_steps[num_inference_steps]
def get_time_steps(self, num_inference_steps):
if num_inference_steps in self.time_steps:
@ -99,12 +100,25 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return self.time_steps[num_inference_steps]
def step_prk(self, residual, sample, t, num_inference_steps):
# TODO(Patrick) - need to rethink whether the "warmup" way is the correct API design here
warmup_time_steps = self.get_warmup_time_steps(num_inference_steps)
def set_prk_mode(self):
self.mode = "prk"
t_prev = warmup_time_steps[t // 4 * 4]
t_next = warmup_time_steps[min(t + 1, len(warmup_time_steps) - 1)]
def set_plms_mode(self):
self.mode = "plms"
def step(self, *args, **kwargs):
if self.mode == "prk":
return self.step_prk(*args, **kwargs)
if self.mode == "plms":
return self.step_plms(*args, **kwargs)
raise ValueError(f"mode {self.mode} does not exist.")
def step_prk(self, residual, sample, t, num_inference_steps):
prk_time_steps = self.get_prk_time_steps(num_inference_steps)
t_orig = prk_time_steps[t // 4 * 4]
t_orig_prev = prk_time_steps[min(t + 1, len(prk_time_steps) - 1)]
if t % 4 == 0:
self.cur_residual += 1 / 6 * residual
@ -118,33 +132,63 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
residual = self.cur_residual + 1 / 6 * residual
self.cur_residual = 0
return self.transfer(self.cur_sample, t_prev, t_next, residual)
# cur_sample should not be `None`
cur_sample = self.cur_sample if self.cur_sample is not None else sample
return self.get_prev_sample(cur_sample, t_orig, t_orig_prev, residual)
def step_plms(self, residual, sample, t, num_inference_steps):
if len(self.ets) < 3:
raise ValueError(
f"{self.__class__} can only be run AFTER scheduler has been run "
"in 'prk' mode for at least 12 iterations "
"See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
"for more information."
)
timesteps = self.get_time_steps(num_inference_steps)
t_prev = timesteps[t]
t_next = timesteps[min(t + 1, len(timesteps) - 1)]
t_orig = timesteps[t]
t_orig_prev = timesteps[min(t + 1, len(timesteps) - 1)]
self.ets.append(residual)
residual = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
return self.transfer(sample, t_prev, t_next, residual)
return self.get_prev_sample(sample, t_orig, t_orig_prev, residual)
def transfer(self, x, t, t_next, et):
# TODO(Patrick): clean up to be compatible with numpy and give better names
def get_prev_sample(self, sample, t_orig, t_orig_prev, residual):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
# this function computes x_(tδ) using the formula of (9)
# Note that x_t needs to be added to both sides of the equation
alphas_cump = self.alphas_cumprod.to(x.device)
at = alphas_cump[t + 1].view(-1, 1, 1, 1)
at_next = alphas_cump[t_next + 1].view(-1, 1, 1, 1)
# Notation (<variable name> -> <name in paper>
# alpha_prod_t -> α_t
# alpha_prod_t_prev -> α_(tδ)
# beta_prod_t -> (1 - α_t)
# beta_prod_t_prev -> (1 - α_(tδ))
# sample -> x_t
# residual -> e_θ(x_t, t)
# prev_sample -> x_(tδ)
alpha_prod_t = self.get_alpha_prod(t_orig + 1)
alpha_prod_t_prev = self.get_alpha_prod(t_orig_prev + 1)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
x_delta = (at_next - at) * (
(1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x
- 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et
)
# corresponds to (α_(tδ) - α_t) divided by
# denominator of x_t in formula (9) and plus 1
# Note: (α_(tδ) - α_t) / (sqrt(α_t) * (sqrt(α_(tδ)) + sqr(α_t))) =
# sqrt(α_(tδ)) / sqrt(α_t))
sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)
x_next = x + x_delta
return x_next
# corresponds to denominator of e_θ(x_t, t) in formula (9)
residual_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
alpha_prod_t * beta_prod_t * alpha_prod_t_prev
) ** (0.5)
# full formula (9)
prev_sample = sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * residual / residual_denom_coeff
return prev_sample
def __len__(self):
return self.config.timesteps

View File

@ -20,7 +20,7 @@ import unittest
import numpy as np
import torch
from diffusers import DDIMScheduler, DDPMScheduler
from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler
torch.backends.cuda.matmul.allow_tf32 = False
@ -90,10 +90,10 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs.update(forward_kwargs)
for scheduler_class in self.scheduler_classes:
scheduler_class = self.scheduler_classes[0]
image = self.dummy_image
residual = 0.1 * image
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
@ -159,7 +159,7 @@ class SchedulerCommonTest(unittest.TestCase):
output = scheduler.step(residual, image, 1, **kwargs)
output_pt = scheduler_pt.step(residual_pt, image_pt, 1, **kwargs)
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-5, "Scheduler outputs are not identical"
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
class DDPMSchedulerTest(SchedulerCommonTest):
@ -237,8 +237,8 @@ class DDPMSchedulerTest(SchedulerCommonTest):
result_sum = np.sum(np.abs(image))
result_mean = np.mean(np.abs(image))
assert result_sum.item() - 732.9947 < 1e-3
assert result_mean.item() - 0.9544 < 1e-3
assert abs(result_sum.item() - 732.9947) < 1e-2
assert abs(result_mean.item() - 0.9544) < 1e-3
class DDIMSchedulerTest(SchedulerCommonTest):
@ -325,5 +325,153 @@ class DDIMSchedulerTest(SchedulerCommonTest):
result_sum = np.sum(np.abs(image))
result_mean = np.mean(np.abs(image))
assert result_sum.item() - 270.6214 < 1e-3
assert result_mean.item() - 0.3524 < 1e-3
assert abs(result_sum.item() - 270.6214) < 1e-2
assert abs(result_mean.item() - 0.3524) < 1e-3
class PNDMSchedulerTest(SchedulerCommonTest):
scheduler_classes = (PNDMScheduler,)
forward_default_kwargs = (("num_inference_steps", 50),)
def get_scheduler_config(self, **kwargs):
config = {
"timesteps": 1000,
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
}
config.update(**kwargs)
return config
def check_over_configs_pmls(self, time_step=0, **config):
kwargs = dict(self.forward_default_kwargs)
image = self.dummy_image
residual = 0.1 * image
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
for scheduler_class in self.scheduler_classes:
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
# copy over dummy past residuals
scheduler.ets = dummy_past_residuals[:]
scheduler.set_plms_mode()
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
# copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:]
new_scheduler.set_plms_mode()
output = scheduler.step(residual, image, time_step, **kwargs)
new_output = new_scheduler.step(residual, image, time_step, **kwargs)
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def check_over_forward_pmls(self, time_step=0, **forward_kwargs):
kwargs = dict(self.forward_default_kwargs)
kwargs.update(forward_kwargs)
image = self.dummy_image
residual = 0.1 * image
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
for scheduler_class in self.scheduler_classes:
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
# copy over dummy past residuals
scheduler.ets = dummy_past_residuals[:]
scheduler.set_plms_mode()
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
# copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:]
new_scheduler.set_plms_mode()
output = scheduler.step(residual, image, time_step, **kwargs)
new_output = new_scheduler.step(residual, image, time_step, **kwargs)
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_timesteps(self):
for timesteps in [100, 1000]:
self.check_over_configs(timesteps=timesteps)
def test_timesteps_pmls(self):
for timesteps in [100, 1000]:
self.check_over_configs_pmls(timesteps=timesteps)
def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
def test_betas_pmls(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
self.check_over_configs_pmls(beta_start=beta_start, beta_end=beta_end)
def test_schedules(self):
for schedule in ["linear", "squaredcos_cap_v2"]:
self.check_over_configs(beta_schedule=schedule)
def test_schedules_pmls(self):
for schedule in ["linear", "squaredcos_cap_v2"]:
self.check_over_configs(beta_schedule=schedule)
def test_time_indices(self):
for t in [1, 5, 10]:
self.check_over_forward(time_step=t)
def test_time_indices_pmls(self):
for t in [1, 5, 10]:
self.check_over_forward_pmls(time_step=t)
def test_inference_steps(self):
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
def test_inference_steps_pmls(self):
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
self.check_over_forward_pmls(time_step=t, num_inference_steps=num_inference_steps)
def test_inference_pmls_no_past_residuals(self):
with self.assertRaises(ValueError):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_plms_mode()
scheduler.step(self.dummy_image, self.dummy_image, 1, 50)
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10
model = self.dummy_model()
image = self.dummy_image_deter
prk_time_steps = scheduler.get_prk_time_steps(num_inference_steps)
for t in range(len(prk_time_steps)):
t_orig = prk_time_steps[t]
residual = model(image, t_orig)
image = scheduler.step_prk(residual, image, t, num_inference_steps)
timesteps = scheduler.get_time_steps(num_inference_steps)
for t in range(len(timesteps)):
t_orig = timesteps[t]
residual = model(image, t_orig)
image = scheduler.step_plms(residual, image, t, num_inference_steps)
result_sum = np.sum(np.abs(image))
result_mean = np.mean(np.abs(image))
assert abs(result_sum.item() - 199.1169) < 1e-2
assert abs(result_mean.item() - 0.2593) < 1e-3