Add LMSDiscreteSchedulerTest (#467)
* [WIP] add LMSDiscreteSchedulerTest * fixes for comments * add torch numpy test * rebase * Update tests/test_scheduler.py * Update tests/test_scheduler.py * style * return residuals Co-authored-by: Anton Lozhkov <anton@huggingface.co>
This commit is contained in:
parent
88972172d8
commit
a54cfe6828
|
@ -19,7 +19,7 @@ from typing import Dict, List, Tuple
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler, ScoreSdeVeScheduler
|
||||
from diffusers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler, ScoreSdeVeScheduler
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
@ -853,3 +853,83 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
|
|||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
|
||||
class LMSDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (LMSDiscreteScheduler,)
|
||||
num_inference_steps = 10
|
||||
|
||||
def get_scheduler_config(self, **kwargs):
|
||||
config = {
|
||||
"num_train_timesteps": 1100,
|
||||
"beta_start": 0.0001,
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
"trained_betas": None,
|
||||
"tensor_format": "pt",
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
return config
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [10, 50, 100, 1000]:
|
||||
self.check_over_configs(num_train_timesteps=timesteps)
|
||||
|
||||
def test_betas(self):
|
||||
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
|
||||
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
|
||||
|
||||
def test_schedules(self):
|
||||
for schedule in ["linear", "scaled_linear"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_time_indices(self):
|
||||
for t in [0, 500, 800]:
|
||||
self.check_over_forward(time_step=t)
|
||||
|
||||
def test_pytorch_equal_numpy(self):
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample_pt = self.dummy_sample
|
||||
residual_pt = 0.1 * sample_pt
|
||||
|
||||
sample = sample_pt.numpy()
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler_config["tensor_format"] = "np"
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler_config["tensor_format"] = "pt"
|
||||
scheduler_pt = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(self.num_inference_steps)
|
||||
scheduler_pt.set_timesteps(self.num_inference_steps)
|
||||
|
||||
output = scheduler.step(residual, 1, sample).prev_sample
|
||||
output_pt = scheduler_pt.step(residual_pt, 1, sample_pt).prev_sample
|
||||
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(self.num_inference_steps)
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter * scheduler.sigmas[0]
|
||||
|
||||
for i, t in enumerate(scheduler.timesteps):
|
||||
sample = sample / ((scheduler.sigmas[i] ** 2 + 1) ** 0.5)
|
||||
|
||||
model_output = model(sample, t)
|
||||
|
||||
output = scheduler.step(model_output, i, sample)
|
||||
sample = output.prev_sample
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_sum.item() - 1006.388) < 1e-2
|
||||
assert abs(result_mean.item() - 1.31) < 1e-3
|
||||
|
|
Loading…
Reference in New Issue