Move the training example
This commit is contained in:
parent
418888a566
commit
bb30664285
2
Makefile
2
Makefile
|
@ -3,7 +3,7 @@
|
|||
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
|
||||
export PYTHONPATH = src
|
||||
|
||||
check_dirs := tests src utils
|
||||
check_dirs := examples tests src utils
|
||||
|
||||
modified_only_fixup:
|
||||
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
|
||||
|
|
|
@ -8,14 +8,23 @@ import PIL.Image
|
|||
from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
from diffusers import DDPM, DDPMScheduler, UNetModel
|
||||
from torchvision.transforms import InterpolationMode, CenterCrop, Compose, Lambda, RandomRotation, RandomHorizontalFlip, Resize, ToTensor
|
||||
from torchvision.transforms import (
|
||||
Compose,
|
||||
InterpolationMode,
|
||||
Lambda,
|
||||
RandomCrop,
|
||||
RandomHorizontalFlip,
|
||||
RandomVerticalFlip,
|
||||
Resize,
|
||||
ToTensor,
|
||||
)
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
#torch.backends.cudnn.deterministic = True
|
||||
#torch.backends.cudnn.benchmark = False
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.backends.cudnn.benchmark = False
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
|
@ -33,7 +42,7 @@ model = UNetModel(
|
|||
dropout=0.0,
|
||||
num_res_blocks=2,
|
||||
resamp_with_conv=True,
|
||||
resolution=32
|
||||
resolution=32,
|
||||
)
|
||||
noise_scheduler = DDPMScheduler(timesteps=1000)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
|
||||
|
@ -44,15 +53,15 @@ gradient_accumulation_steps = 2
|
|||
|
||||
augmentations = Compose(
|
||||
[
|
||||
RandomHorizontalFlip(),
|
||||
RandomRotation(15, interpolation=InterpolationMode.BILINEAR, fill=1),
|
||||
Resize(32, interpolation=InterpolationMode.BILINEAR),
|
||||
CenterCrop(32),
|
||||
RandomHorizontalFlip(),
|
||||
RandomVerticalFlip(),
|
||||
RandomCrop(32),
|
||||
ToTensor(),
|
||||
Lambda(lambda x: x * 2 - 1),
|
||||
]
|
||||
)
|
||||
dataset = load_dataset("huggan/pokemon", split="train")
|
||||
dataset = load_dataset("huggan/flowers-102-categories", split="train")
|
||||
|
||||
|
||||
def transforms(examples):
|
||||
|
@ -127,5 +136,5 @@ for epoch in range(num_epochs):
|
|||
image_pil = PIL.Image.fromarray(image_processed[0])
|
||||
|
||||
# save image
|
||||
pipeline.save_pretrained("./poke-ddpm")
|
||||
image_pil.save(f"./poke-ddpm/test_{epoch}.png")
|
||||
pipeline.save_pretrained("./flowers-ddpm")
|
||||
image_pil.save(f"./flowers-ddpm/test_{epoch}.png")
|
|
@ -19,7 +19,7 @@ import unittest
|
|||
|
||||
import torch
|
||||
|
||||
from diffusers import DDIM, DDPM, DDIMScheduler, DDPMScheduler, LatentDiffusion, UNetModel, PNDM, PNDMScheduler
|
||||
from diffusers import DDIM, DDPM, PNDM, DDIMScheduler, DDPMScheduler, LatentDiffusion, PNDMScheduler, UNetModel
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
||||
|
|
Loading…
Reference in New Issue