diff --git a/train.py b/train.py index 04194f1..a7495b7 100644 --- a/train.py +++ b/train.py @@ -45,7 +45,6 @@ from accelerate.utils import set_seed import wandb from torch.utils.tensorboard import SummaryWriter -from tqdm.auto import tqdm import keyboard @@ -278,7 +277,12 @@ def main(args): """ log_time = setup_local_logger(args) args = setup_args(args) - + + if args.notebook: + from tqdm.notebook import tqdm + else: + from tqdm.auto import tqdm + seed = args.seed if args.seed != -1 else random.randint(0, 2**30) set_seed(seed) gpu = GPU() @@ -292,7 +296,7 @@ def main(args): os.makedirs(log_folder) @torch.no_grad() - def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir): + def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, save_full_precision=False): """ Save the model to disk """ @@ -317,9 +321,11 @@ def main(args): sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path) else: sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path) + + half = not save_full_precision logging.info(f" * Saving SD model to {sd_ckpt_full}") - converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=True) + converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half) # optimizer_path = os.path.join(save_path, "optimizer.pt") # if self.save_optimizer_flag: @@ -575,7 +581,7 @@ def main(args): logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}") logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}") time.sleep(2) # give opportunity to ctrl-C again to cancel save - __save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir) + __save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision) exit(_SIGTERM_EXIT_CODE) signal.signal(signal.SIGINT, sigterm_handler) @@ -771,7 +777,7 @@ def main(args): append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs) torch.cuda.empty_cache() - if keyboard.is_pressed("ctrl+alt+page up") or ((global_step + 1) % args.sample_steps == 0): + if (not args.notebook and keyboard.is_pressed("ctrl+alt+page up")) or ((global_step + 1) % args.sample_steps == 0): pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=sample_scheduler, vae=vae) pipe = pipe.to(device) @@ -793,12 +799,12 @@ def main(args): last_epoch_saved_time = time.time() logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}") save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}") - __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir) + __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision) if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 1 and epoch < args.max_epochs - 1: logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}") save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}") - __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir) + __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision) del batch global_step += 1 @@ -819,7 +825,7 @@ def main(args): # end of training save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}") - __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir) + __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision) total_elapsed_time = time.time() - training_start_time logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}") @@ -829,7 +835,7 @@ def main(args): except Exception as ex: logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}") save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}") - __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir) + __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision) raise ex logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}") @@ -844,6 +850,12 @@ def update_old_args(t_args): if not hasattr(t_args, "shuffle_tags"): print(f" Config json is missing 'shuffle_tags'") t_args.__dict__["shuffle_tags"] = False + if not hasattr(t_args, "save_full_precision"): + print(f" Config json is missing 'save_full_precision'") + t_args.__dict__["save_full_precision"] = False + if not hasattr(t_args, "notebook"): + print(f" Config json is missing 'notebook'") + t_args.__dict__["notebook"] = False if __name__ == "__main__": supported_resolutions = [256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152] @@ -898,6 +910,8 @@ if __name__ == "__main__": argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!") argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY") argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)") + argparser.add_argument("--save_full_precision", action="store_true", default=False, help="save ckpts at full FP32") + argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)") args = argparser.parse_args()