enable xformers for sd1 models if amp enabled
This commit is contained in:
parent
879f5bf33d
commit
ed025d27b6
9
train.py
9
train.py
|
@ -230,9 +230,9 @@ def setup_args(args):
|
|||
# find the last checkpoint in the logdir
|
||||
args.resume_ckpt = find_last_checkpoint(args.logdir)
|
||||
|
||||
if args.ed1_mode and not args.disable_xformers:
|
||||
if args.ed1_mode and not args.amp and not args.disable_xformers:
|
||||
args.disable_xformers = True
|
||||
logging.info(" ED1 mode: Overiding disable_xformers to True")
|
||||
logging.info(" ED1 mode without amp: Overiding disable_xformers to True")
|
||||
|
||||
if args.lowvram:
|
||||
set_args_12gb(args)
|
||||
|
@ -909,14 +909,13 @@ if __name__ == "__main__":
|
|||
with open(args.config, 'rt') as f:
|
||||
t_args = argparse.Namespace()
|
||||
t_args.__dict__.update(json.load(f))
|
||||
print(t_args.__dict__)
|
||||
update_old_args(t_args) # update args to support older configs
|
||||
print(t_args.__dict__)
|
||||
print(f" args: \n{t_args.__dict__}")
|
||||
args = argparser.parse_args(namespace=t_args)
|
||||
else:
|
||||
print("No config file specified, using command line args")
|
||||
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
||||
argparser.add_argument("--amp", action="store_true", default=False, help="use floating point 16 bit training, experimental, reduces quality")
|
||||
argparser.add_argument("--amp", action="store_true", default=False, help="Enables automatic mixed precision compute, recommended on")
|
||||
argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)")
|
||||
argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20")
|
||||
argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?")
|
||||
|
|
Loading…
Reference in New Issue