print args after cleaning, set attn slicing for sd15 if not using amp
This commit is contained in:
parent
e3e30a5599
commit
743c7cccae
12
train.py
12
train.py
|
@ -368,6 +368,8 @@ def main(args):
|
|||
"""
|
||||
log_time = setup_local_logger(args)
|
||||
args = setup_args(args)
|
||||
print(f" Args:")
|
||||
pprint.pprint(vars(args))
|
||||
|
||||
if args.notebook:
|
||||
from tqdm.notebook import tqdm
|
||||
|
@ -484,8 +486,11 @@ def main(args):
|
|||
logging.warning("failed to load xformers, using attention slicing instead")
|
||||
unet.set_attention_slice("auto")
|
||||
pass
|
||||
elif (not args.amp and is_sd1attn):
|
||||
logging.info("AMP is disabled but model is SD1.X, using attention slicing instead of xformers")
|
||||
unet.set_attention_slice("auto")
|
||||
else:
|
||||
logging.info("xformers disabled, using attention slicing instead")
|
||||
logging.info("xformers disabled via arg, using attention slicing instead")
|
||||
unet.set_attention_slice("auto")
|
||||
|
||||
vae = vae.to(device, dtype=torch.float16 if args.amp else torch.float32)
|
||||
|
@ -967,7 +972,7 @@ if __name__ == "__main__":
|
|||
print("No config file specified, using command line args")
|
||||
|
||||
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
||||
argparser.add_argument("--amp", action="store_true", default=False, help="deprecated, use --disable_amp if you wish to disable AMP")
|
||||
argparser.add_argument("--amp", action="store_true", default=True, help="deprecated, use --disable_amp if you wish to disable AMP")
|
||||
argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)")
|
||||
argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20")
|
||||
argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?")
|
||||
|
@ -1016,6 +1021,5 @@ if __name__ == "__main__":
|
|||
|
||||
# load CLI args to overwrite existing config args
|
||||
args = argparser.parse_args(args=argv, namespace=args)
|
||||
print(f" Args:")
|
||||
pprint.pprint(vars(args))
|
||||
|
||||
main(args)
|
||||
|
|
Loading…
Reference in New Issue