fix ckpt save if only 1 step per epoch, revert epsilon for amp

This commit is contained in:
Victor Hall 2023-02-18 11:23:02 -05:00
parent 6e7e7f9e1f
commit 2f48460691
1 changed files with 5 additions and 4 deletions

View File

@ -599,7 +599,7 @@ def main(args):
betas = (0.9, 0.999)
epsilon = 1e-8
if args.amp:
epsilon = 2e-8
epsilon = 1e-8
weight_decay = 0.01
if args.useadam8bit:
@ -775,6 +775,7 @@ def main(args):
latents = latents[0].sample() * 0.18215
if zero_frequency_noise_ratio > 0.0:
# see https://www.crosslabs.org//blog/diffusion-with-offset-noise
zero_frequency_noise = zero_frequency_noise_ratio * torch.randn(latents.shape[0], latents.shape[1], 1, 1, device=latents.device)
noise = torch.randn_like(latents) + zero_frequency_noise
else:
@ -924,7 +925,7 @@ def main(args):
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 1 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)