option for save every n epochs or every n minutes with warning if both
This commit is contained in:
parent
2015778279
commit
c20f262b92
27
train.py
27
train.py
|
@ -51,6 +51,7 @@ from utils.gpu import GPU
|
|||
|
||||
_GRAD_ACCUM_STEPS = 1 # future use...
|
||||
_SIGTERM_EXIT_CODE = 130
|
||||
_VERY_LARGE_NUMBER = 1e9
|
||||
|
||||
def convert_to_hf(ckpt_path):
|
||||
hf_cache = os.path.join("ckpt_cache", os.path.basename(ckpt_path))
|
||||
|
@ -59,8 +60,12 @@ def convert_to_hf(ckpt_path):
|
|||
if not os.path.exists(hf_cache):
|
||||
os.makedirs(hf_cache)
|
||||
logging.info(f"Converting {ckpt_path} to Diffusers format")
|
||||
import utils.convert_original_stable_diffusion_to_diffusers as convert
|
||||
convert.convert(ckpt_path, f"ckpt_cache/{ckpt_path}")
|
||||
try:
|
||||
import utils.convert_original_stable_diffusion_to_diffusers as convert
|
||||
convert.convert(ckpt_path, f"ckpt_cache/{ckpt_path}")
|
||||
except:
|
||||
logging.info("Please manually convert the checkpoint to Diffusers format, see readme.")
|
||||
exit()
|
||||
return hf_cache
|
||||
elif os.path.isdir(hf_cache):
|
||||
return hf_cache
|
||||
|
@ -147,8 +152,15 @@ def main(args):
|
|||
set_seed(seed)
|
||||
gpu = GPU()
|
||||
|
||||
if args.ckpt_every_n_minutes < 1:
|
||||
args.ckpt_every_n_minutes = 99999
|
||||
if args.ckpt_every_n_minutes is None or args.ckpt_every_n_minutes < 1:
|
||||
args.ckpt_every_n_minutes = _VERY_LARGE_NUMBER
|
||||
|
||||
if args.save_every_n_epochs is None or args.save_every_n_epochs < 1:
|
||||
args.save_every_n_epochs = _VERY_LARGE_NUMBER
|
||||
|
||||
if args.save_every_n_epochs < _VERY_LARGE_NUMBER and args.ckpt_every_n_minutes < _VERY_LARGE_NUMBER:
|
||||
logging.warning(f"{Fore.YELLOW}Both save_every_n_epochs and ckpt_every_n_minutes are set, this will potentially spam a lot of checkpoints{Style.RESET_ALL}")
|
||||
logging.warning(f"{Fore.YELLOW}save_every_n_epochs: {args.save_every_n_epochs}, ckpt_every_n_minutes: {args.ckpt_every_n_minutes}{Style.RESET_ALL}")
|
||||
|
||||
if args.cond_dropout > 0.26:
|
||||
logging.warning(f"{Fore.YELLOW}cond_dropout is set fairly high: {args.cond_dropout}, make sure this was intended{Style.RESET_ALL}")
|
||||
|
@ -425,10 +437,9 @@ def main(args):
|
|||
interrupted_checkpoint_path = os.path.join(f"logs/{log_folder}/interrupted-gs{global_step}.ckpt")
|
||||
print()
|
||||
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
|
||||
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, exiting{Style.RESET_ALL}")
|
||||
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
|
||||
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
|
||||
save_path = os.path.join(f"logs/interrupted.ckpt")
|
||||
#__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, scheduler, vae)
|
||||
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, scheduler, vae)
|
||||
exit(_SIGTERM_EXIT_CODE)
|
||||
|
||||
signal.signal(signal.SIGINT, sigterm_handler)
|
||||
|
@ -667,7 +678,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--log_step", type=int, default=25, help="How often to log training stats, def: 25, recommend default")
|
||||
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
|
||||
argparser.add_argument("--ckpt_every_n_minutes", type=int, default=20, help="Save checkpoint every n minutes, def: 20")
|
||||
argparser.add_argument("--save_every_n_epochs", type=int, default=9999, help="Save checkpoint every n epochs, def: 9999")
|
||||
argparser.add_argument("--save_every_n_epochs", type=int, default=0, help="Save checkpoint every n epochs, def: 0 (disabled)")
|
||||
argparser.add_argument("--lr", type=float, default=None, help="Learning rate, if using scheduler is maximum LR at top of curve")
|
||||
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer")
|
||||
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
|
||||
|
|
Loading…
Reference in New Issue