Merge pull request #220 from damian0815/fix_ztsnr_samples_and_save
fixes for ZTSNR training
This commit is contained in:
commit
3954182e45
18
train.py
18
train.py
|
@ -458,15 +458,16 @@ def main(args):
|
|||
unet = pipe.unet
|
||||
del pipe
|
||||
|
||||
# leave the inference scheduler alone
|
||||
inference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
||||
|
||||
if args.zero_frequency_noise_ratio == -1.0:
|
||||
# use zero terminal SNR, currently backdoor way to enable it by setting ZFN to -1, still in testing
|
||||
from utils.unet_utils import enforce_zero_terminal_snr
|
||||
temp_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
||||
trained_betas = enforce_zero_terminal_snr(temp_scheduler.betas).numpy().tolist()
|
||||
reference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
|
||||
else:
|
||||
reference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained(model_root_folder, subfolder="tokenizer", use_fast=False)
|
||||
|
@ -588,7 +589,8 @@ def main(args):
|
|||
batch_size=max(1,args.batch_size//2),
|
||||
default_sample_steps=args.sample_steps,
|
||||
use_xformers=is_xformers_available() and not args.disable_xformers,
|
||||
use_penultimate_clip_layer=(args.clip_skip >= 2)
|
||||
use_penultimate_clip_layer=(args.clip_skip >= 2),
|
||||
guidance_rescale = 0.7 if args.zero_frequency_noise_ratio == -1 else 0
|
||||
)
|
||||
|
||||
"""
|
||||
|
@ -727,7 +729,7 @@ def main(args):
|
|||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
diffusers_scheduler_config=reference_scheduler.config
|
||||
diffusers_scheduler_config=inference_scheduler.config,
|
||||
).to(device)
|
||||
sample_generator.generate_samples(inference_pipe, global_step)
|
||||
|
||||
|
@ -844,12 +846,12 @@ def main(args):
|
|||
last_epoch_saved_time = time.time()
|
||||
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
|
||||
save_path = make_save_path(epoch, global_step)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
|
||||
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:
|
||||
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
|
||||
save_path = make_save_path(epoch, global_step)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
|
||||
plugin_runner.run_on_step_end(epoch=epoch,
|
||||
global_step=global_step,
|
||||
|
@ -888,7 +890,7 @@ def main(args):
|
|||
# end of training
|
||||
epoch = args.max_epochs
|
||||
save_path = make_save_path(epoch, global_step, prepend=("" if args.no_prepend_last else "last-"))
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
|
||||
total_elapsed_time = time.time() - training_start_time
|
||||
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
|
||||
|
@ -898,7 +900,7 @@ def main(args):
|
|||
except Exception as ex:
|
||||
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
|
||||
save_path = make_save_path(epoch, global_step, prepend="errored-")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
logging.info(f"{Fore.LIGHTYELLOW_EX}Model saved, re-raising exception and exiting. Exception was:{Style.RESET_ALL}{Fore.LIGHTRED_EX} {ex} {Style.RESET_ALL}")
|
||||
raise ex
|
||||
|
||||
|
|
|
@ -87,7 +87,8 @@ class SampleGenerator:
|
|||
default_seed: int,
|
||||
default_sample_steps: int,
|
||||
use_xformers: bool,
|
||||
use_penultimate_clip_layer: bool):
|
||||
use_penultimate_clip_layer: bool,
|
||||
guidance_rescale: float = 0):
|
||||
self.log_folder = log_folder
|
||||
self.log_writer = log_writer
|
||||
self.batch_size = batch_size
|
||||
|
@ -96,6 +97,7 @@ class SampleGenerator:
|
|||
self.show_progress_bars = False
|
||||
self.generate_pretrain_samples = False
|
||||
self.use_penultimate_clip_layer = use_penultimate_clip_layer
|
||||
self.guidance_rescale = guidance_rescale
|
||||
|
||||
self.default_resolution = default_resolution
|
||||
self.default_seed = default_seed
|
||||
|
@ -228,6 +230,7 @@ class SampleGenerator:
|
|||
generator=generators,
|
||||
width=size[0],
|
||||
height=size[1],
|
||||
guidance_rescale=self.guidance_rescale
|
||||
).images
|
||||
|
||||
for image in images:
|
||||
|
|
Loading…
Reference in New Issue