From a54cfe6828355e9875b380456ac32e92cb4bcf9f Mon Sep 17 00:00:00 2001 From: Sid Sahai Date: Fri, 16 Sep 2022 10:10:56 -0700 Subject: [PATCH] Add LMSDiscreteSchedulerTest (#467) * [WIP] add LMSDiscreteSchedulerTest * fixes for comments * add torch numpy test * rebase * Update tests/test_scheduler.py * Update tests/test_scheduler.py * style * return residuals Co-authored-by: Anton Lozhkov --- tests/test_scheduler.py | 82 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 81 insertions(+), 1 deletion(-) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 6306be28..e78f1004 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -19,7 +19,7 @@ from typing import Dict, List, Tuple import numpy as np import torch -from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler, ScoreSdeVeScheduler +from diffusers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler, ScoreSdeVeScheduler torch.backends.cuda.matmul.allow_tf32 = False @@ -853,3 +853,83 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, output_1.shape) + + +class LMSDiscreteSchedulerTest(SchedulerCommonTest): + scheduler_classes = (LMSDiscreteScheduler,) + num_inference_steps = 10 + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1100, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + "trained_betas": None, + "tensor_format": "pt", + } + + config.update(**kwargs) + return config + + def test_timesteps(self): + for timesteps in [10, 50, 100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_betas(self): + for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "scaled_linear"]: + self.check_over_configs(beta_schedule=schedule) + + def test_time_indices(self): + for t in [0, 500, 800]: + self.check_over_forward(time_step=t) + + def test_pytorch_equal_numpy(self): + for scheduler_class in self.scheduler_classes: + sample_pt = self.dummy_sample + residual_pt = 0.1 * sample_pt + + sample = sample_pt.numpy() + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler_config["tensor_format"] = "np" + scheduler = scheduler_class(**scheduler_config) + + scheduler_config["tensor_format"] = "pt" + scheduler_pt = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps) + scheduler_pt.set_timesteps(self.num_inference_steps) + + output = scheduler.step(residual, 1, sample).prev_sample + output_pt = scheduler_pt.step(residual_pt, 1, sample_pt).prev_sample + assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" + + def test_full_loop_no_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.sigmas[0] + + for i, t in enumerate(scheduler.timesteps): + sample = sample / ((scheduler.sigmas[i] ** 2 + 1) ** 0.5) + + model_output = model(sample, t) + + output = scheduler.step(model_output, i, sample) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 1006.388) < 1e-2 + assert abs(result_mean.item() - 1.31) < 1e-3