diff --git a/train.py b/train.py index 4d06ce1..1949fc8 100644 --- a/train.py +++ b/train.py @@ -178,12 +178,12 @@ def set_args_12gb(args): if not args.gradient_checkpointing: logging.info(" - Overiding gradient checkpointing to True") args.gradient_checkpointing = True - if args.batch_size != 1: - logging.info(" - Overiding batch size to 1") - args.batch_size = 1 + if args.batch_size > 2: + logging.info(" - Overiding batch size to max 2") + args.batch_size = 2 args.grad_accum = 1 if args.resolution > 512: - logging.info(" - Overiding resolution to 512") + logging.info(" - Overiding resolution to max 512") args.resolution = 512 def find_last_checkpoint(logdir): @@ -214,6 +214,12 @@ def setup_args(args): Sets defaults for missing args (possible if missing from json config) Forces some args to be set based on others for compatibility reasons """ + if args.disable_amp: + logging.warning(f"{Fore.LIGHTYELLOW_EX} Disabling AMP, not recommended.{Style.RESET_ALL}") + args.amp= False + else: + args.amp = True + if args.disable_unet_training and args.disable_textenc_training: raise ValueError("Both unet and textenc are disabled, nothing to train") @@ -939,13 +945,14 @@ if __name__ == "__main__": print("No config file specified, using command line args") argparser = argparse.ArgumentParser(description="EveryDream2 Training options") - argparser.add_argument("--amp", action="store_true", default=False, help="Enables automatic mixed precision compute, recommended on") + argparser.add_argument("--amp", action="store_true", default=False, help="deprecated, use --disable_amp if you wish to disable AMP") argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)") argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20") argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?") argparser.add_argument("--clip_skip", type=int, default=0, help="Train using penultimate layer (def: 0) (2 is 'penultimate')", choices=[0, 1, 2, 3, 4]) argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)") argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are") + argparser.add_argument("--disable_amp", action="store_true", default=False, help="disables training of text encoder (def: False)") argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False)") argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED") argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)")