print args after cleaning, set attn slicing for sd15 if not using amp

This commit is contained in:
Victor Hall 2023-04-16 18:48:44 -04:00
parent e3e30a5599
commit 743c7cccae
1 changed files with 8 additions and 4 deletions

View File

@ -368,6 +368,8 @@ def main(args):
"""
log_time = setup_local_logger(args)
args = setup_args(args)
print(f" Args:")
pprint.pprint(vars(args))
if args.notebook:
from tqdm.notebook import tqdm
@ -484,8 +486,11 @@ def main(args):
logging.warning("failed to load xformers, using attention slicing instead")
unet.set_attention_slice("auto")
pass
elif (not args.amp and is_sd1attn):
logging.info("AMP is disabled but model is SD1.X, using attention slicing instead of xformers")
unet.set_attention_slice("auto")
else:
logging.info("xformers disabled, using attention slicing instead")
logging.info("xformers disabled via arg, using attention slicing instead")
unet.set_attention_slice("auto")
vae = vae.to(device, dtype=torch.float16 if args.amp else torch.float32)
@ -967,7 +972,7 @@ if __name__ == "__main__":
print("No config file specified, using command line args")
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
argparser.add_argument("--amp", action="store_true", default=False, help="deprecated, use --disable_amp if you wish to disable AMP")
argparser.add_argument("--amp", action="store_true", default=True, help="deprecated, use --disable_amp if you wish to disable AMP")
argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)")
argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20")
argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?")
@ -1016,6 +1021,5 @@ if __name__ == "__main__":
# load CLI args to overwrite existing config args
args = argparser.parse_args(args=argv, namespace=args)
print(f" Args:")
pprint.pprint(vars(args))
main(args)