diffusers/examples/training_ddpm.py

141 lines
4.2 KiB
Python
Raw Normal View History

2022-06-13 08:50:30 -06:00
import random
import numpy as np
import torch
import torch.nn.functional as F
import PIL.Image
from accelerate import Accelerator
from datasets import load_dataset
from diffusers import DDPM, DDPMScheduler, UNetModel
2022-06-14 03:33:24 -06:00
from torchvision.transforms import (
Compose,
InterpolationMode,
Lambda,
RandomCrop,
RandomHorizontalFlip,
RandomVerticalFlip,
Resize,
ToTensor,
)
2022-06-13 08:50:30 -06:00
from tqdm.auto import tqdm
from transformers import get_linear_schedule_with_warmup
def set_seed(seed):
2022-06-14 03:33:24 -06:00
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
2022-06-13 08:50:30 -06:00
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
set_seed(0)
2022-06-13 10:31:27 -06:00
accelerator = Accelerator()
model = UNetModel(
attn_resolutions=(16,),
ch=128,
ch_mult=(1, 2, 2, 2),
2022-06-14 00:00:23 -06:00
dropout=0.0,
2022-06-13 10:31:27 -06:00
num_res_blocks=2,
resamp_with_conv=True,
2022-06-14 03:33:24 -06:00
resolution=32,
2022-06-13 10:31:27 -06:00
)
2022-06-13 08:50:30 -06:00
noise_scheduler = DDPMScheduler(timesteps=1000)
2022-06-14 00:00:23 -06:00
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
2022-06-13 08:50:30 -06:00
num_epochs = 100
2022-06-13 10:31:27 -06:00
batch_size = 64
gradient_accumulation_steps = 2
2022-06-13 08:50:30 -06:00
augmentations = Compose(
[
2022-06-14 00:00:23 -06:00
Resize(32, interpolation=InterpolationMode.BILINEAR),
2022-06-14 03:33:24 -06:00
RandomHorizontalFlip(),
RandomVerticalFlip(),
RandomCrop(32),
2022-06-13 08:50:30 -06:00
ToTensor(),
Lambda(lambda x: x * 2 - 1),
]
)
2022-06-14 03:33:24 -06:00
dataset = load_dataset("huggan/flowers-102-categories", split="train")
2022-06-13 08:50:30 -06:00
def transforms(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
return {"input": images}
dataset.set_transform(transforms)
2022-06-14 00:00:23 -06:00
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
2022-06-13 08:50:30 -06:00
2022-06-14 00:00:23 -06:00
lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=500,
num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
)
2022-06-13 08:50:30 -06:00
2022-06-14 00:00:23 -06:00
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
2022-06-13 08:50:30 -06:00
)
for epoch in range(num_epochs):
model.train()
pbar = tqdm(total=len(train_dataloader), unit="ba")
pbar.set_description(f"Epoch {epoch}")
2022-06-14 00:00:23 -06:00
losses = []
2022-06-13 08:50:30 -06:00
for step, batch in enumerate(train_dataloader):
clean_images = batch["input"]
noisy_images = torch.empty_like(clean_images)
2022-06-13 10:31:27 -06:00
noise_samples = torch.empty_like(clean_images)
2022-06-13 08:50:30 -06:00
bsz = clean_images.shape[0]
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long()
for idx in range(bsz):
2022-06-13 10:31:27 -06:00
noise = torch.randn((3, 32, 32)).to(clean_images.device)
noise_samples[idx] = noise
2022-06-13 08:50:30 -06:00
noisy_images[idx] = noise_scheduler.forward_step(clean_images[idx], noise, timesteps[idx])
if step % gradient_accumulation_steps == 0:
with accelerator.no_sync(model):
output = model(noisy_images, timesteps)
2022-06-13 10:31:27 -06:00
# predict the noise
loss = F.l1_loss(output, noise_samples)
2022-06-13 08:50:30 -06:00
accelerator.backward(loss)
else:
output = model(noisy_images, timesteps)
loss = F.l1_loss(output, clean_images)
accelerator.backward(loss)
2022-06-13 10:31:27 -06:00
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
2022-06-13 08:50:30 -06:00
optimizer.step()
2022-06-14 00:00:23 -06:00
lr_scheduler.step()
2022-06-13 08:50:30 -06:00
optimizer.zero_grad()
2022-06-14 00:00:23 -06:00
loss = loss.detach().item()
losses.append(loss)
2022-06-13 08:50:30 -06:00
pbar.update(1)
2022-06-14 00:00:23 -06:00
pbar.set_postfix(loss=loss, avg_loss=np.mean(losses), lr=optimizer.param_groups[0]["lr"])
2022-06-13 08:50:30 -06:00
optimizer.step()
# eval
model.eval()
with torch.no_grad():
pipeline = DDPM(unet=model, noise_scheduler=noise_scheduler)
generator = torch.Generator()
generator = generator.manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
image = pipeline(generator=generator)
# process image to PIL
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed.type(torch.uint8).numpy()
image_pil = PIL.Image.fromarray(image_processed[0])
# save image
2022-06-14 03:33:24 -06:00
pipeline.save_pretrained("./flowers-ddpm")
image_pil.save(f"./flowers-ddpm/test_{epoch}.png")