From ed025d27b639f909c42d00b66bd53ed34cc5f297 Mon Sep 17 00:00:00 2001 From: Victor Hall Date: Mon, 16 Jan 2023 19:11:41 -0500 Subject: [PATCH] enable xformers for sd1 models if amp enabled --- train.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/train.py b/train.py index a5de0b3..872a406 100644 --- a/train.py +++ b/train.py @@ -230,9 +230,9 @@ def setup_args(args): # find the last checkpoint in the logdir args.resume_ckpt = find_last_checkpoint(args.logdir) - if args.ed1_mode and not args.disable_xformers: + if args.ed1_mode and not args.amp and not args.disable_xformers: args.disable_xformers = True - logging.info(" ED1 mode: Overiding disable_xformers to True") + logging.info(" ED1 mode without amp: Overiding disable_xformers to True") if args.lowvram: set_args_12gb(args) @@ -909,14 +909,13 @@ if __name__ == "__main__": with open(args.config, 'rt') as f: t_args = argparse.Namespace() t_args.__dict__.update(json.load(f)) - print(t_args.__dict__) update_old_args(t_args) # update args to support older configs - print(t_args.__dict__) + print(f" args: \n{t_args.__dict__}") args = argparser.parse_args(namespace=t_args) else: print("No config file specified, using command line args") argparser = argparse.ArgumentParser(description="EveryDream2 Training options") - argparser.add_argument("--amp", action="store_true", default=False, help="use floating point 16 bit training, experimental, reduces quality") + argparser.add_argument("--amp", action="store_true", default=False, help="Enables automatic mixed precision compute, recommended on") argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)") argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20") argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?")