defaulting amp to on now
This commit is contained in:
parent
c829521c34
commit
9389a90c67
17
train.py
17
train.py
|
@ -178,12 +178,12 @@ def set_args_12gb(args):
|
|||
if not args.gradient_checkpointing:
|
||||
logging.info(" - Overiding gradient checkpointing to True")
|
||||
args.gradient_checkpointing = True
|
||||
if args.batch_size != 1:
|
||||
logging.info(" - Overiding batch size to 1")
|
||||
args.batch_size = 1
|
||||
if args.batch_size > 2:
|
||||
logging.info(" - Overiding batch size to max 2")
|
||||
args.batch_size = 2
|
||||
args.grad_accum = 1
|
||||
if args.resolution > 512:
|
||||
logging.info(" - Overiding resolution to 512")
|
||||
logging.info(" - Overiding resolution to max 512")
|
||||
args.resolution = 512
|
||||
|
||||
def find_last_checkpoint(logdir):
|
||||
|
@ -214,6 +214,12 @@ def setup_args(args):
|
|||
Sets defaults for missing args (possible if missing from json config)
|
||||
Forces some args to be set based on others for compatibility reasons
|
||||
"""
|
||||
if args.disable_amp:
|
||||
logging.warning(f"{Fore.LIGHTYELLOW_EX} Disabling AMP, not recommended.{Style.RESET_ALL}")
|
||||
args.amp= False
|
||||
else:
|
||||
args.amp = True
|
||||
|
||||
if args.disable_unet_training and args.disable_textenc_training:
|
||||
raise ValueError("Both unet and textenc are disabled, nothing to train")
|
||||
|
||||
|
@ -939,13 +945,14 @@ if __name__ == "__main__":
|
|||
print("No config file specified, using command line args")
|
||||
|
||||
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
||||
argparser.add_argument("--amp", action="store_true", default=False, help="Enables automatic mixed precision compute, recommended on")
|
||||
argparser.add_argument("--amp", action="store_true", default=False, help="deprecated, use --disable_amp if you wish to disable AMP")
|
||||
argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)")
|
||||
argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20")
|
||||
argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?")
|
||||
argparser.add_argument("--clip_skip", type=int, default=0, help="Train using penultimate layer (def: 0) (2 is 'penultimate')", choices=[0, 1, 2, 3, 4])
|
||||
argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)")
|
||||
argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are")
|
||||
argparser.add_argument("--disable_amp", action="store_true", default=False, help="disables training of text encoder (def: False)")
|
||||
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False)")
|
||||
argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED")
|
||||
argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)")
|
||||
|
|
Loading…
Reference in New Issue