Merge pull request #220 from damian0815/fix_ztsnr_samples_and_save

fixes for ZTSNR training
This commit is contained in:
Victor Hall 2023-09-03 17:57:27 -04:00 committed by GitHub
commit 3954182e45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 10 deletions

View File

@ -458,15 +458,16 @@ def main(args):
unet = pipe.unet
del pipe
# leave the inference scheduler alone
inference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
if args.zero_frequency_noise_ratio == -1.0:
# use zero terminal SNR, currently backdoor way to enable it by setting ZFN to -1, still in testing
from utils.unet_utils import enforce_zero_terminal_snr
temp_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
trained_betas = enforce_zero_terminal_snr(temp_scheduler.betas).numpy().tolist()
reference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
else:
reference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(model_root_folder, subfolder="tokenizer", use_fast=False)
@ -588,7 +589,8 @@ def main(args):
batch_size=max(1,args.batch_size//2),
default_sample_steps=args.sample_steps,
use_xformers=is_xformers_available() and not args.disable_xformers,
use_penultimate_clip_layer=(args.clip_skip >= 2)
use_penultimate_clip_layer=(args.clip_skip >= 2),
guidance_rescale = 0.7 if args.zero_frequency_noise_ratio == -1 else 0
)
"""
@ -727,7 +729,7 @@ def main(args):
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
diffusers_scheduler_config=reference_scheduler.config
diffusers_scheduler_config=inference_scheduler.config,
).to(device)
sample_generator.generate_samples(inference_pipe, global_step)
@ -844,12 +846,12 @@ def main(args):
last_epoch_saved_time = time.time()
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
save_path = make_save_path(epoch, global_step)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
save_path = make_save_path(epoch, global_step)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
plugin_runner.run_on_step_end(epoch=epoch,
global_step=global_step,
@ -888,7 +890,7 @@ def main(args):
# end of training
epoch = args.max_epochs
save_path = make_save_path(epoch, global_step, prepend=("" if args.no_prepend_last else "last-"))
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
total_elapsed_time = time.time() - training_start_time
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
@ -898,7 +900,7 @@ def main(args):
except Exception as ex:
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
save_path = make_save_path(epoch, global_step, prepend="errored-")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
logging.info(f"{Fore.LIGHTYELLOW_EX}Model saved, re-raising exception and exiting. Exception was:{Style.RESET_ALL}{Fore.LIGHTRED_EX} {ex} {Style.RESET_ALL}")
raise ex

View File

@ -87,7 +87,8 @@ class SampleGenerator:
default_seed: int,
default_sample_steps: int,
use_xformers: bool,
use_penultimate_clip_layer: bool):
use_penultimate_clip_layer: bool,
guidance_rescale: float = 0):
self.log_folder = log_folder
self.log_writer = log_writer
self.batch_size = batch_size
@ -96,6 +97,7 @@ class SampleGenerator:
self.show_progress_bars = False
self.generate_pretrain_samples = False
self.use_penultimate_clip_layer = use_penultimate_clip_layer
self.guidance_rescale = guidance_rescale
self.default_resolution = default_resolution
self.default_seed = default_seed
@ -228,6 +230,7 @@ class SampleGenerator:
generator=generators,
width=size[0],
height=size[1],
guidance_rescale=self.guidance_rescale
).images
for image in images: