From 743c7cccaed3a2f53d2c318daa8956d999cabf6f Mon Sep 17 00:00:00 2001 From: Victor Hall Date: Sun, 16 Apr 2023 18:48:44 -0400 Subject: [PATCH] print args after cleaning, set attn slicing for sd15 if not using amp --- train.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index f00cbf5..b0a6ca8 100644 --- a/train.py +++ b/train.py @@ -368,6 +368,8 @@ def main(args): """ log_time = setup_local_logger(args) args = setup_args(args) + print(f" Args:") + pprint.pprint(vars(args)) if args.notebook: from tqdm.notebook import tqdm @@ -484,8 +486,11 @@ def main(args): logging.warning("failed to load xformers, using attention slicing instead") unet.set_attention_slice("auto") pass + elif (not args.amp and is_sd1attn): + logging.info("AMP is disabled but model is SD1.X, using attention slicing instead of xformers") + unet.set_attention_slice("auto") else: - logging.info("xformers disabled, using attention slicing instead") + logging.info("xformers disabled via arg, using attention slicing instead") unet.set_attention_slice("auto") vae = vae.to(device, dtype=torch.float16 if args.amp else torch.float32) @@ -967,7 +972,7 @@ if __name__ == "__main__": print("No config file specified, using command line args") argparser = argparse.ArgumentParser(description="EveryDream2 Training options") - argparser.add_argument("--amp", action="store_true", default=False, help="deprecated, use --disable_amp if you wish to disable AMP") + argparser.add_argument("--amp", action="store_true", default=True, help="deprecated, use --disable_amp if you wish to disable AMP") argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)") argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20") argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?") @@ -1016,6 +1021,5 @@ if __name__ == "__main__": # load CLI args to overwrite existing config args args = argparser.parse_args(args=argv, namespace=args) - print(f" Args:") - pprint.pprint(vars(args)) + main(args)