Fix VE SDE tests, clean API (#95)
* clean ddpm api to match ddim * correct ve sde class * update pipeline API for ve sde * make style * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
8b42c7cecc
commit
182b164f32
|
@ -40,8 +40,10 @@ class DDPMPipeline(DiffusionPipeline):
|
|||
)
|
||||
image = image.to(torch_device)
|
||||
|
||||
num_prediction_steps = len(self.scheduler)
|
||||
for t in tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(1000)
|
||||
|
||||
for t in tqdm(self.scheduler.timesteps):
|
||||
# 1. predict noise model_output
|
||||
with torch.no_grad():
|
||||
model_output = self.unet(image, t)
|
||||
|
|
|
@ -18,8 +18,8 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
|||
|
||||
model = self.model.to(device)
|
||||
|
||||
x = torch.randn(*shape) * self.scheduler.config.sigma_max
|
||||
x = x.to(device)
|
||||
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
|
||||
sample = sample.to(device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.set_sigmas(num_inference_steps)
|
||||
|
@ -29,19 +29,20 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
|||
|
||||
for _ in range(self.scheduler.correct_steps):
|
||||
with torch.no_grad():
|
||||
result = self.model(x, sigma_t)
|
||||
model_output = self.model(sample, sigma_t)
|
||||
|
||||
if isinstance(result, dict):
|
||||
result = result["sample"]
|
||||
if isinstance(model_output, dict):
|
||||
model_output = model_output["sample"]
|
||||
|
||||
x = self.scheduler.step_correct(result, x)
|
||||
sample = self.scheduler.step_correct(model_output, sample)["prev_sample"]
|
||||
|
||||
with torch.no_grad():
|
||||
result = model(x, sigma_t)
|
||||
model_output = model(sample, sigma_t)
|
||||
|
||||
if isinstance(result, dict):
|
||||
result = result["sample"]
|
||||
if isinstance(model_output, dict):
|
||||
model_output = model_output["sample"]
|
||||
|
||||
x, x_mean = self.scheduler.step_pred(result, x, t)
|
||||
output = self.scheduler.step_pred(model_output, t, sample)
|
||||
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
|
||||
|
||||
return x_mean
|
||||
return sample_mean
|
||||
|
|
|
@ -86,8 +86,20 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
|
||||
self.one = np.array(1.0)
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
|
||||
|
||||
self.tensor_format = tensor_format
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
def set_timesteps(self, num_inference_steps):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.timesteps = np.arange(
|
||||
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
|
||||
)[::-1].copy()
|
||||
self.set_format(tensor_format=self.tensor_format)
|
||||
|
||||
def get_variance(self, t, variance_type=None):
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
|
||||
|
||||
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
|
||||
import pdb
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -27,8 +29,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|||
"""
|
||||
The variance exploding stochastic differential equation (SDE) scheduler.
|
||||
|
||||
:param snr: coefficient weighting the step from the score sample (from the network) to the random noise. :param
|
||||
sigma_min: initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
|
||||
:param snr: coefficient weighting the step from the model_output sample (from the network) to the random noise.
|
||||
:param sigma_min: initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
|
||||
distribution of the data.
|
||||
:param sigma_max: :param sampling_eps: the end value of sampling, where timesteps decrease progessively from 1 to
|
||||
epsilon. :param correct_steps: number of correction steps performed on a produced sample. :param tensor_format:
|
||||
|
@ -54,12 +56,16 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|||
sampling_eps=sampling_eps,
|
||||
correct_steps=correct_steps,
|
||||
)
|
||||
|
||||
self.sigmas = None
|
||||
self.discrete_sigmas = None
|
||||
# self.sigmas = None
|
||||
# self.discrete_sigmas = None
|
||||
#
|
||||
# # setable values
|
||||
# self.num_inference_steps = None
|
||||
self.timesteps = None
|
||||
|
||||
# TODO - update step to be torch-independant
|
||||
self.set_sigmas(self.num_train_timesteps)
|
||||
|
||||
self.tensor_format = tensor_format
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
def set_timesteps(self, num_inference_steps):
|
||||
|
@ -104,52 +110,80 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
|
||||
|
||||
def step_pred(self, score, x, t):
|
||||
def set_seed(self, seed):
|
||||
tensor_format = getattr(self, "tensor_format", "pt")
|
||||
if tensor_format == "np":
|
||||
np.random.seed(seed)
|
||||
elif tensor_format == "pt":
|
||||
torch.manual_seed(seed)
|
||||
else:
|
||||
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
|
||||
|
||||
def step_pred(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
seed=None,
|
||||
):
|
||||
"""
|
||||
Predict the sample at the previous timestep by reversing the SDE.
|
||||
"""
|
||||
# TODO(Patrick) better comments + non-PyTorch
|
||||
t = self.repeat_scalar(t, x.shape[0]).to(x.device)
|
||||
timesteps = self.long((t * (len(self.timesteps) - 1))).to(x.device)
|
||||
if seed is not None:
|
||||
self.set_seed(seed)
|
||||
# TODO(Patrick) non-PyTorch
|
||||
|
||||
sigma = self.discrete_sigmas[timesteps].to(x.device)
|
||||
adjacent_sigma = self.get_adjacent_sigma(timesteps, t)
|
||||
drift = self.zeros_like(x)
|
||||
timestep = timestep * torch.ones(
|
||||
sample.shape[0], device=sample.device
|
||||
) # torch.repeat_interleave(timestep, sample.shape[0])
|
||||
timesteps = (timestep * (len(self.timesteps) - 1)).long()
|
||||
|
||||
sigma = self.discrete_sigmas[timesteps].to(sample.device)
|
||||
adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep)
|
||||
drift = self.zeros_like(sample)
|
||||
diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5
|
||||
|
||||
# equation 6 in the paper: the score modeled by the network is grad_x log pt(x)
|
||||
# equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
|
||||
# also equation 47 shows the analog from SDE models to ancestral sampling methods
|
||||
drift = drift - diffusion[:, None, None, None] ** 2 * score
|
||||
drift = drift - diffusion[:, None, None, None] ** 2 * model_output
|
||||
|
||||
# equation 6: sample noise for the diffusion term of
|
||||
noise = self.randn_like(x)
|
||||
x_mean = x - drift # subtract because `dt` is a small negative timestep
|
||||
noise = self.randn_like(sample)
|
||||
prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
|
||||
# TODO is the variable diffusion the correct scaling term for the noise?
|
||||
x = x_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
|
||||
return x, x_mean
|
||||
prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
|
||||
|
||||
def step_correct(self, score, x):
|
||||
return {"prev_sample": prev_sample, "prev_sample_mean": prev_sample_mean}
|
||||
|
||||
def step_correct(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
seed=None,
|
||||
):
|
||||
"""
|
||||
Correct the predicted sample based on the output score of the network. This is often run repeatedly after
|
||||
making the prediction for the previous timestep.
|
||||
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
|
||||
after making the prediction for the previous timestep.
|
||||
"""
|
||||
# TODO(Patrick) non-PyTorch
|
||||
if seed is not None:
|
||||
self.set_seed(seed)
|
||||
|
||||
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
|
||||
# sample noise for correction
|
||||
noise = self.randn_like(x)
|
||||
noise = self.randn_like(sample)
|
||||
|
||||
# compute step size from the score, the noise, and the snr
|
||||
grad_norm = self.norm(score)
|
||||
# compute step size from the model_output, the noise, and the snr
|
||||
grad_norm = self.norm(model_output)
|
||||
noise_norm = self.norm(noise)
|
||||
step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
|
||||
step_size = self.repeat_scalar(step_size, x.shape[0]) # * self.ones(x.shape[0], device=x.device)
|
||||
step_size = step_size * torch.ones(sample.shape[0]).to(sample.device)
|
||||
# self.repeat_scalar(step_size, sample.shape[0])
|
||||
|
||||
# compute corrected sample: score term and noise term
|
||||
x_mean = x + step_size[:, None, None, None] * score
|
||||
x = x_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
|
||||
# compute corrected sample: model_output term and noise term
|
||||
prev_sample_mean = sample + step_size[:, None, None, None] * model_output
|
||||
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
|
||||
|
||||
return x
|
||||
return {"prev_sample": prev_sample}
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
|
|
@ -53,16 +53,6 @@ class SchedulerMixin:
|
|||
|
||||
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
|
||||
|
||||
def long(self, tensor):
|
||||
tensor_format = getattr(self, "tensor_format", "pt")
|
||||
|
||||
if tensor_format == "np":
|
||||
return np.int64(tensor)
|
||||
elif tensor_format == "pt":
|
||||
return tensor.long()
|
||||
|
||||
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
|
||||
|
||||
def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: Union[np.ndarray, torch.Tensor]):
|
||||
"""
|
||||
Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims.
|
||||
|
@ -103,15 +93,6 @@ class SchedulerMixin:
|
|||
|
||||
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
|
||||
|
||||
def repeat_scalar(self, tensor, count):
|
||||
tensor_format = getattr(self, "tensor_format", "pt")
|
||||
if tensor_format == "np":
|
||||
return np.repeat(tensor, count)
|
||||
elif tensor_format == "pt":
|
||||
return torch.repeat_interleave(tensor, count)
|
||||
|
||||
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
|
||||
|
||||
def zeros_like(self, tensor):
|
||||
tensor_format = getattr(self, "tensor_format", "pt")
|
||||
if tensor_format == "np":
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import pdb
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
|
@ -507,8 +508,42 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
|||
assert abs(result_mean.item() - 0.2593) < 1e-3
|
||||
|
||||
|
||||
class ScoreSdeVeSchedulerTest(SchedulerCommonTest):
|
||||
class ScoreSdeVeSchedulerTest(unittest.TestCase):
|
||||
# TODO adapt with class SchedulerCommonTest (scheduler needs Numpy Integration)
|
||||
scheduler_classes = (ScoreSdeVeScheduler,)
|
||||
forward_default_kwargs = (("seed", 0),)
|
||||
|
||||
@property
|
||||
def dummy_sample(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
height = 8
|
||||
width = 8
|
||||
|
||||
sample = torch.rand((batch_size, num_channels, height, width))
|
||||
|
||||
return sample
|
||||
|
||||
@property
|
||||
def dummy_sample_deter(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
height = 8
|
||||
width = 8
|
||||
|
||||
num_elems = batch_size * num_channels * height * width
|
||||
sample = torch.arange(num_elems)
|
||||
sample = sample.reshape(num_channels, height, width, batch_size)
|
||||
sample = sample / num_elems
|
||||
sample = sample.permute(3, 0, 1, 2)
|
||||
|
||||
return sample
|
||||
|
||||
def dummy_model(self):
|
||||
def model(sample, t, *args):
|
||||
return sample * t / (t + 1)
|
||||
|
||||
return model
|
||||
|
||||
def get_scheduler_config(self, **kwargs):
|
||||
config = {
|
||||
|
@ -517,7 +552,7 @@ class ScoreSdeVeSchedulerTest(SchedulerCommonTest):
|
|||
"sigma_min": 0.01,
|
||||
"sigma_max": 1348,
|
||||
"sampling_eps": 1e-5,
|
||||
"tensor_format": "np", # TODO add test for tensor formats
|
||||
"tensor_format": "pt", # TODO add test for tensor formats
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
|
@ -538,15 +573,15 @@ class ScoreSdeVeSchedulerTest(SchedulerCommonTest):
|
|||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
|
||||
output = scheduler.step_pred(residual, sample, time_step, **kwargs)
|
||||
new_output = new_scheduler.step_pred(residual, sample, time_step, **kwargs)
|
||||
output = scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
output = scheduler.step_correct(residual, sample, **kwargs)
|
||||
new_output = new_scheduler.step_correct(residual, sample, **kwargs)
|
||||
output = scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
|
||||
|
||||
def check_over_forward(self, time_step=0, **forward_kwargs):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
@ -564,15 +599,15 @@ class ScoreSdeVeSchedulerTest(SchedulerCommonTest):
|
|||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
|
||||
output = scheduler.step_pred(residual, sample, time_step, **kwargs)
|
||||
new_output = new_scheduler.step_pred(residual, sample, time_step, **kwargs)
|
||||
output = scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
output = scheduler.step_correct(residual, sample, **kwargs)
|
||||
new_output = new_scheduler.step_correct(residual, sample, **kwargs)
|
||||
output = scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [10, 100, 1000]:
|
||||
|
@ -583,11 +618,12 @@ class ScoreSdeVeSchedulerTest(SchedulerCommonTest):
|
|||
self.check_over_configs(sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
|
||||
def test_time_indices(self):
|
||||
for t in [1, 5, 10]:
|
||||
for t in [0.1, 0.5, 0.75]:
|
||||
self.check_over_forward(time_step=t)
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
np.random.seed(0)
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
@ -598,52 +634,27 @@ class ScoreSdeVeSchedulerTest(SchedulerCommonTest):
|
|||
sample = self.dummy_sample_deter
|
||||
|
||||
scheduler.set_sigmas(num_inference_steps)
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for i, t in enumerate(scheduler.timesteps):
|
||||
sigma_t = scheduler.sigmas[i]
|
||||
|
||||
for _ in range(scheduler.correct_steps):
|
||||
with torch.no_grad():
|
||||
result = model(sample, sigma_t)
|
||||
sample = scheduler.step_correct(result, sample)
|
||||
model_output = model(sample, sigma_t)
|
||||
sample = scheduler.step_correct(model_output, sample, **kwargs)["prev_sample"]
|
||||
|
||||
with torch.no_grad():
|
||||
result = model(sample, sigma_t)
|
||||
model_output = model(sample, sigma_t)
|
||||
|
||||
sample, sample_mean = scheduler.step_pred(result, sample, t)
|
||||
output = scheduler.step_pred(model_output, t, sample, **kwargs)
|
||||
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
|
||||
|
||||
result_sum = np.sum(np.abs(sample))
|
||||
result_mean = np.mean(np.abs(sample))
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_sum.item() - 10629923278.7104) < 1e-2
|
||||
assert abs(result_mean.item() - 13841045.9358) < 1e-3
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
new_scheduler.set_timesteps(num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output = scheduler.step_pred(residual, 1, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step_pred(residual, 1, sample, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
assert abs(result_sum.item() - 14224664576.0) < 1e-2
|
||||
assert abs(result_mean.item() - 18521698.0) < 1e-3
|
||||
|
||||
def test_step_shape(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
@ -667,31 +678,3 @@ class ScoreSdeVeSchedulerTest(SchedulerCommonTest):
|
|||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
def test_pytorch_equal_numpy(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
sample_pt = torch.tensor(sample)
|
||||
residual_pt = 0.1 * sample_pt
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
scheduler_pt.set_timesteps(num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output = scheduler.step_pred(residual, 1, sample, **kwargs)["prev_sample"]
|
||||
output_pt = scheduler_pt.step_pred(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
|
||||
|
|
Loading…
Reference in New Issue