[Examples] Update train_unconditional.py to include logging argument for Wandb (#1719)
Update train_unconditional.py Add logger flag to choose between tensorboard and wandb
This commit is contained in:
parent
ce1c27adc8
commit
9f657f106d
|
@ -173,6 +173,16 @@ def parse_args():
|
|||
parser.add_argument(
|
||||
"--hub_private_repo", action="store_true", help="Whether or not to create a private repository."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logger",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
choices=["tensorboard", "wandb"],
|
||||
help=(
|
||||
"Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)"
|
||||
" for experiment tracking and logging of model metrics and model checkpoints"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
|
@ -248,7 +258,7 @@ def main(args):
|
|||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with="tensorboard",
|
||||
log_with=args.logger,
|
||||
logging_dir=logging_dir,
|
||||
)
|
||||
|
||||
|
@ -477,9 +487,11 @@ def main(args):
|
|||
|
||||
# denormalize the images and save to tensorboard
|
||||
images_processed = (images * 255).round().astype("uint8")
|
||||
accelerator.trackers[0].writer.add_images(
|
||||
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
|
||||
)
|
||||
|
||||
if args.logger == "tensorboard":
|
||||
accelerator.get_tracker("tensorboard").add_images(
|
||||
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
|
||||
)
|
||||
|
||||
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
|
||||
# save the model
|
||||
|
|
Loading…
Reference in New Issue