Merge pull request #48 from damian0815/feat_cli_args_override_json_file

CLI args override values in JSON file
This commit is contained in:
Victor Hall 2023-02-07 20:31:21 -05:00 committed by GitHub
commit 7f6145098a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 50 additions and 76 deletions

View File

@ -965,48 +965,23 @@ def main(args):
logging.info(f"{Fore.LIGHTWHITE_EX} **** Finished training ****{Style.RESET_ALL}") logging.info(f"{Fore.LIGHTWHITE_EX} **** Finished training ****{Style.RESET_ALL}")
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}") logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
def update_old_args(t_args):
"""
Update old args to new args to deal with json config loading and missing args for compatibility
"""
if not hasattr(t_args, "shuffle_tags"):
print(f" Config json is missing 'shuffle_tags' flag")
t_args.__dict__["shuffle_tags"] = False
if not hasattr(t_args, "save_full_precision"):
print(f" Config json is missing 'save_full_precision' flag")
t_args.__dict__["save_full_precision"] = False
if not hasattr(t_args, "notebook"):
print(f" Config json is missing 'notebook' flag")
t_args.__dict__["notebook"] = False
if not hasattr(t_args, "disable_unet_training"):
print(f" Config json is missing 'disable_unet_training' flag")
t_args.__dict__["disable_unet_training"] = False
if not hasattr(t_args, "rated_dataset"):
print(f" Config json is missing 'rated_dataset' flag")
t_args.__dict__["rated_dataset"] = False
if not hasattr(t_args, "rated_dataset_target_dropout_percent"):
print(f" Config json is missing 'rated_dataset_target_dropout_percent' flag")
t_args.__dict__["rated_dataset_target_dropout_percent"] = 50
if not hasattr(t_args, "validation_config"):
print(f" Config json is missing 'validation_config'")
t_args.__dict__["validation_config"] = None
if __name__ == "__main__": if __name__ == "__main__":
supported_resolutions = [256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152] supported_resolutions = [256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152]
supported_precisions = ['fp16', 'fp32'] supported_precisions = ['fp16', 'fp32']
argparser = argparse.ArgumentParser(description="EveryDream2 Training options") argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from") argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from")
args, _ = argparser.parse_known_args() args, argv = argparser.parse_known_args()
if args.config is not None: if args.config is not None:
print(f"Loading training config from {args.config}, all other command options will be ignored!") print(f"Loading training config from {args.config}.")
with open(args.config, 'rt') as f: with open(args.config, 'rt') as f:
t_args = argparse.Namespace() args.__dict__.update(json.load(f))
t_args.__dict__.update(json.load(f)) if len(argv) > 0:
update_old_args(t_args) # update args to support older configs print(f"Config .json loaded but there are additional CLI arguments -- these will override values in {args.config}.")
args = argparser.parse_args(namespace=t_args)
else: else:
print("No config file specified, using command line args") print("No config file specified, using command line args")
argparser = argparse.ArgumentParser(description="EveryDream2 Training options") argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
argparser.add_argument("--amp", action="store_true", default=False, help="Enables automatic mixed precision compute, recommended on") argparser.add_argument("--amp", action="store_true", default=False, help="Enables automatic mixed precision compute, recommended on")
argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)") argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)")
@ -1034,7 +1009,7 @@ if __name__ == "__main__":
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)") argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'") argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions) argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
argparser.add_argument("--resume_ckpt", type=str, required=True, default="sd_v1-5_vae.ckpt", help="The checkpoint to resume from, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1 ") argparser.add_argument("--resume_ckpt", type=str, required=not ('resume_ckpt' in args), default="sd_v1-5_vae.ckpt", help="The checkpoint to resume from, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1 ")
argparser.add_argument("--sample_prompts", type=str, default="sample_prompts.txt", help="File with prompts to generate test samples from (def: sample_prompts.txt)") argparser.add_argument("--sample_prompts", type=str, default="sample_prompts.txt", help="File with prompts to generate test samples from (def: sample_prompts.txt)")
argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)") argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)")
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)") argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
@ -1045,14 +1020,13 @@ if __name__ == "__main__":
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random") argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets") argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!") argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
argparser.add_argument("--validation_config", type=str, default="validation_default.json", help="validation config file (def: validation_config.json)")
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY") argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)") argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs") argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)") argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
args, _ = argparser.parse_known_args() # load CLI args to overwrite existing config args
args = argparser.parse_args(args=argv, namespace=args)
print(f" Args:") print(f" Args:")
pprint.pprint(args.__dict__) pprint.pprint(vars(args))
main(args) main(args)