enable xformers for sd1 models if amp enabled
This commit is contained in:
parent
879f5bf33d
commit
ed025d27b6
9
train.py
9
train.py
|
@ -230,9 +230,9 @@ def setup_args(args):
|
||||||
# find the last checkpoint in the logdir
|
# find the last checkpoint in the logdir
|
||||||
args.resume_ckpt = find_last_checkpoint(args.logdir)
|
args.resume_ckpt = find_last_checkpoint(args.logdir)
|
||||||
|
|
||||||
if args.ed1_mode and not args.disable_xformers:
|
if args.ed1_mode and not args.amp and not args.disable_xformers:
|
||||||
args.disable_xformers = True
|
args.disable_xformers = True
|
||||||
logging.info(" ED1 mode: Overiding disable_xformers to True")
|
logging.info(" ED1 mode without amp: Overiding disable_xformers to True")
|
||||||
|
|
||||||
if args.lowvram:
|
if args.lowvram:
|
||||||
set_args_12gb(args)
|
set_args_12gb(args)
|
||||||
|
@ -909,14 +909,13 @@ if __name__ == "__main__":
|
||||||
with open(args.config, 'rt') as f:
|
with open(args.config, 'rt') as f:
|
||||||
t_args = argparse.Namespace()
|
t_args = argparse.Namespace()
|
||||||
t_args.__dict__.update(json.load(f))
|
t_args.__dict__.update(json.load(f))
|
||||||
print(t_args.__dict__)
|
|
||||||
update_old_args(t_args) # update args to support older configs
|
update_old_args(t_args) # update args to support older configs
|
||||||
print(t_args.__dict__)
|
print(f" args: \n{t_args.__dict__}")
|
||||||
args = argparser.parse_args(namespace=t_args)
|
args = argparser.parse_args(namespace=t_args)
|
||||||
else:
|
else:
|
||||||
print("No config file specified, using command line args")
|
print("No config file specified, using command line args")
|
||||||
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
||||||
argparser.add_argument("--amp", action="store_true", default=False, help="use floating point 16 bit training, experimental, reduces quality")
|
argparser.add_argument("--amp", action="store_true", default=False, help="Enables automatic mixed precision compute, recommended on")
|
||||||
argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)")
|
argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)")
|
||||||
argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20")
|
argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20")
|
||||||
argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?")
|
argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?")
|
||||||
|
|
Loading…
Reference in New Issue