Add training and batched inference test for DDPM vs DDIM (#140)
* Add torch_device to the VE pipeline * Mark the training test with slow
This commit is contained in:
parent
89f2011ced
commit
6c15636b0b
|
@ -11,9 +11,9 @@ from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
|
|||
from .optimization import (
|
||||
get_constant_schedule,
|
||||
get_constant_schedule_with_warmup,
|
||||
get_linear_schedule_with_warmup,
|
||||
get_cosine_schedule_with_warmup,
|
||||
get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||
get_linear_schedule_with_warmup,
|
||||
get_polynomial_decay_schedule_with_warmup,
|
||||
get_scheduler,
|
||||
)
|
||||
|
|
|
@ -1,8 +1,44 @@
|
|||
import copy
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def enable_full_determinism(seed: int):
|
||||
"""
|
||||
Helper function for reproducible behavior during distributed training. See
|
||||
- https://pytorch.org/docs/stable/notes/randomness.html for pytorch
|
||||
"""
|
||||
# set seed first
|
||||
set_seed(seed)
|
||||
|
||||
# Enable PyTorch deterministic mode. This potentially requires either the environment
|
||||
# variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
|
||||
# depending on the CUDA version, so we set them both here
|
||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
# Enable CUDNN deterministic mode
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
"""
|
||||
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
|
||||
Args:
|
||||
seed (`int`): The seed to set.
|
||||
"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
# ^^ safe to call this function even if cuda is not available
|
||||
|
||||
|
||||
class EMAModel:
|
||||
"""
|
||||
Exponential Moving Average of models weights
|
||||
|
|
|
@ -876,3 +876,45 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.4399, 0.44975, 0.46825, 0.474, 0.4359, 0.4581, 0.45095, 0.4341, 0.4447])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_ddpm_ddim_equality(self):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
ddpm_scheduler = DDPMScheduler(tensor_format="pt")
|
||||
ddim_scheduler = DDIMScheduler(tensor_format="pt")
|
||||
|
||||
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
|
||||
ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
ddpm_image = ddpm(generator=generator, output_type="numpy")["sample"]
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")["sample"]
|
||||
|
||||
# the values aren't exactly equal, but the images look the same upon visual inspection
|
||||
assert np.abs(ddpm_image - ddim_image).max() < 1e-1
|
||||
|
||||
@slow
|
||||
def test_ddpm_ddim_equality_batched(self):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
ddpm_scheduler = DDPMScheduler(tensor_format="pt")
|
||||
ddim_scheduler = DDIMScheduler(tensor_format="pt")
|
||||
|
||||
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
|
||||
ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
ddpm_images = ddpm(batch_size=2, generator=generator, output_type="numpy")["sample"]
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
ddim_images = ddim(batch_size=2, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")[
|
||||
"sample"
|
||||
]
|
||||
|
||||
# the values aren't exactly equal, but the images look the same upon visual inspection
|
||||
assert np.abs(ddpm_images - ddim_images).max() < 1e-1
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import DDIMScheduler, DDPMScheduler, UNet2DModel
|
||||
from diffusers.testing_utils import slow, torch_device
|
||||
from diffusers.training_utils import enable_full_determinism, set_seed
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class TrainingTests(unittest.TestCase):
|
||||
def get_model_optimizer(self, resolution=32):
|
||||
set_seed(0)
|
||||
model = UNet2DModel(sample_size=resolution, in_channels=3, out_channels=3)
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
|
||||
return model, optimizer
|
||||
|
||||
@slow
|
||||
def test_training_step_equality(self):
|
||||
enable_full_determinism(0)
|
||||
|
||||
ddpm_scheduler = DDPMScheduler(
|
||||
num_train_timesteps=1000,
|
||||
beta_start=0.0001,
|
||||
beta_end=0.02,
|
||||
beta_schedule="linear",
|
||||
clip_sample=True,
|
||||
tensor_format="pt",
|
||||
)
|
||||
ddim_scheduler = DDIMScheduler(
|
||||
num_train_timesteps=1000,
|
||||
beta_start=0.0001,
|
||||
beta_end=0.02,
|
||||
beta_schedule="linear",
|
||||
clip_sample=True,
|
||||
tensor_format="pt",
|
||||
)
|
||||
|
||||
assert ddpm_scheduler.num_train_timesteps == ddim_scheduler.num_train_timesteps
|
||||
|
||||
# shared batches for DDPM and DDIM
|
||||
set_seed(0)
|
||||
clean_images = [torch.randn((4, 3, 32, 32)).clip(-1, 1).to(torch_device) for _ in range(4)]
|
||||
noise = [torch.randn((4, 3, 32, 32)).to(torch_device) for _ in range(4)]
|
||||
timesteps = [torch.randint(0, 1000, (4,)).long().to(torch_device) for _ in range(4)]
|
||||
|
||||
# train with a DDPM scheduler
|
||||
model, optimizer = self.get_model_optimizer(resolution=32)
|
||||
model.train().to(torch_device)
|
||||
for i in range(4):
|
||||
optimizer.zero_grad()
|
||||
ddpm_noisy_images = ddpm_scheduler.add_noise(clean_images[i], noise[i], timesteps[i])
|
||||
ddpm_noise_pred = model(ddpm_noisy_images, timesteps[i])["sample"]
|
||||
loss = torch.nn.functional.mse_loss(ddpm_noise_pred, noise[i])
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
del model, optimizer
|
||||
|
||||
# recreate the model and optimizer, and retry with DDIM
|
||||
model, optimizer = self.get_model_optimizer(resolution=32)
|
||||
model.train().to(torch_device)
|
||||
for i in range(4):
|
||||
optimizer.zero_grad()
|
||||
ddim_noisy_images = ddim_scheduler.add_noise(clean_images[i], noise[i], timesteps[i])
|
||||
ddim_noise_pred = model(ddim_noisy_images, timesteps[i])["sample"]
|
||||
loss = torch.nn.functional.mse_loss(ddim_noise_pred, noise[i])
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
del model, optimizer
|
||||
|
||||
self.assertTrue(torch.allclose(ddpm_noisy_images, ddim_noisy_images, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(ddpm_noise_pred, ddim_noise_pred, atol=1e-5))
|
Loading…
Reference in New Issue