Add LMSDiscreteSchedulerTest (#467)

* [WIP] add LMSDiscreteSchedulerTest

* fixes for comments

* add torch numpy test

* rebase

* Update tests/test_scheduler.py

* Update tests/test_scheduler.py

* style

* return residuals

Co-authored-by: Anton Lozhkov <anton@huggingface.co>
This commit is contained in:
Sid Sahai 2022-09-16 10:10:56 -07:00 committed by GitHub
parent 88972172d8
commit a54cfe6828
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 81 additions and 1 deletions

View File

@ -19,7 +19,7 @@ from typing import Dict, List, Tuple
import numpy as np
import torch
from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler, ScoreSdeVeScheduler
from diffusers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler, ScoreSdeVeScheduler
torch.backends.cuda.matmul.allow_tf32 = False
@ -853,3 +853,83 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
class LMSDiscreteSchedulerTest(SchedulerCommonTest):
scheduler_classes = (LMSDiscreteScheduler,)
num_inference_steps = 10
def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 1100,
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
"trained_betas": None,
"tensor_format": "pt",
}
config.update(**kwargs)
return config
def test_timesteps(self):
for timesteps in [10, 50, 100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
def test_schedules(self):
for schedule in ["linear", "scaled_linear"]:
self.check_over_configs(beta_schedule=schedule)
def test_time_indices(self):
for t in [0, 500, 800]:
self.check_over_forward(time_step=t)
def test_pytorch_equal_numpy(self):
for scheduler_class in self.scheduler_classes:
sample_pt = self.dummy_sample
residual_pt = 0.1 * sample_pt
sample = sample_pt.numpy()
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config()
scheduler_config["tensor_format"] = "np"
scheduler = scheduler_class(**scheduler_config)
scheduler_config["tensor_format"] = "pt"
scheduler_pt = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps)
scheduler_pt.set_timesteps(self.num_inference_steps)
output = scheduler.step(residual, 1, sample).prev_sample
output_pt = scheduler_pt.step(residual_pt, 1, sample_pt).prev_sample
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.sigmas[0]
for i, t in enumerate(scheduler.timesteps):
sample = sample / ((scheduler.sigmas[i] ** 2 + 1) ** 0.5)
model_output = model(sample, t)
output = scheduler.step(model_output, i, sample)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 1006.388) < 1e-2
assert abs(result_mean.item() - 1.31) < 1e-3