Merge pull request #48 from damian0815/feat_cli_args_override_json_file
CLI args override values in JSON file
This commit is contained in:
commit
7f6145098a
46
train.py
46
train.py
|
@ -965,48 +965,23 @@ def main(args):
|
||||||
logging.info(f"{Fore.LIGHTWHITE_EX} **** Finished training ****{Style.RESET_ALL}")
|
logging.info(f"{Fore.LIGHTWHITE_EX} **** Finished training ****{Style.RESET_ALL}")
|
||||||
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
|
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
|
||||||
|
|
||||||
def update_old_args(t_args):
|
|
||||||
"""
|
|
||||||
Update old args to new args to deal with json config loading and missing args for compatibility
|
|
||||||
"""
|
|
||||||
if not hasattr(t_args, "shuffle_tags"):
|
|
||||||
print(f" Config json is missing 'shuffle_tags' flag")
|
|
||||||
t_args.__dict__["shuffle_tags"] = False
|
|
||||||
if not hasattr(t_args, "save_full_precision"):
|
|
||||||
print(f" Config json is missing 'save_full_precision' flag")
|
|
||||||
t_args.__dict__["save_full_precision"] = False
|
|
||||||
if not hasattr(t_args, "notebook"):
|
|
||||||
print(f" Config json is missing 'notebook' flag")
|
|
||||||
t_args.__dict__["notebook"] = False
|
|
||||||
if not hasattr(t_args, "disable_unet_training"):
|
|
||||||
print(f" Config json is missing 'disable_unet_training' flag")
|
|
||||||
t_args.__dict__["disable_unet_training"] = False
|
|
||||||
if not hasattr(t_args, "rated_dataset"):
|
|
||||||
print(f" Config json is missing 'rated_dataset' flag")
|
|
||||||
t_args.__dict__["rated_dataset"] = False
|
|
||||||
if not hasattr(t_args, "rated_dataset_target_dropout_percent"):
|
|
||||||
print(f" Config json is missing 'rated_dataset_target_dropout_percent' flag")
|
|
||||||
t_args.__dict__["rated_dataset_target_dropout_percent"] = 50
|
|
||||||
if not hasattr(t_args, "validation_config"):
|
|
||||||
print(f" Config json is missing 'validation_config'")
|
|
||||||
t_args.__dict__["validation_config"] = None
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
supported_resolutions = [256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152]
|
supported_resolutions = [256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152]
|
||||||
supported_precisions = ['fp16', 'fp32']
|
supported_precisions = ['fp16', 'fp32']
|
||||||
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
||||||
argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from")
|
argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from")
|
||||||
args, _ = argparser.parse_known_args()
|
args, argv = argparser.parse_known_args()
|
||||||
|
|
||||||
if args.config is not None:
|
if args.config is not None:
|
||||||
print(f"Loading training config from {args.config}, all other command options will be ignored!")
|
print(f"Loading training config from {args.config}.")
|
||||||
with open(args.config, 'rt') as f:
|
with open(args.config, 'rt') as f:
|
||||||
t_args = argparse.Namespace()
|
args.__dict__.update(json.load(f))
|
||||||
t_args.__dict__.update(json.load(f))
|
if len(argv) > 0:
|
||||||
update_old_args(t_args) # update args to support older configs
|
print(f"Config .json loaded but there are additional CLI arguments -- these will override values in {args.config}.")
|
||||||
args = argparser.parse_args(namespace=t_args)
|
|
||||||
else:
|
else:
|
||||||
print("No config file specified, using command line args")
|
print("No config file specified, using command line args")
|
||||||
|
|
||||||
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
||||||
argparser.add_argument("--amp", action="store_true", default=False, help="Enables automatic mixed precision compute, recommended on")
|
argparser.add_argument("--amp", action="store_true", default=False, help="Enables automatic mixed precision compute, recommended on")
|
||||||
argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)")
|
argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)")
|
||||||
|
@ -1034,7 +1009,7 @@ if __name__ == "__main__":
|
||||||
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
|
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
|
||||||
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
|
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
|
||||||
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
|
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
|
||||||
argparser.add_argument("--resume_ckpt", type=str, required=True, default="sd_v1-5_vae.ckpt", help="The checkpoint to resume from, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1 ")
|
argparser.add_argument("--resume_ckpt", type=str, required=not ('resume_ckpt' in args), default="sd_v1-5_vae.ckpt", help="The checkpoint to resume from, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1 ")
|
||||||
argparser.add_argument("--sample_prompts", type=str, default="sample_prompts.txt", help="File with prompts to generate test samples from (def: sample_prompts.txt)")
|
argparser.add_argument("--sample_prompts", type=str, default="sample_prompts.txt", help="File with prompts to generate test samples from (def: sample_prompts.txt)")
|
||||||
argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)")
|
argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)")
|
||||||
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
|
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
|
||||||
|
@ -1045,14 +1020,13 @@ if __name__ == "__main__":
|
||||||
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
|
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
|
||||||
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
|
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
|
||||||
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
|
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
|
||||||
argparser.add_argument("--validation_config", type=str, default="validation_default.json", help="validation config file (def: validation_config.json)")
|
|
||||||
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
|
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
|
||||||
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
|
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
|
||||||
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
|
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
|
||||||
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
|
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
|
||||||
|
|
||||||
args, _ = argparser.parse_known_args()
|
# load CLI args to overwrite existing config args
|
||||||
|
args = argparser.parse_args(args=argv, namespace=args)
|
||||||
print(f" Args:")
|
print(f" Args:")
|
||||||
pprint.pprint(args.__dict__)
|
pprint.pprint(vars(args))
|
||||||
main(args)
|
main(args)
|
||||||
|
|
Loading…
Reference in New Issue