This commit is contained in:
Victor Hall 2022-12-18 20:52:58 -05:00
parent 8908fa75a9
commit e2ee3da452
1 changed files with 3 additions and 1 deletions

View File

@ -152,6 +152,7 @@ def main(args):
gpu = GPU() gpu = GPU()
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None: if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
logging.info(" no checkpointing specified, defaulting to 20 minutes")
args.ckpt_every_n_minutes = 20 args.ckpt_every_n_minutes = 20
if args.ckpt_every_n_minutes is None or args.ckpt_every_n_minutes < 1: if args.ckpt_every_n_minutes is None or args.ckpt_every_n_minutes < 1:
@ -436,7 +437,7 @@ def main(args):
interrupted=True interrupted=True
global global_step global global_step
#TODO: save model on ctrl-c #TODO: save model on ctrl-c
interrupted_checkpoint_path = os.path.join(f"logs/{log_folder}/interrupted-gs{global_step}.ckpt") interrupted_checkpoint_path = os.path.join(f"{log_folder}/interrupted-gs{global_step}")
print() print()
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}") logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}") logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
@ -452,6 +453,7 @@ def main(args):
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory() gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
logging.info(f" Pretraining GPU Memory: {gpu_used_mem} / {gpu_total_mem} MB") logging.info(f" Pretraining GPU Memory: {gpu_used_mem} / {gpu_total_mem} MB")
logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes") logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes")
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
scaler = torch.cuda.amp.GradScaler( scaler = torch.cuda.amp.GradScaler(
enabled=False, enabled=False,