update train_unconditional_ort.py (#1775)

* reflect changes

* run make style

Co-authored-by: Prathik Rao <prathikrao@microsoft.com>
Co-authored-by: Prathik Rao <prathikrao@microsoft.com@orttrainingdev7.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
This commit is contained in:
Prathik Rao 2022-12-19 14:58:55 -08:00 committed by GitHub
parent 9f8c915a75
commit 847daf25c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 65 additions and 6 deletions

View File

@ -174,6 +174,16 @@ def parse_args():
parser.add_argument(
"--hub_private_repo", action="store_true", help="Whether or not to create a private repository."
)
parser.add_argument(
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
help=(
"Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)"
" for experiment tracking and logging of model metrics and model checkpoints"
),
)
parser.add_argument(
"--logging_dir",
type=str,
@ -195,7 +205,6 @@ def parse_args():
"and an Nvidia Ampere GPU."
),
)
parser.add_argument(
"--prediction_type",
type=str,
@ -206,6 +215,24 @@ def parse_args():
parser.add_argument("--ddpm_num_steps", type=int, default=1000)
parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=(
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
" training using `--resume_from_checkpoint`."
),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@ -233,7 +260,7 @@ def main(args):
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with="tensorboard",
log_with=args.logger,
logging_dir=logging_dir,
)
@ -321,6 +348,7 @@ def main(args):
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
accelerator.register_for_checkpointing(lr_scheduler)
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@ -353,11 +381,34 @@ def main(args):
accelerator.init_trackers(run)
global_step = 0
for epoch in range(args.num_epochs):
first_epoch = 0
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1]
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps
first_epoch = resume_global_step // num_update_steps_per_epoch
resume_step = resume_global_step % num_update_steps_per_epoch
for epoch in range(first_epoch, args.num_epochs):
model.train()
progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)
progress_bar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
clean_images = batch["input"]
# Sample noise that we'll add to the images
noise = torch.randn(clean_images.shape).to(clean_images.device)
@ -404,6 +455,12 @@ def main(args):
progress_bar.update(1)
global_step += 1
if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
if args.use_ema:
logs["ema_decay"] = ema_model.decay
@ -431,7 +488,9 @@ def main(args):
# denormalize the images and save to tensorboard
images_processed = (images * 255).round().astype("uint8")
accelerator.trackers[0].writer.add_images(
if args.logger == "tensorboard":
accelerator.get_tracker("tensorboard").add_images(
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
)