2022-06-13 08:50:30 -06:00
|
|
|
import random
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
import PIL.Image
|
|
|
|
from accelerate import Accelerator
|
|
|
|
from datasets import load_dataset
|
|
|
|
from diffusers import DDPM, DDPMScheduler, UNetModel
|
2022-06-14 03:33:24 -06:00
|
|
|
from torchvision.transforms import (
|
|
|
|
Compose,
|
|
|
|
InterpolationMode,
|
|
|
|
Lambda,
|
|
|
|
RandomCrop,
|
|
|
|
RandomHorizontalFlip,
|
|
|
|
RandomVerticalFlip,
|
|
|
|
Resize,
|
|
|
|
ToTensor,
|
|
|
|
)
|
2022-06-13 08:50:30 -06:00
|
|
|
from tqdm.auto import tqdm
|
|
|
|
from transformers import get_linear_schedule_with_warmup
|
|
|
|
|
|
|
|
|
|
|
|
def set_seed(seed):
|
2022-06-14 03:33:24 -06:00
|
|
|
# torch.backends.cudnn.deterministic = True
|
|
|
|
# torch.backends.cudnn.benchmark = False
|
2022-06-13 08:50:30 -06:00
|
|
|
torch.manual_seed(seed)
|
|
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
np.random.seed(seed)
|
|
|
|
random.seed(seed)
|
|
|
|
|
|
|
|
|
|
|
|
set_seed(0)
|
|
|
|
|
2022-06-13 10:31:27 -06:00
|
|
|
accelerator = Accelerator()
|
|
|
|
|
|
|
|
model = UNetModel(
|
|
|
|
attn_resolutions=(16,),
|
|
|
|
ch=128,
|
|
|
|
ch_mult=(1, 2, 2, 2),
|
2022-06-14 00:00:23 -06:00
|
|
|
dropout=0.0,
|
2022-06-13 10:31:27 -06:00
|
|
|
num_res_blocks=2,
|
|
|
|
resamp_with_conv=True,
|
2022-06-14 03:33:24 -06:00
|
|
|
resolution=32,
|
2022-06-13 10:31:27 -06:00
|
|
|
)
|
2022-06-13 08:50:30 -06:00
|
|
|
noise_scheduler = DDPMScheduler(timesteps=1000)
|
2022-06-14 00:00:23 -06:00
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
|
2022-06-13 08:50:30 -06:00
|
|
|
|
|
|
|
num_epochs = 100
|
2022-06-13 10:31:27 -06:00
|
|
|
batch_size = 64
|
|
|
|
gradient_accumulation_steps = 2
|
2022-06-13 08:50:30 -06:00
|
|
|
|
|
|
|
augmentations = Compose(
|
|
|
|
[
|
2022-06-14 00:00:23 -06:00
|
|
|
Resize(32, interpolation=InterpolationMode.BILINEAR),
|
2022-06-14 03:33:24 -06:00
|
|
|
RandomHorizontalFlip(),
|
|
|
|
RandomVerticalFlip(),
|
|
|
|
RandomCrop(32),
|
2022-06-13 08:50:30 -06:00
|
|
|
ToTensor(),
|
|
|
|
Lambda(lambda x: x * 2 - 1),
|
|
|
|
]
|
|
|
|
)
|
2022-06-14 03:33:24 -06:00
|
|
|
dataset = load_dataset("huggan/flowers-102-categories", split="train")
|
2022-06-13 08:50:30 -06:00
|
|
|
|
|
|
|
|
|
|
|
def transforms(examples):
|
|
|
|
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
|
|
|
|
return {"input": images}
|
|
|
|
|
|
|
|
|
|
|
|
dataset.set_transform(transforms)
|
2022-06-14 00:00:23 -06:00
|
|
|
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
2022-06-13 08:50:30 -06:00
|
|
|
|
2022-06-14 00:00:23 -06:00
|
|
|
lr_scheduler = get_linear_schedule_with_warmup(
|
|
|
|
optimizer=optimizer,
|
|
|
|
num_warmup_steps=500,
|
|
|
|
num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
|
|
|
|
)
|
2022-06-13 08:50:30 -06:00
|
|
|
|
2022-06-14 00:00:23 -06:00
|
|
|
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
|
|
|
model, optimizer, train_dataloader, lr_scheduler
|
2022-06-13 08:50:30 -06:00
|
|
|
)
|
|
|
|
|
|
|
|
for epoch in range(num_epochs):
|
|
|
|
model.train()
|
|
|
|
pbar = tqdm(total=len(train_dataloader), unit="ba")
|
|
|
|
pbar.set_description(f"Epoch {epoch}")
|
2022-06-14 00:00:23 -06:00
|
|
|
losses = []
|
2022-06-13 08:50:30 -06:00
|
|
|
for step, batch in enumerate(train_dataloader):
|
|
|
|
clean_images = batch["input"]
|
|
|
|
noisy_images = torch.empty_like(clean_images)
|
2022-06-13 10:31:27 -06:00
|
|
|
noise_samples = torch.empty_like(clean_images)
|
2022-06-13 08:50:30 -06:00
|
|
|
bsz = clean_images.shape[0]
|
|
|
|
|
|
|
|
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long()
|
|
|
|
for idx in range(bsz):
|
2022-06-13 10:31:27 -06:00
|
|
|
noise = torch.randn((3, 32, 32)).to(clean_images.device)
|
|
|
|
noise_samples[idx] = noise
|
2022-06-13 08:50:30 -06:00
|
|
|
noisy_images[idx] = noise_scheduler.forward_step(clean_images[idx], noise, timesteps[idx])
|
|
|
|
|
|
|
|
if step % gradient_accumulation_steps == 0:
|
|
|
|
with accelerator.no_sync(model):
|
|
|
|
output = model(noisy_images, timesteps)
|
2022-06-13 10:31:27 -06:00
|
|
|
# predict the noise
|
|
|
|
loss = F.l1_loss(output, noise_samples)
|
2022-06-13 08:50:30 -06:00
|
|
|
accelerator.backward(loss)
|
|
|
|
else:
|
|
|
|
output = model(noisy_images, timesteps)
|
|
|
|
loss = F.l1_loss(output, clean_images)
|
|
|
|
accelerator.backward(loss)
|
2022-06-13 10:31:27 -06:00
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
2022-06-13 08:50:30 -06:00
|
|
|
optimizer.step()
|
2022-06-14 00:00:23 -06:00
|
|
|
lr_scheduler.step()
|
2022-06-13 08:50:30 -06:00
|
|
|
optimizer.zero_grad()
|
2022-06-14 00:00:23 -06:00
|
|
|
loss = loss.detach().item()
|
|
|
|
losses.append(loss)
|
2022-06-13 08:50:30 -06:00
|
|
|
pbar.update(1)
|
2022-06-14 00:00:23 -06:00
|
|
|
pbar.set_postfix(loss=loss, avg_loss=np.mean(losses), lr=optimizer.param_groups[0]["lr"])
|
2022-06-13 08:50:30 -06:00
|
|
|
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
# eval
|
|
|
|
model.eval()
|
|
|
|
with torch.no_grad():
|
|
|
|
pipeline = DDPM(unet=model, noise_scheduler=noise_scheduler)
|
|
|
|
generator = torch.Generator()
|
|
|
|
generator = generator.manual_seed(0)
|
|
|
|
# run pipeline in inference (sample random noise and denoise)
|
|
|
|
image = pipeline(generator=generator)
|
|
|
|
|
|
|
|
# process image to PIL
|
|
|
|
image_processed = image.cpu().permute(0, 2, 3, 1)
|
|
|
|
image_processed = (image_processed + 1.0) * 127.5
|
|
|
|
image_processed = image_processed.type(torch.uint8).numpy()
|
|
|
|
image_pil = PIL.Image.fromarray(image_processed[0])
|
|
|
|
|
|
|
|
# save image
|
2022-06-14 03:33:24 -06:00
|
|
|
pipeline.save_pretrained("./flowers-ddpm")
|
|
|
|
image_pil.save(f"./flowers-ddpm/test_{epoch}.png")
|