defaulting amp to on now

This commit is contained in:
Victor Hall 2023-03-10 21:35:47 -05:00
parent c829521c34
commit 9389a90c67
1 changed files with 12 additions and 5 deletions

View File

@ -178,12 +178,12 @@ def set_args_12gb(args):
if not args.gradient_checkpointing:
logging.info(" - Overiding gradient checkpointing to True")
args.gradient_checkpointing = True
if args.batch_size != 1:
logging.info(" - Overiding batch size to 1")
args.batch_size = 1
if args.batch_size > 2:
logging.info(" - Overiding batch size to max 2")
args.batch_size = 2
args.grad_accum = 1
if args.resolution > 512:
logging.info(" - Overiding resolution to 512")
logging.info(" - Overiding resolution to max 512")
args.resolution = 512
def find_last_checkpoint(logdir):
@ -214,6 +214,12 @@ def setup_args(args):
Sets defaults for missing args (possible if missing from json config)
Forces some args to be set based on others for compatibility reasons
"""
if args.disable_amp:
logging.warning(f"{Fore.LIGHTYELLOW_EX} Disabling AMP, not recommended.{Style.RESET_ALL}")
args.amp= False
else:
args.amp = True
if args.disable_unet_training and args.disable_textenc_training:
raise ValueError("Both unet and textenc are disabled, nothing to train")
@ -939,13 +945,14 @@ if __name__ == "__main__":
print("No config file specified, using command line args")
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
argparser.add_argument("--amp", action="store_true", default=False, help="Enables automatic mixed precision compute, recommended on")
argparser.add_argument("--amp", action="store_true", default=False, help="deprecated, use --disable_amp if you wish to disable AMP")
argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)")
argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20")
argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?")
argparser.add_argument("--clip_skip", type=int, default=0, help="Train using penultimate layer (def: 0) (2 is 'penultimate')", choices=[0, 1, 2, 3, 4])
argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)")
argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are")
argparser.add_argument("--disable_amp", action="store_true", default=False, help="disables training of text encoder (def: False)")
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False)")
argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED")
argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)")