From bb3066428537da6263676448e737f315203d986c Mon Sep 17 00:00:00 2001 From: anton-l Date: Tue, 14 Jun 2022 11:33:24 +0200 Subject: [PATCH] Move the training example --- Makefile | 2 +- .../trainers => examples}/training_ddpm.py | 29 ++++++++++++------- tests/test_modeling_utils.py | 2 +- 3 files changed, 21 insertions(+), 12 deletions(-) rename {src/diffusers/trainers => examples}/training_ddpm.py (87%) diff --git a/Makefile b/Makefile index dad06117..ddf143b6 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) export PYTHONPATH = src -check_dirs := tests src utils +check_dirs := examples tests src utils modified_only_fixup: $(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs))) diff --git a/src/diffusers/trainers/training_ddpm.py b/examples/training_ddpm.py similarity index 87% rename from src/diffusers/trainers/training_ddpm.py rename to examples/training_ddpm.py index bc2a4d10..b3ba111c 100644 --- a/src/diffusers/trainers/training_ddpm.py +++ b/examples/training_ddpm.py @@ -8,14 +8,23 @@ import PIL.Image from accelerate import Accelerator from datasets import load_dataset from diffusers import DDPM, DDPMScheduler, UNetModel -from torchvision.transforms import InterpolationMode, CenterCrop, Compose, Lambda, RandomRotation, RandomHorizontalFlip, Resize, ToTensor +from torchvision.transforms import ( + Compose, + InterpolationMode, + Lambda, + RandomCrop, + RandomHorizontalFlip, + RandomVerticalFlip, + Resize, + ToTensor, +) from tqdm.auto import tqdm from transformers import get_linear_schedule_with_warmup def set_seed(seed): - #torch.backends.cudnn.deterministic = True - #torch.backends.cudnn.benchmark = False + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = False torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) @@ -33,7 +42,7 @@ model = UNetModel( dropout=0.0, num_res_blocks=2, resamp_with_conv=True, - resolution=32 + resolution=32, ) noise_scheduler = DDPMScheduler(timesteps=1000) optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) @@ -44,15 +53,15 @@ gradient_accumulation_steps = 2 augmentations = Compose( [ - RandomHorizontalFlip(), - RandomRotation(15, interpolation=InterpolationMode.BILINEAR, fill=1), Resize(32, interpolation=InterpolationMode.BILINEAR), - CenterCrop(32), + RandomHorizontalFlip(), + RandomVerticalFlip(), + RandomCrop(32), ToTensor(), Lambda(lambda x: x * 2 - 1), ] ) -dataset = load_dataset("huggan/pokemon", split="train") +dataset = load_dataset("huggan/flowers-102-categories", split="train") def transforms(examples): @@ -127,5 +136,5 @@ for epoch in range(num_epochs): image_pil = PIL.Image.fromarray(image_processed[0]) # save image - pipeline.save_pretrained("./poke-ddpm") - image_pil.save(f"./poke-ddpm/test_{epoch}.png") + pipeline.save_pretrained("./flowers-ddpm") + image_pil.save(f"./flowers-ddpm/test_{epoch}.png") diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 6c119479..417ef353 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -19,7 +19,7 @@ import unittest import torch -from diffusers import DDIM, DDPM, DDIMScheduler, DDPMScheduler, LatentDiffusion, UNetModel, PNDM, PNDMScheduler +from diffusers import DDIM, DDPM, PNDM, DDIMScheduler, DDPMScheduler, LatentDiffusion, PNDMScheduler, UNetModel from diffusers.configuration_utils import ConfigMixin from diffusers.pipeline_utils import DiffusionPipeline from diffusers.testing_utils import floats_tensor, slow, torch_device