stuff
This commit is contained in:
parent
8908fa75a9
commit
e2ee3da452
4
train.py
4
train.py
|
@ -152,6 +152,7 @@ def main(args):
|
||||||
gpu = GPU()
|
gpu = GPU()
|
||||||
|
|
||||||
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
|
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
|
||||||
|
logging.info(" no checkpointing specified, defaulting to 20 minutes")
|
||||||
args.ckpt_every_n_minutes = 20
|
args.ckpt_every_n_minutes = 20
|
||||||
|
|
||||||
if args.ckpt_every_n_minutes is None or args.ckpt_every_n_minutes < 1:
|
if args.ckpt_every_n_minutes is None or args.ckpt_every_n_minutes < 1:
|
||||||
|
@ -436,7 +437,7 @@ def main(args):
|
||||||
interrupted=True
|
interrupted=True
|
||||||
global global_step
|
global global_step
|
||||||
#TODO: save model on ctrl-c
|
#TODO: save model on ctrl-c
|
||||||
interrupted_checkpoint_path = os.path.join(f"logs/{log_folder}/interrupted-gs{global_step}.ckpt")
|
interrupted_checkpoint_path = os.path.join(f"{log_folder}/interrupted-gs{global_step}")
|
||||||
print()
|
print()
|
||||||
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
|
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
|
||||||
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
|
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
|
||||||
|
@ -452,6 +453,7 @@ def main(args):
|
||||||
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
|
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
|
||||||
logging.info(f" Pretraining GPU Memory: {gpu_used_mem} / {gpu_total_mem} MB")
|
logging.info(f" Pretraining GPU Memory: {gpu_used_mem} / {gpu_total_mem} MB")
|
||||||
logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes")
|
logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes")
|
||||||
|
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
|
||||||
|
|
||||||
scaler = torch.cuda.amp.GradScaler(
|
scaler = torch.cuda.amp.GradScaler(
|
||||||
enabled=False,
|
enabled=False,
|
||||||
|
|
Loading…
Reference in New Issue