From 927880e1fce9e8b27a480c548b7197bd5b2f8f56 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Mon, 6 Feb 2023 07:40:59 +0100 Subject: [PATCH 1/2] allow cli args to override config values --- train.py | 101 ++++++++++++++++++++++++++++--------------------------- 1 file changed, 51 insertions(+), 50 deletions(-) diff --git a/train.py b/train.py index e11ab99..df67761 100644 --- a/train.py +++ b/train.py @@ -995,62 +995,63 @@ if __name__ == "__main__": supported_precisions = ['fp16', 'fp32'] argparser = argparse.ArgumentParser(description="EveryDream2 Training options") argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from") - args, _ = argparser.parse_known_args() + args, argv = argparser.parse_known_args() if args.config is not None: - print(f"Loading training config from {args.config}, all other command options will be ignored!") + print(f"Loading training config from {args.config}.") with open(args.config, 'rt') as f: - t_args = argparse.Namespace() - t_args.__dict__.update(json.load(f)) - update_old_args(t_args) # update args to support older configs - args = argparser.parse_args(namespace=t_args) + args.__dict__.update(json.load(f)) + update_old_args(args) # update args to support older configs + if len(argv) > 0: + print(f"Config .json loaded but there are additional CLI arguments -- these will override values in {args.config}.") else: print("No config file specified, using command line args") - argparser = argparse.ArgumentParser(description="EveryDream2 Training options") - argparser.add_argument("--amp", action="store_true", default=False, help="Enables automatic mixed precision compute, recommended on") - argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)") - argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20") - argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?") - argparser.add_argument("--clip_skip", type=int, default=0, help="Train using penultimate layer (def: 0) (2 is 'penultimate')", choices=[0, 1, 2, 3, 4]) - argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)") - argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are") - argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False)") - argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED") - argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)") - argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!") - argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1)") - argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)") - argparser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation factor (def: 1), (ex, 2)") - argparser.add_argument("--hf_repo_subfolder", type=str, default=None, help="Subfolder inside the huggingface repo to download, if the model is not in the root of the repo.") - argparser.add_argument("--logdir", type=str, default="logs", help="folder to save logs to (def: logs)") - argparser.add_argument("--log_step", type=int, default=25, help="How often to log training stats, def: 25, recommend default!") - argparser.add_argument("--lowvram", action="store_true", default=False, help="automatically overrides various args to support 12GB gpu") - argparser.add_argument("--lr", type=float, default=None, help="Learning rate, if using scheduler is maximum LR at top of curve") - argparser.add_argument("--lr_decay_steps", type=int, default=0, help="Steps to reach minimum LR, default: automatically set") - argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"]) - argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant") - argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for") - argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)") - argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'") - argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions) - argparser.add_argument("--resume_ckpt", type=str, required=True, default="sd_v1-5_vae.ckpt", help="The checkpoint to resume from, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1 ") - argparser.add_argument("--sample_prompts", type=str, default="sample_prompts.txt", help="File with prompts to generate test samples from (def: sample_prompts.txt)") - argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)") - argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)") - argparser.add_argument("--save_every_n_epochs", type=int, default=None, help="Save checkpoint every n epochs, def: 0 (disabled)") - argparser.add_argument("--save_full_precision", action="store_true", default=False, help="save ckpts at full FP32") - argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later") - argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)") - argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random") - argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets") - argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!") - argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY") - argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)") - argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs") - argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)") + argparser = argparse.ArgumentParser(description="EveryDream2 Training options") + argparser.add_argument("--amp", action="store_true", default=False, help="Enables automatic mixed precision compute, recommended on") + argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)") + argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20") + argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?") + argparser.add_argument("--clip_skip", type=int, default=0, help="Train using penultimate layer (def: 0) (2 is 'penultimate')", choices=[0, 1, 2, 3, 4]) + argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)") + argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are") + argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False)") + argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED") + argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)") + argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!") + argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1)") + argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)") + argparser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation factor (def: 1), (ex, 2)") + argparser.add_argument("--hf_repo_subfolder", type=str, default=None, help="Subfolder inside the huggingface repo to download, if the model is not in the root of the repo.") + argparser.add_argument("--logdir", type=str, default="logs", help="folder to save logs to (def: logs)") + argparser.add_argument("--log_step", type=int, default=25, help="How often to log training stats, def: 25, recommend default!") + argparser.add_argument("--lowvram", action="store_true", default=False, help="automatically overrides various args to support 12GB gpu") + argparser.add_argument("--lr", type=float, default=None, help="Learning rate, if using scheduler is maximum LR at top of curve") + argparser.add_argument("--lr_decay_steps", type=int, default=0, help="Steps to reach minimum LR, default: automatically set") + argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"]) + argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant") + argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for") + argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)") + argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'") + argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions) + argparser.add_argument("--resume_ckpt", type=str, required=not ('resume_ckpt' in args), default="sd_v1-5_vae.ckpt", help="The checkpoint to resume from, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1 ") + argparser.add_argument("--sample_prompts", type=str, default="sample_prompts.txt", help="File with prompts to generate test samples from (def: sample_prompts.txt)") + argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)") + argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)") + argparser.add_argument("--save_every_n_epochs", type=int, default=None, help="Save checkpoint every n epochs, def: 0 (disabled)") + argparser.add_argument("--save_full_precision", action="store_true", default=False, help="save ckpts at full FP32") + argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later") + argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)") + argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random") + argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets") + argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!") + argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY") + argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)") + argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs") + argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)") - args, _ = argparser.parse_known_args() + # load CLI args to overwrite existing config args + args = argparser.parse_args(args=argv, namespace=args) print(f" Args:") - pprint.pprint(args.__dict__) + pprint.pprint(vars(args)) main(args) From e0ca75cc96372ee735cb6ed635b5ed7e4e582fc6 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Tue, 7 Feb 2023 13:46:19 +0100 Subject: [PATCH 2/2] remove redundant update_old_args() function --- train.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/train.py b/train.py index df67761..8508ea1 100644 --- a/train.py +++ b/train.py @@ -966,29 +966,6 @@ def main(args): logging.info(f"{Fore.LIGHTWHITE_EX} **** Finished training ****{Style.RESET_ALL}") logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}") -def update_old_args(t_args): - """ - Update old args to new args to deal with json config loading and missing args for compatibility - """ - if not hasattr(t_args, "shuffle_tags"): - print(f" Config json is missing 'shuffle_tags' flag") - t_args.__dict__["shuffle_tags"] = False - if not hasattr(t_args, "save_full_precision"): - print(f" Config json is missing 'save_full_precision' flag") - t_args.__dict__["save_full_precision"] = False - if not hasattr(t_args, "notebook"): - print(f" Config json is missing 'notebook' flag") - t_args.__dict__["notebook"] = False - if not hasattr(t_args, "disable_unet_training"): - print(f" Config json is missing 'disable_unet_training' flag") - t_args.__dict__["disable_unet_training"] = False - if not hasattr(t_args, "rated_dataset"): - print(f" Config json is missing 'rated_dataset' flag") - t_args.__dict__["rated_dataset"] = False - if not hasattr(t_args, "rated_dataset_target_dropout_percent"): - print(f" Config json is missing 'rated_dataset_target_dropout_percent' flag") - t_args.__dict__["rated_dataset_target_dropout_percent"] = 50 - if __name__ == "__main__": supported_resolutions = [256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152] @@ -1001,7 +978,6 @@ if __name__ == "__main__": print(f"Loading training config from {args.config}.") with open(args.config, 'rt') as f: args.__dict__.update(json.load(f)) - update_old_args(args) # update args to support older configs if len(argv) > 0: print(f"Config .json loaded but there are additional CLI arguments -- these will override values in {args.config}.") else: