This commit is contained in:
chavinlo 2022-11-16 10:55:38 -05:00
parent fed3431f03
commit a2772fc668
1 changed files with 11 additions and 7 deletions

View File

@ -84,8 +84,8 @@ parser.add_argument('--clip_penultimate', type=bool_t, default='False', help='Us
parser.add_argument('--output_bucket_info', type=bool_t, default='False', help='Outputs bucket information and exits') parser.add_argument('--output_bucket_info', type=bool_t, default='False', help='Outputs bucket information and exits')
parser.add_argument('--resize', type=bool_t, default='False', help="Resizes dataset's images to the appropriate bucket dimensions.") parser.add_argument('--resize', type=bool_t, default='False', help="Resizes dataset's images to the appropriate bucket dimensions.")
parser.add_argument('--use_xformers', type=bool_t, default='False', help='Use memory efficient attention') parser.add_argument('--use_xformers', type=bool_t, default='False', help='Use memory efficient attention')
parser.add_argument('--wandb', dest='enablewandb', type=str, default='True', help='Enable WeightsAndBiases Reporting') parser.add_argument('--wandb', dest='enablewandb', type=bool_t, default='True', help='Enable WeightsAndBiases Reporting')
parser.add_argument('--inference', dest='enableinference', type=str, default='True', help='Enable Inference during training (Consumes 2GB of VRAM)') parser.add_argument('--inference', dest='enableinference', type=bool_t, default='True', help='Enable Inference during training (Consumes 2GB of VRAM)')
args = parser.parse_args() args = parser.parse_args()
def setup(): def setup():
@ -523,10 +523,11 @@ def main():
if rank == 0: if rank == 0:
os.makedirs(args.output_path, exist_ok=True) os.makedirs(args.output_path, exist_ok=True)
mode = 'enabled'
if args.enablewandb: if args.enablewandb:
run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb') mode = 'disabled'
else:
run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb', mode="disabled") run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb', mode=mode)
# Inform the user of host, and various versions -- useful for debugging issues. # Inform the user of host, and various versions -- useful for debugging issues.
print("RUN_NAME:", args.run_name) print("RUN_NAME:", args.run_name)
@ -539,10 +540,13 @@ def main():
print("FP16:", args.fp16) print("FP16:", args.fp16)
print("RESOLUTION:", args.resolution) print("RESOLUTION:", args.resolution)
if args.hf_token is None:
if args.hf_token is not None:
print('It is recommended to set the HF_API_TOKEN environment variable instead of passing it as a command line argument since WandB will automatically log it.')
else:
try: try:
args.hf_token = os.environ['HF_API_TOKEN'] args.hf_token = os.environ['HF_API_TOKEN']
print('It is recommended to set the HF_API_TOKEN environment variable instead of passing it as a command line argument since WandB will automatically log it.') print("HF Token set via enviroment variable")
except Exception: except Exception:
print("No HF Token detected in arguments or enviroment variable, setting it to none (as in string)") print("No HF Token detected in arguments or enviroment variable, setting it to none (as in string)")
args.hf_token = "none" args.hf_token = "none"