fix ckpt save if only 1 step per epoch, revert epsilon for amp
This commit is contained in:
parent
6e7e7f9e1f
commit
2f48460691
5
train.py
5
train.py
|
@ -599,7 +599,7 @@ def main(args):
|
||||||
betas = (0.9, 0.999)
|
betas = (0.9, 0.999)
|
||||||
epsilon = 1e-8
|
epsilon = 1e-8
|
||||||
if args.amp:
|
if args.amp:
|
||||||
epsilon = 2e-8
|
epsilon = 1e-8
|
||||||
|
|
||||||
weight_decay = 0.01
|
weight_decay = 0.01
|
||||||
if args.useadam8bit:
|
if args.useadam8bit:
|
||||||
|
@ -775,6 +775,7 @@ def main(args):
|
||||||
latents = latents[0].sample() * 0.18215
|
latents = latents[0].sample() * 0.18215
|
||||||
|
|
||||||
if zero_frequency_noise_ratio > 0.0:
|
if zero_frequency_noise_ratio > 0.0:
|
||||||
|
# see https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||||
zero_frequency_noise = zero_frequency_noise_ratio * torch.randn(latents.shape[0], latents.shape[1], 1, 1, device=latents.device)
|
zero_frequency_noise = zero_frequency_noise_ratio * torch.randn(latents.shape[0], latents.shape[1], 1, 1, device=latents.device)
|
||||||
noise = torch.randn_like(latents) + zero_frequency_noise
|
noise = torch.randn_like(latents) + zero_frequency_noise
|
||||||
else:
|
else:
|
||||||
|
@ -924,7 +925,7 @@ def main(args):
|
||||||
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||||
|
|
||||||
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 1 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:
|
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:
|
||||||
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
|
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
|
||||||
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||||
|
|
Loading…
Reference in New Issue