Merge pull request #137 from tjennings/main

support for saving the optimizer state
This commit is contained in:
Victor Hall 2023-04-14 14:21:35 -04:00 committed by GitHub
commit f9958320f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 45 additions and 35 deletions

1
.gitignore vendored
View File

@ -13,3 +13,4 @@
/.vscode/**
.ssh_config
*inference*.yaml
.idea

View File

@ -394,7 +394,7 @@ def main(args):
os.makedirs(log_folder)
@torch.no_grad()
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, yaml_name, save_full_precision=False):
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, optimizer, save_ckpt_dir, yaml_name, save_full_precision=False, save_optimizer_flag=False):
"""
Save the model to disk
"""
@ -432,10 +432,11 @@ def main(args):
logging.info(f" * Saving yaml to {yaml_save_path}")
shutil.copyfile(yaml_name, yaml_save_path)
# optimizer_path = os.path.join(save_path, "optimizer.pt")
# if self.save_optimizer_flag:
# logging.info(f" Saving optimizer state to {save_path}")
# self.save_optimizer(self.ctx.optimizer, optimizer_path)
if save_optimizer_flag:
optimizer_path = os.path.join(save_path, "optimizer.pt")
logging.info(f" Saving optimizer state to {save_path}")
save_optimizer(optimizer, optimizer_path)
try:
@ -446,6 +447,10 @@ def main(args):
text_encoder = CLIPTextModel.from_pretrained(model_root_folder, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(model_root_folder, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(model_root_folder, subfolder="unet")
optimizer_state_path = os.path.join(args.resume_ckpt, "optimizer.pt")
if not os.path.exists(optimizer_state_path):
optimizer_state_path = None
else:
# try to download from HF using resume_ckpt as a repo id
downloaded = try_download_model_from_hf(repo_id=args.resume_ckpt)
@ -588,6 +593,10 @@ def main(args):
amsgrad=False,
)
if optimizer_state_path is not None:
logging.info(f"Loading optimizer state from {optimizer_state_path}")
load_optimizer(optimizer, optimizer_state_path)
log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr, curr_text_encoder_lr)
image_train_items = resolve_image_train_items(args, log_folder)
@ -673,7 +682,7 @@ def main(args):
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
time.sleep(2) # give opportunity to ctrl-C again to cancel save
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, args.save_full_precision, args.save_optimizer)
exit(_SIGTERM_EXIT_CODE)
else:
# non-main threads (i.e. dataloader workers) should exit cleanly
@ -887,12 +896,12 @@ def main(args):
last_epoch_saved_time = time.time()
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
del batch
global_step += 1
@ -922,7 +931,7 @@ def main(args):
# end of training
save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
total_elapsed_time = time.time() - training_start_time
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
@ -932,7 +941,7 @@ def main(args):
except Exception as ex:
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
raise ex
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")