enable xformers for sd1 models if amp enabled

This commit is contained in:
Victor Hall 2023-01-16 19:11:41 -05:00
parent 879f5bf33d
commit ed025d27b6
1 changed files with 4 additions and 5 deletions

View File

@ -230,9 +230,9 @@ def setup_args(args):
# find the last checkpoint in the logdir # find the last checkpoint in the logdir
args.resume_ckpt = find_last_checkpoint(args.logdir) args.resume_ckpt = find_last_checkpoint(args.logdir)
if args.ed1_mode and not args.disable_xformers: if args.ed1_mode and not args.amp and not args.disable_xformers:
args.disable_xformers = True args.disable_xformers = True
logging.info(" ED1 mode: Overiding disable_xformers to True") logging.info(" ED1 mode without amp: Overiding disable_xformers to True")
if args.lowvram: if args.lowvram:
set_args_12gb(args) set_args_12gb(args)
@ -909,14 +909,13 @@ if __name__ == "__main__":
with open(args.config, 'rt') as f: with open(args.config, 'rt') as f:
t_args = argparse.Namespace() t_args = argparse.Namespace()
t_args.__dict__.update(json.load(f)) t_args.__dict__.update(json.load(f))
print(t_args.__dict__)
update_old_args(t_args) # update args to support older configs update_old_args(t_args) # update args to support older configs
print(t_args.__dict__) print(f" args: \n{t_args.__dict__}")
args = argparser.parse_args(namespace=t_args) args = argparser.parse_args(namespace=t_args)
else: else:
print("No config file specified, using command line args") print("No config file specified, using command line args")
argparser = argparse.ArgumentParser(description="EveryDream2 Training options") argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
argparser.add_argument("--amp", action="store_true", default=False, help="use floating point 16 bit training, experimental, reduces quality") argparser.add_argument("--amp", action="store_true", default=False, help="Enables automatic mixed precision compute, recommended on")
argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)") argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)")
argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20") argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20")
argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?") argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?")