diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index b2513899..136b0eb3 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -11,9 +11,9 @@ from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel from .optimization import ( get_constant_schedule, get_constant_schedule_with_warmup, - get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup, + get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup, get_scheduler, ) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 45be9acc..022be41a 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -1,8 +1,44 @@ import copy +import os +import random +import numpy as np import torch +def enable_full_determinism(seed: int): + """ + Helper function for reproducible behavior during distributed training. See + - https://pytorch.org/docs/stable/notes/randomness.html for pytorch + """ + # set seed first + set_seed(seed) + + # Enable PyTorch deterministic mode. This potentially requires either the environment + # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set, + # depending on the CUDA version, so we set them both here + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" + torch.use_deterministic_algorithms(True) + + # Enable CUDNN deterministic mode + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def set_seed(seed: int): + """ + Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. + Args: + seed (`int`): The seed to set. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # ^^ safe to call this function even if cuda is not available + + class EMAModel: """ Exponential Moving Average of models weights diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index ce4f9958..c47a787c 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -876,3 +876,45 @@ class PipelineTesterMixin(unittest.TestCase): assert image.shape == (1, 256, 256, 3) expected_slice = np.array([0.4399, 0.44975, 0.46825, 0.474, 0.4359, 0.4581, 0.45095, 0.4341, 0.4447]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + @slow + def test_ddpm_ddim_equality(self): + model_id = "google/ddpm-cifar10-32" + + unet = UNet2DModel.from_pretrained(model_id) + ddpm_scheduler = DDPMScheduler(tensor_format="pt") + ddim_scheduler = DDIMScheduler(tensor_format="pt") + + ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler) + ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler) + + generator = torch.manual_seed(0) + ddpm_image = ddpm(generator=generator, output_type="numpy")["sample"] + + generator = torch.manual_seed(0) + ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")["sample"] + + # the values aren't exactly equal, but the images look the same upon visual inspection + assert np.abs(ddpm_image - ddim_image).max() < 1e-1 + + @slow + def test_ddpm_ddim_equality_batched(self): + model_id = "google/ddpm-cifar10-32" + + unet = UNet2DModel.from_pretrained(model_id) + ddpm_scheduler = DDPMScheduler(tensor_format="pt") + ddim_scheduler = DDIMScheduler(tensor_format="pt") + + ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler) + ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler) + + generator = torch.manual_seed(0) + ddpm_images = ddpm(batch_size=2, generator=generator, output_type="numpy")["sample"] + + generator = torch.manual_seed(0) + ddim_images = ddim(batch_size=2, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")[ + "sample" + ] + + # the values aren't exactly equal, but the images look the same upon visual inspection + assert np.abs(ddpm_images - ddim_images).max() < 1e-1 diff --git a/tests/test_training.py b/tests/test_training.py new file mode 100644 index 00000000..48903c37 --- /dev/null +++ b/tests/test_training.py @@ -0,0 +1,89 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import DDIMScheduler, DDPMScheduler, UNet2DModel +from diffusers.testing_utils import slow, torch_device +from diffusers.training_utils import enable_full_determinism, set_seed + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class TrainingTests(unittest.TestCase): + def get_model_optimizer(self, resolution=32): + set_seed(0) + model = UNet2DModel(sample_size=resolution, in_channels=3, out_channels=3) + optimizer = torch.optim.SGD(model.parameters(), lr=0.0001) + return model, optimizer + + @slow + def test_training_step_equality(self): + enable_full_determinism(0) + + ddpm_scheduler = DDPMScheduler( + num_train_timesteps=1000, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="linear", + clip_sample=True, + tensor_format="pt", + ) + ddim_scheduler = DDIMScheduler( + num_train_timesteps=1000, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="linear", + clip_sample=True, + tensor_format="pt", + ) + + assert ddpm_scheduler.num_train_timesteps == ddim_scheduler.num_train_timesteps + + # shared batches for DDPM and DDIM + set_seed(0) + clean_images = [torch.randn((4, 3, 32, 32)).clip(-1, 1).to(torch_device) for _ in range(4)] + noise = [torch.randn((4, 3, 32, 32)).to(torch_device) for _ in range(4)] + timesteps = [torch.randint(0, 1000, (4,)).long().to(torch_device) for _ in range(4)] + + # train with a DDPM scheduler + model, optimizer = self.get_model_optimizer(resolution=32) + model.train().to(torch_device) + for i in range(4): + optimizer.zero_grad() + ddpm_noisy_images = ddpm_scheduler.add_noise(clean_images[i], noise[i], timesteps[i]) + ddpm_noise_pred = model(ddpm_noisy_images, timesteps[i])["sample"] + loss = torch.nn.functional.mse_loss(ddpm_noise_pred, noise[i]) + loss.backward() + optimizer.step() + del model, optimizer + + # recreate the model and optimizer, and retry with DDIM + model, optimizer = self.get_model_optimizer(resolution=32) + model.train().to(torch_device) + for i in range(4): + optimizer.zero_grad() + ddim_noisy_images = ddim_scheduler.add_noise(clean_images[i], noise[i], timesteps[i]) + ddim_noise_pred = model(ddim_noisy_images, timesteps[i])["sample"] + loss = torch.nn.functional.mse_loss(ddim_noise_pred, noise[i]) + loss.backward() + optimizer.step() + del model, optimizer + + self.assertTrue(torch.allclose(ddpm_noisy_images, ddim_noisy_images, atol=1e-5)) + self.assertTrue(torch.allclose(ddpm_noise_pred, ddim_noise_pred, atol=1e-5))