Update the training examples (#102)

* New unet, gradient accumulation

* Save every n epochs

* Remove find_unused_params, hooray!

* Update examples

* Switch to DDPM completely
This commit is contained in:
Anton Lozhkov 2022-07-20 19:51:23 +02:00 committed by GitHub
parent 6b275fca49
commit 76f9b52289
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 73 additions and 81 deletions

View File

@ -5,18 +5,17 @@
The command to train a DDPM UNet model on the Oxford Flowers dataset:
```bash
python -m torch.distributed.launch \
--nproc_per_node 4 \
train_unconditional.py \
accelerate launch train_unconditional.py \
--dataset="huggan/flowers-102-categories" \
--resolution=64 \
--output_dir="flowers-ddpm" \
--batch_size=16 \
--output_dir="ddpm-ema-flowers-64" \
--train_batch_size=16 \
--num_epochs=100 \
--gradient_accumulation_steps=1 \
--lr=1e-4 \
--warmup_steps=500 \
--mixed_precision=no
--learning_rate=1e-4 \
--lr_warmup_steps=500 \
--mixed_precision=no \
--push_to_hub
```
A full training run takes 2 hours on 4xV100 GPUs.
@ -29,18 +28,17 @@ A full training run takes 2 hours on 4xV100 GPUs.
The command to train a DDPM UNet model on the Pokemon dataset:
```bash
python -m torch.distributed.launch \
--nproc_per_node 4 \
train_unconditional.py \
accelerate launch train_unconditional.py \
--dataset="huggan/pokemon" \
--resolution=64 \
--output_dir="pokemon-ddpm" \
--batch_size=16 \
--output_dir="ddpm-ema-pokemon-64" \
--train_batch_size=16 \
--num_epochs=100 \
--gradient_accumulation_steps=1 \
--lr=1e-4 \
--warmup_steps=500 \
--mixed_precision=no
--learning_rate=1e-4 \
--lr_warmup_steps=500 \
--mixed_precision=no \
--push_to_hub
```
A full training run takes 2 hours on 4xV100 GPUs.

View File

@ -4,10 +4,10 @@ import os
import torch
import torch.nn.functional as F
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate import Accelerator
from accelerate.logging import get_logger
from datasets import load_dataset
from diffusers import DDIMPipeline, DDIMScheduler, UNetModel
from diffusers import DDPMPipeline, DDPMScheduler, UNetUnconditionalModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
@ -27,25 +27,37 @@ logger = get_logger(__name__)
def main(args):
ddp_unused_params = DistributedDataParallelKwargs(find_unused_parameters=True)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator = Accelerator(
mixed_precision=args.mixed_precision,
log_with="tensorboard",
logging_dir=logging_dir,
kwargs_handlers=[ddp_unused_params],
)
model = UNetModel(
attn_resolutions=(16,),
ch=128,
ch_mult=(1, 2, 4, 8),
dropout=0.0,
model = UNetUnconditionalModel(
image_size=args.resolution,
in_channels=3,
out_channels=3,
num_res_blocks=2,
resamp_with_conv=True,
resolution=args.resolution,
block_channels=(128, 128, 256, 256, 512, 512),
down_blocks=(
"UNetResDownBlock2D",
"UNetResDownBlock2D",
"UNetResDownBlock2D",
"UNetResDownBlock2D",
"UNetResAttnDownBlock2D",
"UNetResDownBlock2D",
),
up_blocks=(
"UNetResUpBlock2D",
"UNetResAttnUpBlock2D",
"UNetResUpBlock2D",
"UNetResUpBlock2D",
"UNetResUpBlock2D",
"UNetResUpBlock2D",
),
)
noise_scheduler = DDIMScheduler(timesteps=1000, tensor_format="pt")
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt")
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.learning_rate,
@ -92,19 +104,6 @@ def main(args):
run = os.path.split(__file__)[-1].split(".")[0]
accelerator.init_trackers(run)
# Train!
is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size() if is_distributed else 1
total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * world_size
max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataloader.dataset)}")
logger.info(f" Num Epochs = {args.num_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")
global_step = 0
for epoch in range(args.num_epochs):
model.train()
@ -112,45 +111,37 @@ def main(args):
progress_bar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
clean_images = batch["input"]
noise_samples = torch.randn(clean_images.shape).to(clean_images.device)
# Sample noise that we'll add to the images
noise = torch.randn(clean_images.shape).to(clean_images.device)
bsz = clean_images.shape[0]
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long()
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.num_train_timesteps, (bsz,), device=clean_images.device
).long()
# add noise onto the clean images according to the noise magnitude at each timestep
# Add noise to the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_images = noise_scheduler.add_noise(clean_images, noise_samples, timesteps)
noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
if step % args.gradient_accumulation_steps != 0:
with accelerator.no_sync(model):
output = model(noisy_images, timesteps)
# predict the noise residual
loss = F.mse_loss(output, noise_samples)
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
else:
output = model(noisy_images, timesteps)
# predict the noise residual
loss = F.mse_loss(output, noise_samples)
loss = loss / args.gradient_accumulation_steps
with accelerator.accumulate(model):
# Predict the noise residual
noise_pred = model(noisy_images, timesteps)["sample"]
loss = F.mse_loss(noise_pred, noise)
accelerator.backward(loss)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
ema_model.step(model)
if args.use_ema:
ema_model.step(model)
optimizer.zero_grad()
progress_bar.update(1)
progress_bar.set_postfix(
loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"], ema_decay=ema_model.decay
)
accelerator.log(
{
"train_loss": loss.detach().item(),
"epoch": epoch,
"ema_decay": ema_model.decay,
"step": global_step,
},
step=global_step,
)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
if args.use_ema:
logs["ema_decay"] = ema_model.decay
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
global_step += 1
progress_bar.close()
@ -159,14 +150,14 @@ def main(args):
# Generate a sample image for visual inspection
if accelerator.is_main_process:
with torch.no_grad():
pipeline = DDIMPipeline(
unet=accelerator.unwrap_model(ema_model.averaged_model),
noise_scheduler=noise_scheduler,
pipeline = DDPMPipeline(
unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model),
scheduler=noise_scheduler,
)
generator = torch.manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
images = pipeline(generator=generator, batch_size=args.eval_batch_size, num_inference_steps=50)
images = pipeline(generator=generator, batch_size=args.eval_batch_size)
# denormalize the images and save to tensorboard
images_processed = (images.cpu() + 1.0) * 127.5
@ -174,11 +165,12 @@ def main(args):
accelerator.trackers[0].writer.add_images("test_samples", images_processed, epoch)
# save the model
if args.push_to_hub:
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
else:
pipeline.save_pretrained(args.output_dir)
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
# save the model
if args.push_to_hub:
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
else:
pipeline.save_pretrained(args.output_dir)
accelerator.wait_for_everyone()
accelerator.end_training()
@ -188,12 +180,13 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories")
parser.add_argument("--output_dir", type=str, default="ddpm-model")
parser.add_argument("--output_dir", type=str, default="ddpm-flowers-64")
parser.add_argument("--overwrite_output_dir", action="store_true")
parser.add_argument("--resolution", type=int, default=64)
parser.add_argument("--train_batch_size", type=int, default=16)
parser.add_argument("--eval_batch_size", type=int, default=16)
parser.add_argument("--num_epochs", type=int, default=100)
parser.add_argument("--save_model_epochs", type=int, default=5)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--lr_scheduler", type=str, default="cosine")
@ -202,6 +195,7 @@ if __name__ == "__main__":
parser.add_argument("--adam_beta2", type=float, default=0.999)
parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
parser.add_argument("--adam_epsilon", type=float, default=1e-3)
parser.add_argument("--use_ema", action="store_true", default=True)
parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
parser.add_argument("--ema_power", type=float, default=3 / 4)
parser.add_argument("--ema_max_decay", type=float, default=0.9999)