From 847daf25c7e461795932099c5097eb8ac489645c Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Mon, 19 Dec 2022 14:58:55 -0800 Subject: [PATCH] update train_unconditional_ort.py (#1775) * reflect changes * run make style Co-authored-by: Prathik Rao Co-authored-by: Prathik Rao --- .../train_unconditional_ort.py | 71 +++++++++++++++++-- 1 file changed, 65 insertions(+), 6 deletions(-) diff --git a/examples/unconditional_image_generation/train_unconditional_ort.py b/examples/unconditional_image_generation/train_unconditional_ort.py index b5974b84..34b5434d 100644 --- a/examples/unconditional_image_generation/train_unconditional_ort.py +++ b/examples/unconditional_image_generation/train_unconditional_ort.py @@ -174,6 +174,16 @@ def parse_args(): parser.add_argument( "--hub_private_repo", action="store_true", help="Whether or not to create a private repository." ) + parser.add_argument( + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + help=( + "Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)" + " for experiment tracking and logging of model metrics and model checkpoints" + ), + ) parser.add_argument( "--logging_dir", type=str, @@ -195,7 +205,6 @@ def parse_args(): "and an Nvidia Ampere GPU." ), ) - parser.add_argument( "--prediction_type", type=str, @@ -206,6 +215,24 @@ def parse_args(): parser.add_argument("--ddpm_num_steps", type=int, default=1000) parser.add_argument("--ddpm_beta_schedule", type=str, default="linear") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -233,7 +260,7 @@ def main(args): accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, - log_with="tensorboard", + log_with=args.logger, logging_dir=logging_dir, ) @@ -321,6 +348,7 @@ def main(args): model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( model, optimizer, train_dataloader, lr_scheduler ) + accelerator.register_for_checkpointing(lr_scheduler) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -353,11 +381,34 @@ def main(args): accelerator.init_trackers(run) global_step = 0 - for epoch in range(args.num_epochs): + first_epoch = 0 + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = resume_global_step // num_update_steps_per_epoch + resume_step = resume_global_step % num_update_steps_per_epoch + + for epoch in range(first_epoch, args.num_epochs): model.train() progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process) progress_bar.set_description(f"Epoch {epoch}") for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + clean_images = batch["input"] # Sample noise that we'll add to the images noise = torch.randn(clean_images.shape).to(clean_images.device) @@ -404,6 +455,12 @@ def main(args): progress_bar.update(1) global_step += 1 + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} if args.use_ema: logs["ema_decay"] = ema_model.decay @@ -431,9 +488,11 @@ def main(args): # denormalize the images and save to tensorboard images_processed = (images * 255).round().astype("uint8") - accelerator.trackers[0].writer.add_images( - "test_samples", images_processed.transpose(0, 3, 1, 2), epoch - ) + + if args.logger == "tensorboard": + accelerator.get_tracker("tensorboard").add_images( + "test_samples", images_processed.transpose(0, 3, 1, 2), epoch + ) if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: # save the model