diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 53b0d532..af3d0ddc 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -173,6 +173,16 @@ def parse_args(): parser.add_argument( "--hub_private_repo", action="store_true", help="Whether or not to create a private repository." ) + parser.add_argument( + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + help=( + "Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)" + " for experiment tracking and logging of model metrics and model checkpoints" + ), + ) parser.add_argument( "--logging_dir", type=str, @@ -248,7 +258,7 @@ def main(args): accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, - log_with="tensorboard", + log_with=args.logger, logging_dir=logging_dir, ) @@ -477,9 +487,11 @@ def main(args): # denormalize the images and save to tensorboard images_processed = (images * 255).round().astype("uint8") - accelerator.trackers[0].writer.add_images( - "test_samples", images_processed.transpose(0, 3, 1, 2), epoch - ) + + if args.logger == "tensorboard": + accelerator.get_tracker("tensorboard").add_images( + "test_samples", images_processed.transpose(0, 3, 1, 2), epoch + ) if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: # save the model