Merge pull request #137 from tjennings/main
support for saving the optimizer state
This commit is contained in:
commit
f9958320f6
|
@ -13,3 +13,4 @@
|
||||||
/.vscode/**
|
/.vscode/**
|
||||||
.ssh_config
|
.ssh_config
|
||||||
*inference*.yaml
|
*inference*.yaml
|
||||||
|
.idea
|
||||||
|
|
29
train.py
29
train.py
|
@ -394,7 +394,7 @@ def main(args):
|
||||||
os.makedirs(log_folder)
|
os.makedirs(log_folder)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, yaml_name, save_full_precision=False):
|
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, optimizer, save_ckpt_dir, yaml_name, save_full_precision=False, save_optimizer_flag=False):
|
||||||
"""
|
"""
|
||||||
Save the model to disk
|
Save the model to disk
|
||||||
"""
|
"""
|
||||||
|
@ -432,10 +432,11 @@ def main(args):
|
||||||
logging.info(f" * Saving yaml to {yaml_save_path}")
|
logging.info(f" * Saving yaml to {yaml_save_path}")
|
||||||
shutil.copyfile(yaml_name, yaml_save_path)
|
shutil.copyfile(yaml_name, yaml_save_path)
|
||||||
|
|
||||||
# optimizer_path = os.path.join(save_path, "optimizer.pt")
|
|
||||||
# if self.save_optimizer_flag:
|
if save_optimizer_flag:
|
||||||
# logging.info(f" Saving optimizer state to {save_path}")
|
optimizer_path = os.path.join(save_path, "optimizer.pt")
|
||||||
# self.save_optimizer(self.ctx.optimizer, optimizer_path)
|
logging.info(f" Saving optimizer state to {save_path}")
|
||||||
|
save_optimizer(optimizer, optimizer_path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
@ -446,6 +447,10 @@ def main(args):
|
||||||
text_encoder = CLIPTextModel.from_pretrained(model_root_folder, subfolder="text_encoder")
|
text_encoder = CLIPTextModel.from_pretrained(model_root_folder, subfolder="text_encoder")
|
||||||
vae = AutoencoderKL.from_pretrained(model_root_folder, subfolder="vae")
|
vae = AutoencoderKL.from_pretrained(model_root_folder, subfolder="vae")
|
||||||
unet = UNet2DConditionModel.from_pretrained(model_root_folder, subfolder="unet")
|
unet = UNet2DConditionModel.from_pretrained(model_root_folder, subfolder="unet")
|
||||||
|
|
||||||
|
optimizer_state_path = os.path.join(args.resume_ckpt, "optimizer.pt")
|
||||||
|
if not os.path.exists(optimizer_state_path):
|
||||||
|
optimizer_state_path = None
|
||||||
else:
|
else:
|
||||||
# try to download from HF using resume_ckpt as a repo id
|
# try to download from HF using resume_ckpt as a repo id
|
||||||
downloaded = try_download_model_from_hf(repo_id=args.resume_ckpt)
|
downloaded = try_download_model_from_hf(repo_id=args.resume_ckpt)
|
||||||
|
@ -588,6 +593,10 @@ def main(args):
|
||||||
amsgrad=False,
|
amsgrad=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if optimizer_state_path is not None:
|
||||||
|
logging.info(f"Loading optimizer state from {optimizer_state_path}")
|
||||||
|
load_optimizer(optimizer, optimizer_state_path)
|
||||||
|
|
||||||
log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr, curr_text_encoder_lr)
|
log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr, curr_text_encoder_lr)
|
||||||
|
|
||||||
image_train_items = resolve_image_train_items(args, log_folder)
|
image_train_items = resolve_image_train_items(args, log_folder)
|
||||||
|
@ -673,7 +682,7 @@ def main(args):
|
||||||
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
|
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
|
||||||
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
|
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
|
||||||
time.sleep(2) # give opportunity to ctrl-C again to cancel save
|
time.sleep(2) # give opportunity to ctrl-C again to cancel save
|
||||||
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
|
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, args.save_full_precision, args.save_optimizer)
|
||||||
exit(_SIGTERM_EXIT_CODE)
|
exit(_SIGTERM_EXIT_CODE)
|
||||||
else:
|
else:
|
||||||
# non-main threads (i.e. dataloader workers) should exit cleanly
|
# non-main threads (i.e. dataloader workers) should exit cleanly
|
||||||
|
@ -887,12 +896,12 @@ def main(args):
|
||||||
last_epoch_saved_time = time.time()
|
last_epoch_saved_time = time.time()
|
||||||
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
|
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
|
||||||
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
|
||||||
|
|
||||||
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:
|
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:
|
||||||
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
|
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
|
||||||
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
|
||||||
|
|
||||||
del batch
|
del batch
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
@ -922,7 +931,7 @@ def main(args):
|
||||||
# end of training
|
# end of training
|
||||||
|
|
||||||
save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
|
||||||
|
|
||||||
total_elapsed_time = time.time() - training_start_time
|
total_elapsed_time = time.time() - training_start_time
|
||||||
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
|
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
|
||||||
|
@ -932,7 +941,7 @@ def main(args):
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
|
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
|
||||||
save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
|
||||||
raise ex
|
raise ex
|
||||||
|
|
||||||
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
|
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
|
||||||
|
|
Loading…
Reference in New Issue