option for save every n epochs or every n minutes with warning if both

This commit is contained in:
Victor Hall 2022-12-18 17:24:54 -05:00
parent 2015778279
commit c20f262b92
1 changed files with 19 additions and 8 deletions

View File

@ -51,6 +51,7 @@ from utils.gpu import GPU
_GRAD_ACCUM_STEPS = 1 # future use...
_SIGTERM_EXIT_CODE = 130
_VERY_LARGE_NUMBER = 1e9
def convert_to_hf(ckpt_path):
hf_cache = os.path.join("ckpt_cache", os.path.basename(ckpt_path))
@ -59,8 +60,12 @@ def convert_to_hf(ckpt_path):
if not os.path.exists(hf_cache):
os.makedirs(hf_cache)
logging.info(f"Converting {ckpt_path} to Diffusers format")
import utils.convert_original_stable_diffusion_to_diffusers as convert
convert.convert(ckpt_path, f"ckpt_cache/{ckpt_path}")
try:
import utils.convert_original_stable_diffusion_to_diffusers as convert
convert.convert(ckpt_path, f"ckpt_cache/{ckpt_path}")
except:
logging.info("Please manually convert the checkpoint to Diffusers format, see readme.")
exit()
return hf_cache
elif os.path.isdir(hf_cache):
return hf_cache
@ -147,8 +152,15 @@ def main(args):
set_seed(seed)
gpu = GPU()
if args.ckpt_every_n_minutes < 1:
args.ckpt_every_n_minutes = 99999
if args.ckpt_every_n_minutes is None or args.ckpt_every_n_minutes < 1:
args.ckpt_every_n_minutes = _VERY_LARGE_NUMBER
if args.save_every_n_epochs is None or args.save_every_n_epochs < 1:
args.save_every_n_epochs = _VERY_LARGE_NUMBER
if args.save_every_n_epochs < _VERY_LARGE_NUMBER and args.ckpt_every_n_minutes < _VERY_LARGE_NUMBER:
logging.warning(f"{Fore.YELLOW}Both save_every_n_epochs and ckpt_every_n_minutes are set, this will potentially spam a lot of checkpoints{Style.RESET_ALL}")
logging.warning(f"{Fore.YELLOW}save_every_n_epochs: {args.save_every_n_epochs}, ckpt_every_n_minutes: {args.ckpt_every_n_minutes}{Style.RESET_ALL}")
if args.cond_dropout > 0.26:
logging.warning(f"{Fore.YELLOW}cond_dropout is set fairly high: {args.cond_dropout}, make sure this was intended{Style.RESET_ALL}")
@ -425,10 +437,9 @@ def main(args):
interrupted_checkpoint_path = os.path.join(f"logs/{log_folder}/interrupted-gs{global_step}.ckpt")
print()
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, exiting{Style.RESET_ALL}")
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
save_path = os.path.join(f"logs/interrupted.ckpt")
#__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, scheduler, vae)
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, scheduler, vae)
exit(_SIGTERM_EXIT_CODE)
signal.signal(signal.SIGINT, sigterm_handler)
@ -667,7 +678,7 @@ if __name__ == "__main__":
argparser.add_argument("--log_step", type=int, default=25, help="How often to log training stats, def: 25, recommend default")
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
argparser.add_argument("--ckpt_every_n_minutes", type=int, default=20, help="Save checkpoint every n minutes, def: 20")
argparser.add_argument("--save_every_n_epochs", type=int, default=9999, help="Save checkpoint every n epochs, def: 9999")
argparser.add_argument("--save_every_n_epochs", type=int, default=0, help="Save checkpoint every n epochs, def: 0 (disabled)")
argparser.add_argument("--lr", type=float, default=None, help="Learning rate, if using scheduler is maximum LR at top of curve")
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer")
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")