Merge branch 'main' of https://github.com/huggingface/diffusers into main
This commit is contained in:
commit
e45dae7dc0
|
@ -10,7 +10,7 @@ python -m torch.distributed.launch \
|
|||
train_unconditional.py \
|
||||
--dataset="huggan/flowers-102-categories" \
|
||||
--resolution=64 \
|
||||
--output_path="flowers-ddpm" \
|
||||
--output_dir="flowers-ddpm" \
|
||||
--batch_size=16 \
|
||||
--num_epochs=100 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
|
@ -34,7 +34,7 @@ python -m torch.distributed.launch \
|
|||
train_unconditional.py \
|
||||
--dataset="huggan/pokemon" \
|
||||
--resolution=64 \
|
||||
--output_path="pokemon-ddpm" \
|
||||
--output_dir="pokemon-ddpm" \
|
||||
--batch_size=16 \
|
||||
--num_epochs=100 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
|
|
|
@ -39,7 +39,7 @@ def main(args):
|
|||
resamp_with_conv=True,
|
||||
resolution=args.resolution,
|
||||
)
|
||||
noise_scheduler = DDPMScheduler(timesteps=1000)
|
||||
noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt")
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
||||
|
||||
augmentations = Compose(
|
||||
|
@ -93,15 +93,13 @@ def main(args):
|
|||
pbar.set_description(f"Epoch {epoch}")
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
clean_images = batch["input"]
|
||||
noisy_images = torch.empty_like(clean_images)
|
||||
noise_samples = torch.empty_like(clean_images)
|
||||
noise_samples = torch.randn(clean_images.shape).to(clean_images.device)
|
||||
bsz = clean_images.shape[0]
|
||||
|
||||
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long()
|
||||
for idx in range(bsz):
|
||||
noise = torch.randn(clean_images.shape[1:]).to(clean_images.device)
|
||||
noise_samples[idx] = noise
|
||||
noisy_images[idx] = noise_scheduler.forward_step(clean_images[idx], noise, timesteps[idx])
|
||||
|
||||
# add noise onto the clean images according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_images = noise_scheduler.training_step(clean_images, noise_samples, timesteps)
|
||||
|
||||
if step % args.gradient_accumulation_steps != 0:
|
||||
with accelerator.no_sync(model):
|
||||
|
@ -146,7 +144,7 @@ def main(args):
|
|||
# save image
|
||||
test_dir = os.path.join(args.output_dir, "test_samples")
|
||||
os.makedirs(test_dir, exist_ok=True)
|
||||
image_pil.save(f"{test_dir}/{epoch}.png")
|
||||
image_pil.save(f"{test_dir}/{epoch:04d}.png")
|
||||
|
||||
# save the model
|
||||
if args.push_to_hub:
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
@ -142,11 +143,18 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
return pred_prev_sample
|
||||
|
||||
def forward_step(self, original_sample, noise, t):
|
||||
sqrt_alpha_prod = self.alphas_cumprod[t] ** 0.5
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[t]) ** 0.5
|
||||
noisy_sample = sqrt_alpha_prod * original_sample + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_sample
|
||||
def training_step(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor):
|
||||
if timesteps.dim() != 1:
|
||||
raise ValueError("`timesteps` must be a 1D tensor")
|
||||
|
||||
device = original_samples.device
|
||||
batch_size = original_samples.shape[0]
|
||||
timesteps = timesteps.reshape(batch_size, 1, 1, 1)
|
||||
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
||||
noisy_samples = sqrt_alpha_prod.to(device) * original_samples + sqrt_one_minus_alpha_prod.to(device) * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.config.timesteps
|
||||
|
|
Loading…
Reference in New Issue