From a2772fc668bd9b9a924efdf8b70dd50d0b163a4e Mon Sep 17 00:00:00 2001 From: chavinlo <85657083+chavinlo@users.noreply.github.com> Date: Wed, 16 Nov 2022 10:55:38 -0500 Subject: [PATCH] fixes --- trainer/diffusers_trainer.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index b5698b4..8ed36e4 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -84,8 +84,8 @@ parser.add_argument('--clip_penultimate', type=bool_t, default='False', help='Us parser.add_argument('--output_bucket_info', type=bool_t, default='False', help='Outputs bucket information and exits') parser.add_argument('--resize', type=bool_t, default='False', help="Resizes dataset's images to the appropriate bucket dimensions.") parser.add_argument('--use_xformers', type=bool_t, default='False', help='Use memory efficient attention') -parser.add_argument('--wandb', dest='enablewandb', type=str, default='True', help='Enable WeightsAndBiases Reporting') -parser.add_argument('--inference', dest='enableinference', type=str, default='True', help='Enable Inference during training (Consumes 2GB of VRAM)') +parser.add_argument('--wandb', dest='enablewandb', type=bool_t, default='True', help='Enable WeightsAndBiases Reporting') +parser.add_argument('--inference', dest='enableinference', type=bool_t, default='True', help='Enable Inference during training (Consumes 2GB of VRAM)') args = parser.parse_args() def setup(): @@ -523,10 +523,11 @@ def main(): if rank == 0: os.makedirs(args.output_path, exist_ok=True) + mode = 'enabled' if args.enablewandb: - run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb') - else: - run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb', mode="disabled") + mode = 'disabled' + + run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb', mode=mode) # Inform the user of host, and various versions -- useful for debugging issues. print("RUN_NAME:", args.run_name) @@ -539,10 +540,13 @@ def main(): print("FP16:", args.fp16) print("RESOLUTION:", args.resolution) - if args.hf_token is None: + + if args.hf_token is not None: + print('It is recommended to set the HF_API_TOKEN environment variable instead of passing it as a command line argument since WandB will automatically log it.') + else: try: args.hf_token = os.environ['HF_API_TOKEN'] - print('It is recommended to set the HF_API_TOKEN environment variable instead of passing it as a command line argument since WandB will automatically log it.') + print("HF Token set via enviroment variable") except Exception: print("No HF Token detected in arguments or enviroment variable, setting it to none (as in string)") args.hf_token = "none"