diff --git a/train.py b/train.py index 5a0e20e..6d3c157 100644 --- a/train.py +++ b/train.py @@ -691,7 +691,10 @@ def main(args): logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}") time.sleep(2) # give opportunity to ctrl-C again to cancel save __save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision) - exit(_SIGTERM_EXIT_CODE) + exit(_SIGTERM_EXIT_CODE) + else: + # non-main threads (i.e. dataloader workers) should exit cleanly + exit(0) signal.signal(signal.SIGINT, sigterm_handler)