Update the training examples (#102)
* New unet, gradient accumulation * Save every n epochs * Remove find_unused_params, hooray! * Update examples * Switch to DDPM completely
This commit is contained in:
parent
6b275fca49
commit
76f9b52289
|
@ -5,18 +5,17 @@
|
|||
The command to train a DDPM UNet model on the Oxford Flowers dataset:
|
||||
|
||||
```bash
|
||||
python -m torch.distributed.launch \
|
||||
--nproc_per_node 4 \
|
||||
train_unconditional.py \
|
||||
accelerate launch train_unconditional.py \
|
||||
--dataset="huggan/flowers-102-categories" \
|
||||
--resolution=64 \
|
||||
--output_dir="flowers-ddpm" \
|
||||
--batch_size=16 \
|
||||
--output_dir="ddpm-ema-flowers-64" \
|
||||
--train_batch_size=16 \
|
||||
--num_epochs=100 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--lr=1e-4 \
|
||||
--warmup_steps=500 \
|
||||
--mixed_precision=no
|
||||
--learning_rate=1e-4 \
|
||||
--lr_warmup_steps=500 \
|
||||
--mixed_precision=no \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
A full training run takes 2 hours on 4xV100 GPUs.
|
||||
|
@ -29,18 +28,17 @@ A full training run takes 2 hours on 4xV100 GPUs.
|
|||
The command to train a DDPM UNet model on the Pokemon dataset:
|
||||
|
||||
```bash
|
||||
python -m torch.distributed.launch \
|
||||
--nproc_per_node 4 \
|
||||
train_unconditional.py \
|
||||
accelerate launch train_unconditional.py \
|
||||
--dataset="huggan/pokemon" \
|
||||
--resolution=64 \
|
||||
--output_dir="pokemon-ddpm" \
|
||||
--batch_size=16 \
|
||||
--output_dir="ddpm-ema-pokemon-64" \
|
||||
--train_batch_size=16 \
|
||||
--num_epochs=100 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--lr=1e-4 \
|
||||
--warmup_steps=500 \
|
||||
--mixed_precision=no
|
||||
--learning_rate=1e-4 \
|
||||
--lr_warmup_steps=500 \
|
||||
--mixed_precision=no \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
A full training run takes 2 hours on 4xV100 GPUs.
|
||||
|
|
|
@ -4,10 +4,10 @@ import os
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from accelerate import Accelerator, DistributedDataParallelKwargs
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import load_dataset
|
||||
from diffusers import DDIMPipeline, DDIMScheduler, UNetModel
|
||||
from diffusers import DDPMPipeline, DDPMScheduler, UNetUnconditionalModel
|
||||
from diffusers.hub_utils import init_git_repo, push_to_hub
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
|
@ -27,25 +27,37 @@ logger = get_logger(__name__)
|
|||
|
||||
|
||||
def main(args):
|
||||
ddp_unused_params = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
||||
accelerator = Accelerator(
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with="tensorboard",
|
||||
logging_dir=logging_dir,
|
||||
kwargs_handlers=[ddp_unused_params],
|
||||
)
|
||||
|
||||
model = UNetModel(
|
||||
attn_resolutions=(16,),
|
||||
ch=128,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
dropout=0.0,
|
||||
model = UNetUnconditionalModel(
|
||||
image_size=args.resolution,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
num_res_blocks=2,
|
||||
resamp_with_conv=True,
|
||||
resolution=args.resolution,
|
||||
block_channels=(128, 128, 256, 256, 512, 512),
|
||||
down_blocks=(
|
||||
"UNetResDownBlock2D",
|
||||
"UNetResDownBlock2D",
|
||||
"UNetResDownBlock2D",
|
||||
"UNetResDownBlock2D",
|
||||
"UNetResAttnDownBlock2D",
|
||||
"UNetResDownBlock2D",
|
||||
),
|
||||
up_blocks=(
|
||||
"UNetResUpBlock2D",
|
||||
"UNetResAttnUpBlock2D",
|
||||
"UNetResUpBlock2D",
|
||||
"UNetResUpBlock2D",
|
||||
"UNetResUpBlock2D",
|
||||
"UNetResUpBlock2D",
|
||||
),
|
||||
)
|
||||
noise_scheduler = DDIMScheduler(timesteps=1000, tensor_format="pt")
|
||||
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt")
|
||||
optimizer = torch.optim.AdamW(
|
||||
model.parameters(),
|
||||
lr=args.learning_rate,
|
||||
|
@ -92,19 +104,6 @@ def main(args):
|
|||
run = os.path.split(__file__)[-1].split(".")[0]
|
||||
accelerator.init_trackers(run)
|
||||
|
||||
# Train!
|
||||
is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
|
||||
world_size = torch.distributed.get_world_size() if is_distributed else 1
|
||||
total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * world_size
|
||||
max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataloader.dataset)}")
|
||||
logger.info(f" Num Epochs = {args.num_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {max_steps}")
|
||||
|
||||
global_step = 0
|
||||
for epoch in range(args.num_epochs):
|
||||
model.train()
|
||||
|
@ -112,45 +111,37 @@ def main(args):
|
|||
progress_bar.set_description(f"Epoch {epoch}")
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
clean_images = batch["input"]
|
||||
noise_samples = torch.randn(clean_images.shape).to(clean_images.device)
|
||||
# Sample noise that we'll add to the images
|
||||
noise = torch.randn(clean_images.shape).to(clean_images.device)
|
||||
bsz = clean_images.shape[0]
|
||||
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long()
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0, noise_scheduler.num_train_timesteps, (bsz,), device=clean_images.device
|
||||
).long()
|
||||
|
||||
# add noise onto the clean images according to the noise magnitude at each timestep
|
||||
# Add noise to the clean images according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_images = noise_scheduler.add_noise(clean_images, noise_samples, timesteps)
|
||||
noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
|
||||
|
||||
if step % args.gradient_accumulation_steps != 0:
|
||||
with accelerator.no_sync(model):
|
||||
output = model(noisy_images, timesteps)
|
||||
# predict the noise residual
|
||||
loss = F.mse_loss(output, noise_samples)
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
accelerator.backward(loss)
|
||||
else:
|
||||
output = model(noisy_images, timesteps)
|
||||
# predict the noise residual
|
||||
loss = F.mse_loss(output, noise_samples)
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
with accelerator.accumulate(model):
|
||||
# Predict the noise residual
|
||||
noise_pred = model(noisy_images, timesteps)["sample"]
|
||||
loss = F.mse_loss(noise_pred, noise)
|
||||
accelerator.backward(loss)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
|
||||
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
ema_model.step(model)
|
||||
if args.use_ema:
|
||||
ema_model.step(model)
|
||||
optimizer.zero_grad()
|
||||
|
||||
progress_bar.update(1)
|
||||
progress_bar.set_postfix(
|
||||
loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"], ema_decay=ema_model.decay
|
||||
)
|
||||
accelerator.log(
|
||||
{
|
||||
"train_loss": loss.detach().item(),
|
||||
"epoch": epoch,
|
||||
"ema_decay": ema_model.decay,
|
||||
"step": global_step,
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
|
||||
if args.use_ema:
|
||||
logs["ema_decay"] = ema_model.decay
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
global_step += 1
|
||||
progress_bar.close()
|
||||
|
||||
|
@ -159,14 +150,14 @@ def main(args):
|
|||
# Generate a sample image for visual inspection
|
||||
if accelerator.is_main_process:
|
||||
with torch.no_grad():
|
||||
pipeline = DDIMPipeline(
|
||||
unet=accelerator.unwrap_model(ema_model.averaged_model),
|
||||
noise_scheduler=noise_scheduler,
|
||||
pipeline = DDPMPipeline(
|
||||
unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model),
|
||||
scheduler=noise_scheduler,
|
||||
)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
# run pipeline in inference (sample random noise and denoise)
|
||||
images = pipeline(generator=generator, batch_size=args.eval_batch_size, num_inference_steps=50)
|
||||
images = pipeline(generator=generator, batch_size=args.eval_batch_size)
|
||||
|
||||
# denormalize the images and save to tensorboard
|
||||
images_processed = (images.cpu() + 1.0) * 127.5
|
||||
|
@ -174,11 +165,12 @@ def main(args):
|
|||
|
||||
accelerator.trackers[0].writer.add_images("test_samples", images_processed, epoch)
|
||||
|
||||
# save the model
|
||||
if args.push_to_hub:
|
||||
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
|
||||
else:
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
|
||||
# save the model
|
||||
if args.push_to_hub:
|
||||
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
|
||||
else:
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
accelerator.end_training()
|
||||
|
@ -188,12 +180,13 @@ if __name__ == "__main__":
|
|||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument("--local_rank", type=int, default=-1)
|
||||
parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories")
|
||||
parser.add_argument("--output_dir", type=str, default="ddpm-model")
|
||||
parser.add_argument("--output_dir", type=str, default="ddpm-flowers-64")
|
||||
parser.add_argument("--overwrite_output_dir", action="store_true")
|
||||
parser.add_argument("--resolution", type=int, default=64)
|
||||
parser.add_argument("--train_batch_size", type=int, default=16)
|
||||
parser.add_argument("--eval_batch_size", type=int, default=16)
|
||||
parser.add_argument("--num_epochs", type=int, default=100)
|
||||
parser.add_argument("--save_model_epochs", type=int, default=5)
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4)
|
||||
parser.add_argument("--lr_scheduler", type=str, default="cosine")
|
||||
|
@ -202,6 +195,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--adam_beta2", type=float, default=0.999)
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-3)
|
||||
parser.add_argument("--use_ema", action="store_true", default=True)
|
||||
parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
|
||||
parser.add_argument("--ema_power", type=float, default=3 / 4)
|
||||
parser.add_argument("--ema_max_decay", type=float, default=0.9999)
|
||||
|
|
Loading…
Reference in New Issue