Merge pull request #220 from damian0815/fix_ztsnr_samples_and_save
fixes for ZTSNR training
This commit is contained in:
commit
3954182e45
18
train.py
18
train.py
|
@ -458,15 +458,16 @@ def main(args):
|
||||||
unet = pipe.unet
|
unet = pipe.unet
|
||||||
del pipe
|
del pipe
|
||||||
|
|
||||||
|
# leave the inference scheduler alone
|
||||||
|
inference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
||||||
|
|
||||||
if args.zero_frequency_noise_ratio == -1.0:
|
if args.zero_frequency_noise_ratio == -1.0:
|
||||||
# use zero terminal SNR, currently backdoor way to enable it by setting ZFN to -1, still in testing
|
# use zero terminal SNR, currently backdoor way to enable it by setting ZFN to -1, still in testing
|
||||||
from utils.unet_utils import enforce_zero_terminal_snr
|
from utils.unet_utils import enforce_zero_terminal_snr
|
||||||
temp_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
temp_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
||||||
trained_betas = enforce_zero_terminal_snr(temp_scheduler.betas).numpy().tolist()
|
trained_betas = enforce_zero_terminal_snr(temp_scheduler.betas).numpy().tolist()
|
||||||
reference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
|
|
||||||
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
|
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
|
||||||
else:
|
else:
|
||||||
reference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
|
||||||
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
||||||
|
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(model_root_folder, subfolder="tokenizer", use_fast=False)
|
tokenizer = CLIPTokenizer.from_pretrained(model_root_folder, subfolder="tokenizer", use_fast=False)
|
||||||
|
@ -588,7 +589,8 @@ def main(args):
|
||||||
batch_size=max(1,args.batch_size//2),
|
batch_size=max(1,args.batch_size//2),
|
||||||
default_sample_steps=args.sample_steps,
|
default_sample_steps=args.sample_steps,
|
||||||
use_xformers=is_xformers_available() and not args.disable_xformers,
|
use_xformers=is_xformers_available() and not args.disable_xformers,
|
||||||
use_penultimate_clip_layer=(args.clip_skip >= 2)
|
use_penultimate_clip_layer=(args.clip_skip >= 2),
|
||||||
|
guidance_rescale = 0.7 if args.zero_frequency_noise_ratio == -1 else 0
|
||||||
)
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -727,7 +729,7 @@ def main(args):
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
vae=vae,
|
vae=vae,
|
||||||
diffusers_scheduler_config=reference_scheduler.config
|
diffusers_scheduler_config=inference_scheduler.config,
|
||||||
).to(device)
|
).to(device)
|
||||||
sample_generator.generate_samples(inference_pipe, global_step)
|
sample_generator.generate_samples(inference_pipe, global_step)
|
||||||
|
|
||||||
|
@ -844,12 +846,12 @@ def main(args):
|
||||||
last_epoch_saved_time = time.time()
|
last_epoch_saved_time = time.time()
|
||||||
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
|
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
|
||||||
save_path = make_save_path(epoch, global_step)
|
save_path = make_save_path(epoch, global_step)
|
||||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||||
|
|
||||||
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:
|
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:
|
||||||
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
|
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
|
||||||
save_path = make_save_path(epoch, global_step)
|
save_path = make_save_path(epoch, global_step)
|
||||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||||
|
|
||||||
plugin_runner.run_on_step_end(epoch=epoch,
|
plugin_runner.run_on_step_end(epoch=epoch,
|
||||||
global_step=global_step,
|
global_step=global_step,
|
||||||
|
@ -888,7 +890,7 @@ def main(args):
|
||||||
# end of training
|
# end of training
|
||||||
epoch = args.max_epochs
|
epoch = args.max_epochs
|
||||||
save_path = make_save_path(epoch, global_step, prepend=("" if args.no_prepend_last else "last-"))
|
save_path = make_save_path(epoch, global_step, prepend=("" if args.no_prepend_last else "last-"))
|
||||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||||
|
|
||||||
total_elapsed_time = time.time() - training_start_time
|
total_elapsed_time = time.time() - training_start_time
|
||||||
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
|
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
|
||||||
|
@ -898,7 +900,7 @@ def main(args):
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
|
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
|
||||||
save_path = make_save_path(epoch, global_step, prepend="errored-")
|
save_path = make_save_path(epoch, global_step, prepend="errored-")
|
||||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||||
logging.info(f"{Fore.LIGHTYELLOW_EX}Model saved, re-raising exception and exiting. Exception was:{Style.RESET_ALL}{Fore.LIGHTRED_EX} {ex} {Style.RESET_ALL}")
|
logging.info(f"{Fore.LIGHTYELLOW_EX}Model saved, re-raising exception and exiting. Exception was:{Style.RESET_ALL}{Fore.LIGHTRED_EX} {ex} {Style.RESET_ALL}")
|
||||||
raise ex
|
raise ex
|
||||||
|
|
||||||
|
|
|
@ -87,7 +87,8 @@ class SampleGenerator:
|
||||||
default_seed: int,
|
default_seed: int,
|
||||||
default_sample_steps: int,
|
default_sample_steps: int,
|
||||||
use_xformers: bool,
|
use_xformers: bool,
|
||||||
use_penultimate_clip_layer: bool):
|
use_penultimate_clip_layer: bool,
|
||||||
|
guidance_rescale: float = 0):
|
||||||
self.log_folder = log_folder
|
self.log_folder = log_folder
|
||||||
self.log_writer = log_writer
|
self.log_writer = log_writer
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
@ -96,6 +97,7 @@ class SampleGenerator:
|
||||||
self.show_progress_bars = False
|
self.show_progress_bars = False
|
||||||
self.generate_pretrain_samples = False
|
self.generate_pretrain_samples = False
|
||||||
self.use_penultimate_clip_layer = use_penultimate_clip_layer
|
self.use_penultimate_clip_layer = use_penultimate_clip_layer
|
||||||
|
self.guidance_rescale = guidance_rescale
|
||||||
|
|
||||||
self.default_resolution = default_resolution
|
self.default_resolution = default_resolution
|
||||||
self.default_seed = default_seed
|
self.default_seed = default_seed
|
||||||
|
@ -228,6 +230,7 @@ class SampleGenerator:
|
||||||
generator=generators,
|
generator=generators,
|
||||||
width=size[0],
|
width=size[0],
|
||||||
height=size[1],
|
height=size[1],
|
||||||
|
guidance_rescale=self.guidance_rescale
|
||||||
).images
|
).images
|
||||||
|
|
||||||
for image in images:
|
for image in images:
|
||||||
|
|
Loading…
Reference in New Issue