hotfix for pdnm test (#220)
This commit is contained in:
parent
6a03060c45
commit
3f1861ee46
|
@ -426,16 +426,18 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
|||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# copy over dummy past residuals
|
||||
# copy over dummy past residuals (must be after setting timesteps)
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.ets = dummy_past_residuals[:]
|
||||
new_scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# copy over dummy past residual (must be after setting timesteps)
|
||||
new_scheduler.ets = dummy_past_residuals[:]
|
||||
|
||||
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
|
||||
|
@ -461,12 +463,8 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
|||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(tensor_format="np", **scheduler_config)
|
||||
# copy over dummy past residuals
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
|
||||
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
|
||||
# copy over dummy past residuals
|
||||
scheduler_pt.ets = dummy_past_residuals_pt[:]
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
@ -474,6 +472,10 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
|||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
# copy over dummy past residuals (must be done after set_timesteps)
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
scheduler_pt.ets = dummy_past_residuals_pt[:]
|
||||
|
||||
output = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"]
|
||||
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
|
||||
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
|
||||
|
@ -494,15 +496,16 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
|||
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
# copy over dummy past residuals
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
# copy over dummy past residuals (must be done after set_timesteps)
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
|
||||
output_0 = scheduler.step_prk(residual, 0, sample, **kwargs)["prev_sample"]
|
||||
output_1 = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"]
|
||||
|
||||
|
|
Loading…
Reference in New Issue