This commit is contained in:
Patrick von Platen 2022-06-22 12:38:44 +00:00
commit e45dae7dc0
3 changed files with 22 additions and 16 deletions

View File

@ -10,7 +10,7 @@ python -m torch.distributed.launch \
train_unconditional.py \ train_unconditional.py \
--dataset="huggan/flowers-102-categories" \ --dataset="huggan/flowers-102-categories" \
--resolution=64 \ --resolution=64 \
--output_path="flowers-ddpm" \ --output_dir="flowers-ddpm" \
--batch_size=16 \ --batch_size=16 \
--num_epochs=100 \ --num_epochs=100 \
--gradient_accumulation_steps=1 \ --gradient_accumulation_steps=1 \
@ -34,7 +34,7 @@ python -m torch.distributed.launch \
train_unconditional.py \ train_unconditional.py \
--dataset="huggan/pokemon" \ --dataset="huggan/pokemon" \
--resolution=64 \ --resolution=64 \
--output_path="pokemon-ddpm" \ --output_dir="pokemon-ddpm" \
--batch_size=16 \ --batch_size=16 \
--num_epochs=100 \ --num_epochs=100 \
--gradient_accumulation_steps=1 \ --gradient_accumulation_steps=1 \

View File

@ -39,7 +39,7 @@ def main(args):
resamp_with_conv=True, resamp_with_conv=True,
resolution=args.resolution, resolution=args.resolution,
) )
noise_scheduler = DDPMScheduler(timesteps=1000) noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt")
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
augmentations = Compose( augmentations = Compose(
@ -93,15 +93,13 @@ def main(args):
pbar.set_description(f"Epoch {epoch}") pbar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
clean_images = batch["input"] clean_images = batch["input"]
noisy_images = torch.empty_like(clean_images) noise_samples = torch.randn(clean_images.shape).to(clean_images.device)
noise_samples = torch.empty_like(clean_images)
bsz = clean_images.shape[0] bsz = clean_images.shape[0]
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long() timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long()
for idx in range(bsz):
noise = torch.randn(clean_images.shape[1:]).to(clean_images.device) # add noise onto the clean images according to the noise magnitude at each timestep
noise_samples[idx] = noise # (this is the forward diffusion process)
noisy_images[idx] = noise_scheduler.forward_step(clean_images[idx], noise, timesteps[idx]) noisy_images = noise_scheduler.training_step(clean_images, noise_samples, timesteps)
if step % args.gradient_accumulation_steps != 0: if step % args.gradient_accumulation_steps != 0:
with accelerator.no_sync(model): with accelerator.no_sync(model):
@ -146,7 +144,7 @@ def main(args):
# save image # save image
test_dir = os.path.join(args.output_dir, "test_samples") test_dir = os.path.join(args.output_dir, "test_samples")
os.makedirs(test_dir, exist_ok=True) os.makedirs(test_dir, exist_ok=True)
image_pil.save(f"{test_dir}/{epoch}.png") image_pil.save(f"{test_dir}/{epoch:04d}.png")
# save the model # save the model
if args.push_to_hub: if args.push_to_hub:

View File

@ -17,6 +17,7 @@
import math import math
import numpy as np import numpy as np
import torch
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
@ -142,11 +143,18 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return pred_prev_sample return pred_prev_sample
def forward_step(self, original_sample, noise, t): def training_step(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor):
sqrt_alpha_prod = self.alphas_cumprod[t] ** 0.5 if timesteps.dim() != 1:
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[t]) ** 0.5 raise ValueError("`timesteps` must be a 1D tensor")
noisy_sample = sqrt_alpha_prod * original_sample + sqrt_one_minus_alpha_prod * noise
return noisy_sample device = original_samples.device
batch_size = original_samples.shape[0]
timesteps = timesteps.reshape(batch_size, 1, 1, 1)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
noisy_samples = sqrt_alpha_prod.to(device) * original_samples + sqrt_one_minus_alpha_prod.to(device) * noise
return noisy_samples
def __len__(self): def __len__(self):
return self.config.timesteps return self.config.timesteps