finish pndm scheduler
This commit is contained in:
parent
de22d4cd5d
commit
b2274ece73
|
@ -42,9 +42,9 @@ class PNDM(DiffusionPipeline):
|
|||
)
|
||||
image = image.to(torch_device)
|
||||
|
||||
warmup_time_steps = self.noise_scheduler.get_warmup_time_steps(num_inference_steps)
|
||||
for t in tqdm.tqdm(range(len(warmup_time_steps))):
|
||||
t_orig = warmup_time_steps[t]
|
||||
prk_time_steps = self.noise_scheduler.get_prk_time_steps(num_inference_steps)
|
||||
for t in tqdm.tqdm(range(len(prk_time_steps))):
|
||||
t_orig = prk_time_steps[t]
|
||||
residual = self.unet(image, t_orig)
|
||||
|
||||
image = self.noise_scheduler.step_prk(residual, image, t, num_inference_steps)
|
||||
|
|
|
@ -56,15 +56,16 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
# For now we only support F-PNDM, i.e. the runge-kutta method
|
||||
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
|
||||
# mainly at equations (12) and (13) and the Algorithm 2.
|
||||
# mainly at formula (9), (12), (13) and the Algorithm 2.
|
||||
self.pndm_order = 4
|
||||
|
||||
# running values
|
||||
self.cur_residual = 0
|
||||
self.cur_sample = None
|
||||
self.ets = []
|
||||
self.warmup_time_steps = {}
|
||||
self.prk_time_steps = {}
|
||||
self.time_steps = {}
|
||||
self.set_prk_mode()
|
||||
|
||||
def get_alpha(self, time_step):
|
||||
return self.alphas[time_step]
|
||||
|
@ -77,18 +78,18 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||
return self.one
|
||||
return self.alphas_cumprod[time_step]
|
||||
|
||||
def get_warmup_time_steps(self, num_inference_steps):
|
||||
if num_inference_steps in self.warmup_time_steps:
|
||||
return self.warmup_time_steps[num_inference_steps]
|
||||
def get_prk_time_steps(self, num_inference_steps):
|
||||
if num_inference_steps in self.prk_time_steps:
|
||||
return self.prk_time_steps[num_inference_steps]
|
||||
|
||||
inference_step_times = list(range(0, self.config.timesteps, self.config.timesteps // num_inference_steps))
|
||||
|
||||
warmup_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile(
|
||||
prk_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile(
|
||||
np.array([0, self.config.timesteps // num_inference_steps // 2]), self.pndm_order
|
||||
)
|
||||
self.warmup_time_steps[num_inference_steps] = list(reversed(warmup_time_steps[:-1].repeat(2)[1:-1]))
|
||||
self.prk_time_steps[num_inference_steps] = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1]))
|
||||
|
||||
return self.warmup_time_steps[num_inference_steps]
|
||||
return self.prk_time_steps[num_inference_steps]
|
||||
|
||||
def get_time_steps(self, num_inference_steps):
|
||||
if num_inference_steps in self.time_steps:
|
||||
|
@ -99,12 +100,25 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
return self.time_steps[num_inference_steps]
|
||||
|
||||
def step_prk(self, residual, sample, t, num_inference_steps):
|
||||
# TODO(Patrick) - need to rethink whether the "warmup" way is the correct API design here
|
||||
warmup_time_steps = self.get_warmup_time_steps(num_inference_steps)
|
||||
def set_prk_mode(self):
|
||||
self.mode = "prk"
|
||||
|
||||
t_prev = warmup_time_steps[t // 4 * 4]
|
||||
t_next = warmup_time_steps[min(t + 1, len(warmup_time_steps) - 1)]
|
||||
def set_plms_mode(self):
|
||||
self.mode = "plms"
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
if self.mode == "prk":
|
||||
return self.step_prk(*args, **kwargs)
|
||||
if self.mode == "plms":
|
||||
return self.step_plms(*args, **kwargs)
|
||||
|
||||
raise ValueError(f"mode {self.mode} does not exist.")
|
||||
|
||||
def step_prk(self, residual, sample, t, num_inference_steps):
|
||||
prk_time_steps = self.get_prk_time_steps(num_inference_steps)
|
||||
|
||||
t_orig = prk_time_steps[t // 4 * 4]
|
||||
t_orig_prev = prk_time_steps[min(t + 1, len(prk_time_steps) - 1)]
|
||||
|
||||
if t % 4 == 0:
|
||||
self.cur_residual += 1 / 6 * residual
|
||||
|
@ -118,33 +132,63 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||
residual = self.cur_residual + 1 / 6 * residual
|
||||
self.cur_residual = 0
|
||||
|
||||
return self.transfer(self.cur_sample, t_prev, t_next, residual)
|
||||
# cur_sample should not be `None`
|
||||
cur_sample = self.cur_sample if self.cur_sample is not None else sample
|
||||
|
||||
return self.get_prev_sample(cur_sample, t_orig, t_orig_prev, residual)
|
||||
|
||||
def step_plms(self, residual, sample, t, num_inference_steps):
|
||||
if len(self.ets) < 3:
|
||||
raise ValueError(
|
||||
f"{self.__class__} can only be run AFTER scheduler has been run "
|
||||
"in 'prk' mode for at least 12 iterations "
|
||||
"See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
|
||||
"for more information."
|
||||
)
|
||||
|
||||
timesteps = self.get_time_steps(num_inference_steps)
|
||||
|
||||
t_prev = timesteps[t]
|
||||
t_next = timesteps[min(t + 1, len(timesteps) - 1)]
|
||||
t_orig = timesteps[t]
|
||||
t_orig_prev = timesteps[min(t + 1, len(timesteps) - 1)]
|
||||
self.ets.append(residual)
|
||||
|
||||
residual = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
|
||||
|
||||
return self.transfer(sample, t_prev, t_next, residual)
|
||||
return self.get_prev_sample(sample, t_orig, t_orig_prev, residual)
|
||||
|
||||
def transfer(self, x, t, t_next, et):
|
||||
# TODO(Patrick): clean up to be compatible with numpy and give better names
|
||||
def get_prev_sample(self, sample, t_orig, t_orig_prev, residual):
|
||||
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
|
||||
# this function computes x_(t−δ) using the formula of (9)
|
||||
# Note that x_t needs to be added to both sides of the equation
|
||||
|
||||
alphas_cump = self.alphas_cumprod.to(x.device)
|
||||
at = alphas_cump[t + 1].view(-1, 1, 1, 1)
|
||||
at_next = alphas_cump[t_next + 1].view(-1, 1, 1, 1)
|
||||
# Notation (<variable name> -> <name in paper>
|
||||
# alpha_prod_t -> α_t
|
||||
# alpha_prod_t_prev -> α_(t−δ)
|
||||
# beta_prod_t -> (1 - α_t)
|
||||
# beta_prod_t_prev -> (1 - α_(t−δ))
|
||||
# sample -> x_t
|
||||
# residual -> e_θ(x_t, t)
|
||||
# prev_sample -> x_(t−δ)
|
||||
alpha_prod_t = self.get_alpha_prod(t_orig + 1)
|
||||
alpha_prod_t_prev = self.get_alpha_prod(t_orig_prev + 1)
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
x_delta = (at_next - at) * (
|
||||
(1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x
|
||||
- 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et
|
||||
)
|
||||
# corresponds to (α_(t−δ) - α_t) divided by
|
||||
# denominator of x_t in formula (9) and plus 1
|
||||
# Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
|
||||
# sqrt(α_(t−δ)) / sqrt(α_t))
|
||||
sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)
|
||||
|
||||
x_next = x + x_delta
|
||||
return x_next
|
||||
# corresponds to denominator of e_θ(x_t, t) in formula (9)
|
||||
residual_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
|
||||
alpha_prod_t * beta_prod_t * alpha_prod_t_prev
|
||||
) ** (0.5)
|
||||
|
||||
# full formula (9)
|
||||
prev_sample = sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * residual / residual_denom_coeff
|
||||
|
||||
return prev_sample
|
||||
|
||||
def __len__(self):
|
||||
return self.config.timesteps
|
||||
|
|
|
@ -20,7 +20,7 @@ import unittest
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import DDIMScheduler, DDPMScheduler
|
||||
from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
@ -90,10 +90,10 @@ class SchedulerCommonTest(unittest.TestCase):
|
|||
kwargs.update(forward_kwargs)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
image = self.dummy_image
|
||||
residual = 0.1 * image
|
||||
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
|
@ -159,7 +159,7 @@ class SchedulerCommonTest(unittest.TestCase):
|
|||
output = scheduler.step(residual, image, 1, **kwargs)
|
||||
output_pt = scheduler_pt.step(residual_pt, image_pt, 1, **kwargs)
|
||||
|
||||
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-5, "Scheduler outputs are not identical"
|
||||
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
|
||||
|
||||
|
||||
class DDPMSchedulerTest(SchedulerCommonTest):
|
||||
|
@ -237,8 +237,8 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
|||
result_sum = np.sum(np.abs(image))
|
||||
result_mean = np.mean(np.abs(image))
|
||||
|
||||
assert result_sum.item() - 732.9947 < 1e-3
|
||||
assert result_mean.item() - 0.9544 < 1e-3
|
||||
assert abs(result_sum.item() - 732.9947) < 1e-2
|
||||
assert abs(result_mean.item() - 0.9544) < 1e-3
|
||||
|
||||
|
||||
class DDIMSchedulerTest(SchedulerCommonTest):
|
||||
|
@ -325,5 +325,153 @@ class DDIMSchedulerTest(SchedulerCommonTest):
|
|||
result_sum = np.sum(np.abs(image))
|
||||
result_mean = np.mean(np.abs(image))
|
||||
|
||||
assert result_sum.item() - 270.6214 < 1e-3
|
||||
assert result_mean.item() - 0.3524 < 1e-3
|
||||
assert abs(result_sum.item() - 270.6214) < 1e-2
|
||||
assert abs(result_mean.item() - 0.3524) < 1e-3
|
||||
|
||||
|
||||
class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (PNDMScheduler,)
|
||||
forward_default_kwargs = (("num_inference_steps", 50),)
|
||||
|
||||
def get_scheduler_config(self, **kwargs):
|
||||
config = {
|
||||
"timesteps": 1000,
|
||||
"beta_start": 0.0001,
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
return config
|
||||
|
||||
def check_over_configs_pmls(self, time_step=0, **config):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
image = self.dummy_image
|
||||
residual = 0.1 * image
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
# copy over dummy past residuals
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
scheduler.set_plms_mode()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.ets = dummy_past_residuals[:]
|
||||
new_scheduler.set_plms_mode()
|
||||
|
||||
output = scheduler.step(residual, image, time_step, **kwargs)
|
||||
new_output = new_scheduler.step(residual, image, time_step, **kwargs)
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def check_over_forward_pmls(self, time_step=0, **forward_kwargs):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
kwargs.update(forward_kwargs)
|
||||
image = self.dummy_image
|
||||
residual = 0.1 * image
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
# copy over dummy past residuals
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
scheduler.set_plms_mode()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.ets = dummy_past_residuals[:]
|
||||
new_scheduler.set_plms_mode()
|
||||
|
||||
output = scheduler.step(residual, image, time_step, **kwargs)
|
||||
new_output = new_scheduler.step(residual, image, time_step, **kwargs)
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [100, 1000]:
|
||||
self.check_over_configs(timesteps=timesteps)
|
||||
|
||||
def test_timesteps_pmls(self):
|
||||
for timesteps in [100, 1000]:
|
||||
self.check_over_configs_pmls(timesteps=timesteps)
|
||||
|
||||
def test_betas(self):
|
||||
for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
|
||||
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
|
||||
|
||||
def test_betas_pmls(self):
|
||||
for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
|
||||
self.check_over_configs_pmls(beta_start=beta_start, beta_end=beta_end)
|
||||
|
||||
def test_schedules(self):
|
||||
for schedule in ["linear", "squaredcos_cap_v2"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_schedules_pmls(self):
|
||||
for schedule in ["linear", "squaredcos_cap_v2"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_time_indices(self):
|
||||
for t in [1, 5, 10]:
|
||||
self.check_over_forward(time_step=t)
|
||||
|
||||
def test_time_indices_pmls(self):
|
||||
for t in [1, 5, 10]:
|
||||
self.check_over_forward_pmls(time_step=t)
|
||||
|
||||
def test_inference_steps(self):
|
||||
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
|
||||
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
|
||||
|
||||
def test_inference_steps_pmls(self):
|
||||
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
|
||||
self.check_over_forward_pmls(time_step=t, num_inference_steps=num_inference_steps)
|
||||
|
||||
def test_inference_pmls_no_past_residuals(self):
|
||||
with self.assertRaises(ValueError):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_plms_mode()
|
||||
|
||||
scheduler.step(self.dummy_image, self.dummy_image, 1, 50)
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
num_inference_steps = 10
|
||||
model = self.dummy_model()
|
||||
image = self.dummy_image_deter
|
||||
|
||||
prk_time_steps = scheduler.get_prk_time_steps(num_inference_steps)
|
||||
for t in range(len(prk_time_steps)):
|
||||
t_orig = prk_time_steps[t]
|
||||
residual = model(image, t_orig)
|
||||
|
||||
image = scheduler.step_prk(residual, image, t, num_inference_steps)
|
||||
|
||||
timesteps = scheduler.get_time_steps(num_inference_steps)
|
||||
for t in range(len(timesteps)):
|
||||
t_orig = timesteps[t]
|
||||
residual = model(image, t_orig)
|
||||
|
||||
image = scheduler.step_plms(residual, image, t, num_inference_steps)
|
||||
|
||||
result_sum = np.sum(np.abs(image))
|
||||
result_mean = np.mean(np.abs(image))
|
||||
|
||||
assert abs(result_sum.item() - 199.1169) < 1e-2
|
||||
assert abs(result_mean.item() - 0.2593) < 1e-3
|
||||
|
|
Loading…
Reference in New Issue