From 9b5b96a50b74c28e2c0804c8743167d2ec13f56f Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Tue, 15 Aug 2023 20:47:34 +0200 Subject: [PATCH] fixes for ZTSNR training --- train.py | 20 +++++++++++--------- utils/sample_generator.py | 5 ++++- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/train.py b/train.py index 57c36a6..cc55821 100644 --- a/train.py +++ b/train.py @@ -457,16 +457,17 @@ def main(args): vae = pipe.vae unet = pipe.unet del pipe - + + # leave the inference scheduler alone + inference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler") + if args.zero_frequency_noise_ratio == -1.0: # use zero terminal SNR, currently backdoor way to enable it by setting ZFN to -1, still in testing from utils.unet_utils import enforce_zero_terminal_snr temp_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler") trained_betas = enforce_zero_terminal_snr(temp_scheduler.betas).numpy().tolist() - reference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas) noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas) else: - reference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler") tokenizer = CLIPTokenizer.from_pretrained(model_root_folder, subfolder="tokenizer", use_fast=False) @@ -588,7 +589,8 @@ def main(args): batch_size=max(1,args.batch_size//2), default_sample_steps=args.sample_steps, use_xformers=is_xformers_available() and not args.disable_xformers, - use_penultimate_clip_layer=(args.clip_skip >= 2) + use_penultimate_clip_layer=(args.clip_skip >= 2), + guidance_rescale = 0.7 if args.zero_frequency_noise_ratio == -1 else 0 ) """ @@ -727,7 +729,7 @@ def main(args): text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, - diffusers_scheduler_config=reference_scheduler.config + diffusers_scheduler_config=inference_scheduler.config, ).to(device) sample_generator.generate_samples(inference_pipe, global_step) @@ -844,12 +846,12 @@ def main(args): last_epoch_saved_time = time.time() logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}") save_path = make_save_path(epoch, global_step) - __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt) + __save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt) if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs: logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}") save_path = make_save_path(epoch, global_step) - __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt) + __save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt) plugin_runner.run_on_step_end(epoch=epoch, global_step=global_step, @@ -888,7 +890,7 @@ def main(args): # end of training epoch = args.max_epochs save_path = make_save_path(epoch, global_step, prepend=("" if args.no_prepend_last else "last-")) - __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt) + __save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt) total_elapsed_time = time.time() - training_start_time logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}") @@ -898,7 +900,7 @@ def main(args): except Exception as ex: logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}") save_path = make_save_path(epoch, global_step, prepend="errored-") - __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt) + __save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt) logging.info(f"{Fore.LIGHTYELLOW_EX}Model saved, re-raising exception and exiting. Exception was:{Style.RESET_ALL}{Fore.LIGHTRED_EX} {ex} {Style.RESET_ALL}") raise ex diff --git a/utils/sample_generator.py b/utils/sample_generator.py index dd13bf8..a39413f 100644 --- a/utils/sample_generator.py +++ b/utils/sample_generator.py @@ -87,7 +87,8 @@ class SampleGenerator: default_seed: int, default_sample_steps: int, use_xformers: bool, - use_penultimate_clip_layer: bool): + use_penultimate_clip_layer: bool, + guidance_rescale: float = 0): self.log_folder = log_folder self.log_writer = log_writer self.batch_size = batch_size @@ -96,6 +97,7 @@ class SampleGenerator: self.show_progress_bars = False self.generate_pretrain_samples = False self.use_penultimate_clip_layer = use_penultimate_clip_layer + self.guidance_rescale = guidance_rescale self.default_resolution = default_resolution self.default_seed = default_seed @@ -228,6 +230,7 @@ class SampleGenerator: generator=generators, width=size[0], height=size[1], + guidance_rescale=self.guidance_rescale ).images for image in images: