Fix VE SDE tests, clean API (#95)

* clean ddpm api to match ddim

* correct ve sde class

* update pipeline API for ve sde

* make style

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Nathan Lambert 2022-07-19 03:12:45 -07:00 committed by GitHub
parent 8b42c7cecc
commit 182b164f32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 155 additions and 142 deletions

View File

@ -40,8 +40,10 @@ class DDPMPipeline(DiffusionPipeline):
)
image = image.to(torch_device)
num_prediction_steps = len(self.scheduler)
for t in tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
# set step values
self.scheduler.set_timesteps(1000)
for t in tqdm(self.scheduler.timesteps):
# 1. predict noise model_output
with torch.no_grad():
model_output = self.unet(image, t)

View File

@ -18,8 +18,8 @@ class ScoreSdeVePipeline(DiffusionPipeline):
model = self.model.to(device)
x = torch.randn(*shape) * self.scheduler.config.sigma_max
x = x.to(device)
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
sample = sample.to(device)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.set_sigmas(num_inference_steps)
@ -29,19 +29,20 @@ class ScoreSdeVePipeline(DiffusionPipeline):
for _ in range(self.scheduler.correct_steps):
with torch.no_grad():
result = self.model(x, sigma_t)
model_output = self.model(sample, sigma_t)
if isinstance(result, dict):
result = result["sample"]
if isinstance(model_output, dict):
model_output = model_output["sample"]
x = self.scheduler.step_correct(result, x)
sample = self.scheduler.step_correct(model_output, sample)["prev_sample"]
with torch.no_grad():
result = model(x, sigma_t)
model_output = model(sample, sigma_t)
if isinstance(result, dict):
result = result["sample"]
if isinstance(model_output, dict):
model_output = model_output["sample"]
x, x_mean = self.scheduler.step_pred(result, x, t)
output = self.scheduler.step_pred(model_output, t, sample)
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
return x_mean
return sample_mean

View File

@ -86,8 +86,20 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.one = np.array(1.0)
# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps):
self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
)[::-1].copy()
self.set_format(tensor_format=self.tensor_format)
def get_variance(self, t, variance_type=None):
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one

View File

@ -15,6 +15,8 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
import pdb
from typing import Union
import numpy as np
import torch
@ -27,8 +29,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
"""
The variance exploding stochastic differential equation (SDE) scheduler.
:param snr: coefficient weighting the step from the score sample (from the network) to the random noise. :param
sigma_min: initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
:param snr: coefficient weighting the step from the model_output sample (from the network) to the random noise.
:param sigma_min: initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
distribution of the data.
:param sigma_max: :param sampling_eps: the end value of sampling, where timesteps decrease progessively from 1 to
epsilon. :param correct_steps: number of correction steps performed on a produced sample. :param tensor_format:
@ -54,12 +56,16 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
sampling_eps=sampling_eps,
correct_steps=correct_steps,
)
self.sigmas = None
self.discrete_sigmas = None
# self.sigmas = None
# self.discrete_sigmas = None
#
# # setable values
# self.num_inference_steps = None
self.timesteps = None
# TODO - update step to be torch-independant
self.set_sigmas(self.num_train_timesteps)
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps):
@ -104,52 +110,80 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def step_pred(self, score, x, t):
def set_seed(self, seed):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
np.random.seed(seed)
elif tensor_format == "pt":
torch.manual_seed(seed)
else:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def step_pred(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
seed=None,
):
"""
Predict the sample at the previous timestep by reversing the SDE.
"""
# TODO(Patrick) better comments + non-PyTorch
t = self.repeat_scalar(t, x.shape[0]).to(x.device)
timesteps = self.long((t * (len(self.timesteps) - 1))).to(x.device)
if seed is not None:
self.set_seed(seed)
# TODO(Patrick) non-PyTorch
sigma = self.discrete_sigmas[timesteps].to(x.device)
adjacent_sigma = self.get_adjacent_sigma(timesteps, t)
drift = self.zeros_like(x)
timestep = timestep * torch.ones(
sample.shape[0], device=sample.device
) # torch.repeat_interleave(timestep, sample.shape[0])
timesteps = (timestep * (len(self.timesteps) - 1)).long()
sigma = self.discrete_sigmas[timesteps].to(sample.device)
adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep)
drift = self.zeros_like(sample)
diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5
# equation 6 in the paper: the score modeled by the network is grad_x log pt(x)
# equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
# also equation 47 shows the analog from SDE models to ancestral sampling methods
drift = drift - diffusion[:, None, None, None] ** 2 * score
drift = drift - diffusion[:, None, None, None] ** 2 * model_output
# equation 6: sample noise for the diffusion term of
noise = self.randn_like(x)
x_mean = x - drift # subtract because `dt` is a small negative timestep
noise = self.randn_like(sample)
prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
# TODO is the variable diffusion the correct scaling term for the noise?
x = x_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
return x, x_mean
prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
def step_correct(self, score, x):
return {"prev_sample": prev_sample, "prev_sample_mean": prev_sample_mean}
def step_correct(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
sample: Union[torch.FloatTensor, np.ndarray],
seed=None,
):
"""
Correct the predicted sample based on the output score of the network. This is often run repeatedly after
making the prediction for the previous timestep.
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
after making the prediction for the previous timestep.
"""
# TODO(Patrick) non-PyTorch
if seed is not None:
self.set_seed(seed)
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
# sample noise for correction
noise = self.randn_like(x)
noise = self.randn_like(sample)
# compute step size from the score, the noise, and the snr
grad_norm = self.norm(score)
# compute step size from the model_output, the noise, and the snr
grad_norm = self.norm(model_output)
noise_norm = self.norm(noise)
step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
step_size = self.repeat_scalar(step_size, x.shape[0]) # * self.ones(x.shape[0], device=x.device)
step_size = step_size * torch.ones(sample.shape[0]).to(sample.device)
# self.repeat_scalar(step_size, sample.shape[0])
# compute corrected sample: score term and noise term
x_mean = x + step_size[:, None, None, None] * score
x = x_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
# compute corrected sample: model_output term and noise term
prev_sample_mean = sample + step_size[:, None, None, None] * model_output
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
return x
return {"prev_sample": prev_sample}
def __len__(self):
return self.config.num_train_timesteps

View File

@ -53,16 +53,6 @@ class SchedulerMixin:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def long(self, tensor):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.int64(tensor)
elif tensor_format == "pt":
return tensor.long()
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: Union[np.ndarray, torch.Tensor]):
"""
Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims.
@ -103,15 +93,6 @@ class SchedulerMixin:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def repeat_scalar(self, tensor, count):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.repeat(tensor, count)
elif tensor_format == "pt":
return torch.repeat_interleave(tensor, count)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def zeros_like(self, tensor):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pdb
import tempfile
import unittest
@ -507,8 +508,42 @@ class PNDMSchedulerTest(SchedulerCommonTest):
assert abs(result_mean.item() - 0.2593) < 1e-3
class ScoreSdeVeSchedulerTest(SchedulerCommonTest):
class ScoreSdeVeSchedulerTest(unittest.TestCase):
# TODO adapt with class SchedulerCommonTest (scheduler needs Numpy Integration)
scheduler_classes = (ScoreSdeVeScheduler,)
forward_default_kwargs = (("seed", 0),)
@property
def dummy_sample(self):
batch_size = 4
num_channels = 3
height = 8
width = 8
sample = torch.rand((batch_size, num_channels, height, width))
return sample
@property
def dummy_sample_deter(self):
batch_size = 4
num_channels = 3
height = 8
width = 8
num_elems = batch_size * num_channels * height * width
sample = torch.arange(num_elems)
sample = sample.reshape(num_channels, height, width, batch_size)
sample = sample / num_elems
sample = sample.permute(3, 0, 1, 2)
return sample
def dummy_model(self):
def model(sample, t, *args):
return sample * t / (t + 1)
return model
def get_scheduler_config(self, **kwargs):
config = {
@ -517,7 +552,7 @@ class ScoreSdeVeSchedulerTest(SchedulerCommonTest):
"sigma_min": 0.01,
"sigma_max": 1348,
"sampling_eps": 1e-5,
"tensor_format": "np", # TODO add test for tensor formats
"tensor_format": "pt", # TODO add test for tensor formats
}
config.update(**kwargs)
@ -538,15 +573,15 @@ class ScoreSdeVeSchedulerTest(SchedulerCommonTest):
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step_pred(residual, sample, time_step, **kwargs)
new_output = new_scheduler.step_pred(residual, sample, time_step, **kwargs)
output = scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step_correct(residual, sample, **kwargs)
new_output = new_scheduler.step_correct(residual, sample, **kwargs)
output = scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
def check_over_forward(self, time_step=0, **forward_kwargs):
kwargs = dict(self.forward_default_kwargs)
@ -564,15 +599,15 @@ class ScoreSdeVeSchedulerTest(SchedulerCommonTest):
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step_pred(residual, sample, time_step, **kwargs)
new_output = new_scheduler.step_pred(residual, sample, time_step, **kwargs)
output = scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step_correct(residual, sample, **kwargs)
new_output = new_scheduler.step_correct(residual, sample, **kwargs)
output = scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
def test_timesteps(self):
for timesteps in [10, 100, 1000]:
@ -583,11 +618,12 @@ class ScoreSdeVeSchedulerTest(SchedulerCommonTest):
self.check_over_configs(sigma_min=sigma_min, sigma_max=sigma_max)
def test_time_indices(self):
for t in [1, 5, 10]:
for t in [0.1, 0.5, 0.75]:
self.check_over_forward(time_step=t)
def test_full_loop_no_noise(self):
np.random.seed(0)
kwargs = dict(self.forward_default_kwargs)
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
@ -598,52 +634,27 @@ class ScoreSdeVeSchedulerTest(SchedulerCommonTest):
sample = self.dummy_sample_deter
scheduler.set_sigmas(num_inference_steps)
scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(scheduler.timesteps):
sigma_t = scheduler.sigmas[i]
for _ in range(scheduler.correct_steps):
with torch.no_grad():
result = model(sample, sigma_t)
sample = scheduler.step_correct(result, sample)
model_output = model(sample, sigma_t)
sample = scheduler.step_correct(model_output, sample, **kwargs)["prev_sample"]
with torch.no_grad():
result = model(sample, sigma_t)
model_output = model(sample, sigma_t)
sample, sample_mean = scheduler.step_pred(result, sample, t)
output = scheduler.step_pred(model_output, t, sample, **kwargs)
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
result_sum = np.sum(np.abs(sample))
result_mean = np.mean(np.abs(sample))
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 10629923278.7104) < 1e-2
assert abs(result_mean.item() - 13841045.9358) < 1e-3
def test_from_pretrained_save_pretrained(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
new_scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step_pred(residual, 1, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_pred(residual, 1, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
assert abs(result_sum.item() - 14224664576.0) < 1e-2
assert abs(result_mean.item() - 18521698.0) < 1e-3
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
@ -667,31 +678,3 @@ class ScoreSdeVeSchedulerTest(SchedulerCommonTest):
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
def test_pytorch_equal_numpy(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample = self.dummy_sample
residual = 0.1 * sample
sample_pt = torch.tensor(sample)
residual_pt = 0.1 * sample_pt
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
scheduler_pt.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step_pred(residual, 1, sample, **kwargs)["prev_sample"]
output_pt = scheduler_pt.step_pred(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"