Merge pull request #137 from tjennings/main

support for saving the optimizer state
This commit is contained in:
Victor Hall 2023-04-14 14:21:35 -04:00 committed by GitHub
commit f9958320f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 45 additions and 35 deletions

1
.gitignore vendored
View File

@ -13,3 +13,4 @@
/.vscode/** /.vscode/**
.ssh_config .ssh_config
*inference*.yaml *inference*.yaml
.idea

View File

@ -394,7 +394,7 @@ def main(args):
os.makedirs(log_folder) os.makedirs(log_folder)
@torch.no_grad() @torch.no_grad()
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, yaml_name, save_full_precision=False): def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, optimizer, save_ckpt_dir, yaml_name, save_full_precision=False, save_optimizer_flag=False):
""" """
Save the model to disk Save the model to disk
""" """
@ -432,10 +432,11 @@ def main(args):
logging.info(f" * Saving yaml to {yaml_save_path}") logging.info(f" * Saving yaml to {yaml_save_path}")
shutil.copyfile(yaml_name, yaml_save_path) shutil.copyfile(yaml_name, yaml_save_path)
# optimizer_path = os.path.join(save_path, "optimizer.pt")
# if self.save_optimizer_flag: if save_optimizer_flag:
# logging.info(f" Saving optimizer state to {save_path}") optimizer_path = os.path.join(save_path, "optimizer.pt")
# self.save_optimizer(self.ctx.optimizer, optimizer_path) logging.info(f" Saving optimizer state to {save_path}")
save_optimizer(optimizer, optimizer_path)
try: try:
@ -446,6 +447,10 @@ def main(args):
text_encoder = CLIPTextModel.from_pretrained(model_root_folder, subfolder="text_encoder") text_encoder = CLIPTextModel.from_pretrained(model_root_folder, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(model_root_folder, subfolder="vae") vae = AutoencoderKL.from_pretrained(model_root_folder, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(model_root_folder, subfolder="unet") unet = UNet2DConditionModel.from_pretrained(model_root_folder, subfolder="unet")
optimizer_state_path = os.path.join(args.resume_ckpt, "optimizer.pt")
if not os.path.exists(optimizer_state_path):
optimizer_state_path = None
else: else:
# try to download from HF using resume_ckpt as a repo id # try to download from HF using resume_ckpt as a repo id
downloaded = try_download_model_from_hf(repo_id=args.resume_ckpt) downloaded = try_download_model_from_hf(repo_id=args.resume_ckpt)
@ -588,6 +593,10 @@ def main(args):
amsgrad=False, amsgrad=False,
) )
if optimizer_state_path is not None:
logging.info(f"Loading optimizer state from {optimizer_state_path}")
load_optimizer(optimizer, optimizer_state_path)
log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr, curr_text_encoder_lr) log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr, curr_text_encoder_lr)
image_train_items = resolve_image_train_items(args, log_folder) image_train_items = resolve_image_train_items(args, log_folder)
@ -673,7 +682,7 @@ def main(args):
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}") logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}") logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
time.sleep(2) # give opportunity to ctrl-C again to cancel save time.sleep(2) # give opportunity to ctrl-C again to cancel save
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision) __save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, args.save_full_precision, args.save_optimizer)
exit(_SIGTERM_EXIT_CODE) exit(_SIGTERM_EXIT_CODE)
else: else:
# non-main threads (i.e. dataloader workers) should exit cleanly # non-main threads (i.e. dataloader workers) should exit cleanly
@ -887,12 +896,12 @@ def main(args):
last_epoch_saved_time = time.time() last_epoch_saved_time = time.time()
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}") logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}") save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision) __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs: if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}") logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}") save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision) __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
del batch del batch
global_step += 1 global_step += 1
@ -922,7 +931,7 @@ def main(args):
# end of training # end of training
save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}") save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision) __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
total_elapsed_time = time.time() - training_start_time total_elapsed_time = time.time() - training_start_time
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}") logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
@ -932,7 +941,7 @@ def main(args):
except Exception as ex: except Exception as ex:
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}") logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}") save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision) __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
raise ex raise ex
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}") logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")