Merge pull request #137 from tjennings/main
support for saving the optimizer state
This commit is contained in:
commit
f9958320f6
|
@ -13,3 +13,4 @@
|
|||
/.vscode/**
|
||||
.ssh_config
|
||||
*inference*.yaml
|
||||
.idea
|
||||
|
|
29
train.py
29
train.py
|
@ -394,7 +394,7 @@ def main(args):
|
|||
os.makedirs(log_folder)
|
||||
|
||||
@torch.no_grad()
|
||||
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, yaml_name, save_full_precision=False):
|
||||
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, optimizer, save_ckpt_dir, yaml_name, save_full_precision=False, save_optimizer_flag=False):
|
||||
"""
|
||||
Save the model to disk
|
||||
"""
|
||||
|
@ -432,10 +432,11 @@ def main(args):
|
|||
logging.info(f" * Saving yaml to {yaml_save_path}")
|
||||
shutil.copyfile(yaml_name, yaml_save_path)
|
||||
|
||||
# optimizer_path = os.path.join(save_path, "optimizer.pt")
|
||||
# if self.save_optimizer_flag:
|
||||
# logging.info(f" Saving optimizer state to {save_path}")
|
||||
# self.save_optimizer(self.ctx.optimizer, optimizer_path)
|
||||
|
||||
if save_optimizer_flag:
|
||||
optimizer_path = os.path.join(save_path, "optimizer.pt")
|
||||
logging.info(f" Saving optimizer state to {save_path}")
|
||||
save_optimizer(optimizer, optimizer_path)
|
||||
|
||||
try:
|
||||
|
||||
|
@ -446,6 +447,10 @@ def main(args):
|
|||
text_encoder = CLIPTextModel.from_pretrained(model_root_folder, subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained(model_root_folder, subfolder="vae")
|
||||
unet = UNet2DConditionModel.from_pretrained(model_root_folder, subfolder="unet")
|
||||
|
||||
optimizer_state_path = os.path.join(args.resume_ckpt, "optimizer.pt")
|
||||
if not os.path.exists(optimizer_state_path):
|
||||
optimizer_state_path = None
|
||||
else:
|
||||
# try to download from HF using resume_ckpt as a repo id
|
||||
downloaded = try_download_model_from_hf(repo_id=args.resume_ckpt)
|
||||
|
@ -588,6 +593,10 @@ def main(args):
|
|||
amsgrad=False,
|
||||
)
|
||||
|
||||
if optimizer_state_path is not None:
|
||||
logging.info(f"Loading optimizer state from {optimizer_state_path}")
|
||||
load_optimizer(optimizer, optimizer_state_path)
|
||||
|
||||
log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr, curr_text_encoder_lr)
|
||||
|
||||
image_train_items = resolve_image_train_items(args, log_folder)
|
||||
|
@ -673,7 +682,7 @@ def main(args):
|
|||
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
|
||||
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
|
||||
time.sleep(2) # give opportunity to ctrl-C again to cancel save
|
||||
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
|
||||
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, args.save_full_precision, args.save_optimizer)
|
||||
exit(_SIGTERM_EXIT_CODE)
|
||||
else:
|
||||
# non-main threads (i.e. dataloader workers) should exit cleanly
|
||||
|
@ -887,12 +896,12 @@ def main(args):
|
|||
last_epoch_saved_time = time.time()
|
||||
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
|
||||
|
||||
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:
|
||||
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
|
||||
|
||||
del batch
|
||||
global_step += 1
|
||||
|
@ -922,7 +931,7 @@ def main(args):
|
|||
# end of training
|
||||
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
|
||||
|
||||
total_elapsed_time = time.time() - training_start_time
|
||||
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
|
||||
|
@ -932,7 +941,7 @@ def main(args):
|
|||
except Exception as ex:
|
||||
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
|
||||
raise ex
|
||||
|
||||
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
|
||||
|
|
Loading…
Reference in New Issue