fixes
This commit is contained in:
parent
fed3431f03
commit
a2772fc668
|
@ -84,8 +84,8 @@ parser.add_argument('--clip_penultimate', type=bool_t, default='False', help='Us
|
||||||
parser.add_argument('--output_bucket_info', type=bool_t, default='False', help='Outputs bucket information and exits')
|
parser.add_argument('--output_bucket_info', type=bool_t, default='False', help='Outputs bucket information and exits')
|
||||||
parser.add_argument('--resize', type=bool_t, default='False', help="Resizes dataset's images to the appropriate bucket dimensions.")
|
parser.add_argument('--resize', type=bool_t, default='False', help="Resizes dataset's images to the appropriate bucket dimensions.")
|
||||||
parser.add_argument('--use_xformers', type=bool_t, default='False', help='Use memory efficient attention')
|
parser.add_argument('--use_xformers', type=bool_t, default='False', help='Use memory efficient attention')
|
||||||
parser.add_argument('--wandb', dest='enablewandb', type=str, default='True', help='Enable WeightsAndBiases Reporting')
|
parser.add_argument('--wandb', dest='enablewandb', type=bool_t, default='True', help='Enable WeightsAndBiases Reporting')
|
||||||
parser.add_argument('--inference', dest='enableinference', type=str, default='True', help='Enable Inference during training (Consumes 2GB of VRAM)')
|
parser.add_argument('--inference', dest='enableinference', type=bool_t, default='True', help='Enable Inference during training (Consumes 2GB of VRAM)')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
def setup():
|
def setup():
|
||||||
|
@ -523,10 +523,11 @@ def main():
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
os.makedirs(args.output_path, exist_ok=True)
|
os.makedirs(args.output_path, exist_ok=True)
|
||||||
|
|
||||||
|
mode = 'enabled'
|
||||||
if args.enablewandb:
|
if args.enablewandb:
|
||||||
run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb')
|
mode = 'disabled'
|
||||||
else:
|
|
||||||
run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb', mode="disabled")
|
run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb', mode=mode)
|
||||||
|
|
||||||
# Inform the user of host, and various versions -- useful for debugging issues.
|
# Inform the user of host, and various versions -- useful for debugging issues.
|
||||||
print("RUN_NAME:", args.run_name)
|
print("RUN_NAME:", args.run_name)
|
||||||
|
@ -539,10 +540,13 @@ def main():
|
||||||
print("FP16:", args.fp16)
|
print("FP16:", args.fp16)
|
||||||
print("RESOLUTION:", args.resolution)
|
print("RESOLUTION:", args.resolution)
|
||||||
|
|
||||||
if args.hf_token is None:
|
|
||||||
|
if args.hf_token is not None:
|
||||||
|
print('It is recommended to set the HF_API_TOKEN environment variable instead of passing it as a command line argument since WandB will automatically log it.')
|
||||||
|
else:
|
||||||
try:
|
try:
|
||||||
args.hf_token = os.environ['HF_API_TOKEN']
|
args.hf_token = os.environ['HF_API_TOKEN']
|
||||||
print('It is recommended to set the HF_API_TOKEN environment variable instead of passing it as a command line argument since WandB will automatically log it.')
|
print("HF Token set via enviroment variable")
|
||||||
except Exception:
|
except Exception:
|
||||||
print("No HF Token detected in arguments or enviroment variable, setting it to none (as in string)")
|
print("No HF Token detected in arguments or enviroment variable, setting it to none (as in string)")
|
||||||
args.hf_token = "none"
|
args.hf_token = "none"
|
||||||
|
|
Loading…
Reference in New Issue