1. Improved EMA support: samples generation with arguments EMA/NOT-EMA, saving checkpoints and diffusers for both, ema_decay_target implemented.

2. enable_zero_terminal_snr separated from zero_frequency_noise_ratio.
This commit is contained in:
alexds9 2023-09-06 22:37:10 +03:00
parent 23df727a1f
commit 5bcf9407f0
2 changed files with 220 additions and 99 deletions

309
train.py
View File

@ -358,17 +358,20 @@ def log_args(log_writer, args):
log_writer.add_text("config", arglog)
def update_ema(model, ema_model, decay, default_device):
# TODO: handle Unet/TE not trained
with torch.no_grad():
original_model_on_same_device = model
original_model_on_proper_device = model
need_to_delete_original = False
if args.ema_decay_device != default_device:
original_model_on_other_device = deepcopy(model)
original_model_on_same_device = original_model_on_other_device.to(args.ema_decay_device, dtype=model.dtype)
original_model_on_proper_device = original_model_on_other_device.to(args.ema_decay_device, dtype=model.dtype)
del original_model_on_other_device
need_to_delete_original = True
params = dict(original_model_on_same_device.named_parameters())
# original_model_on_proper_device = type(model)().to(args.ema_decay_device, dtype=model.dtype)
# original_model_on_proper_device.load_state_dict(model.state_dict())
params = dict(original_model_on_proper_device.named_parameters())
ema_params = dict(ema_model.named_parameters())
for name in ema_params:
@ -376,7 +379,7 @@ def update_ema(model, ema_model, decay, default_device):
ema_params[name].data = ema_params[name] * decay + params[name].data * (1.0 - decay)
if need_to_delete_original:
del(original_model_on_same_device)
del(original_model_on_proper_device)
def compute_snr(timesteps, noise_scheduler):
"""
@ -415,6 +418,15 @@ def compute_snr(timesteps, noise_scheduler):
snr[snr == 0] = minimal_value
return snr
def load_train_json_from_file(args):
try:
print(f"Loading training config from {args.config}.")
with open(args.config, 'rt') as f:
read_json = json.load(f)
args.__dict__.update(read_json)
except Exception as config_read:
print(f"Error on loading training config from {args.config}.")
def main(args):
"""
@ -452,40 +464,22 @@ def main(args):
@torch.no_grad()
def __save_model(save_path, tokenizer, scheduler, vae, ed_optimizer, save_ckpt_dir, yaml_name,
save_full_precision=False, save_optimizer_flag=False, save_ckpt=True):
nonlocal unet
nonlocal text_encoder
nonlocal unet_ema
nonlocal text_encoder_ema
"""
Save the model to disk
"""
global global_step
if global_step is None or global_step == 0:
logging.warning(" No model to save, something likely blew up on startup, not saving")
return
logging.info(f" * Saving diffusers model to {save_path}")
if args.ema_decay_rate != None:
pipeline = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder_ema,
tokenizer=tokenizer,
unet=unet_ema,
scheduler=scheduler,
safety_checker=None, # save vram
requires_safety_checker=None, # avoid nag
feature_extractor=None, # must be none of no safety checker
)
else:
pipeline = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None, # save vram
requires_safety_checker=None, # avoid nag
feature_extractor=None, # must be none of no safety checker
)
pipeline.save_pretrained(save_path)
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
if save_ckpt:
def save_ckpt_file(diffusers_model_path, sd_ckpt_path):
nonlocal save_ckpt_dir
nonlocal save_full_precision
#nonlocal save_path
nonlocal yaml_name
if save_ckpt_dir is not None:
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
else:
@ -495,17 +489,73 @@ def main(args):
half = not save_full_precision
logging.info(f" * Saving SD model to {sd_ckpt_full}")
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)
converter(model_path=diffusers_model_path, checkpoint_path=sd_ckpt_full, half=half)
if yaml_name and yaml_name != "v1-inference.yaml":
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(save_path))}.yaml"
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(diffusers_model_path))}.yaml"
logging.info(f" * Saving yaml to {yaml_save_path}")
shutil.copyfile(yaml_name, yaml_save_path)
global global_step
if global_step is None or global_step == 0:
logging.warning(" No model to save, something likely blew up on startup, not saving")
return
if args.ema_decay_rate != None:
pipeline_ema = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder_ema,
tokenizer=tokenizer,
unet=unet_ema,
scheduler=scheduler,
safety_checker=None, # save vram
requires_safety_checker=None, # avoid nag
feature_extractor=None, # must be none of no safety checker
)
diffusers_model_path = save_path + "_ema"
logging.info(f" * Saving diffusers EMA model to {diffusers_model_path}")
pipeline_ema.save_pretrained(diffusers_model_path)
if save_ckpt:
sd_ckpt_path_ema = f"{os.path.basename(save_path)}_ema.ckpt"
save_ckpt_file(diffusers_model_path, sd_ckpt_path_ema)
pipeline = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None, # save vram
requires_safety_checker=None, # avoid nag
feature_extractor=None, # must be none of no safety checker
)
diffusers_model_path = save_path
logging.info(f" * Saving diffusers model to {diffusers_model_path}")
pipeline.save_pretrained(diffusers_model_path)
if save_ckpt:
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
save_ckpt_file(diffusers_model_path, sd_ckpt_path)
if save_optimizer_flag:
logging.info(f" Saving optimizer state to {save_path}")
ed_optimizer.save(save_path)
optimizer_state_path = None
try:
# check for a local file
@ -530,8 +580,8 @@ def main(args):
unet = pipe.unet
del pipe
if args.zero_frequency_noise_ratio == -1.0:
# use zero terminal SNR, currently backdoor way to enable it by setting ZFN to -1, still in testing
if args.enable_zero_terminal_snr:
# Use zero terminal SNR
from utils.unet_utils import enforce_zero_terminal_snr
temp_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
trained_betas = enforce_zero_terminal_snr(temp_scheduler.betas).numpy().tolist()
@ -576,18 +626,28 @@ def main(args):
text_encoder = text_encoder.to(device, dtype=torch.float32)
# TODO: handle Unet/TE not trained
if args.ema_decay_rate != None:
if args.ema_decay_device == device:
unet_ema = deepcopy(unet)
text_encoder_ema = deepcopy(text_encoder)
else:
unet_ema_first = deepcopy(unet)
text_encoder_ema_first = deepcopy(text_encoder)
unet_ema = unet_ema_first.to(args.ema_decay_device, dtype=unet.dtype)
text_encoder_ema = text_encoder_ema_first.to(args.ema_decay_device, dtype=text_encoder.dtype)
del unet_ema_first
del text_encoder_ema_first
use_ema_dacay_training = (args.ema_decay_rate != None) or (args.ema_decay_target != None)
if use_ema_dacay_training:
logging.info(f"EMA decay enabled, creating EMA model.")
with torch.no_grad():
if args.ema_decay_device == device:
unet_ema = deepcopy(unet)
text_encoder_ema = deepcopy(text_encoder)
else:
# unet_ema = type(unet)().to(args.ema_decay_device, dtype=unet.dtype)
# unet_ema.load_state_dict(unet.state_dict())
#
# text_encoder_ema = type(unet)().to(args.ema_decay_device, dtype=text_encoder.dtype)
# text_encoder_ema.load_state_dict(text_encoder.state_dict())
unet_ema_first = deepcopy(unet)
text_encoder_ema_first = deepcopy(text_encoder)
unet_ema = unet_ema_first.to(args.ema_decay_device, dtype=unet.dtype)
text_encoder_ema = text_encoder_ema_first.to(args.ema_decay_device, dtype=text_encoder.dtype)
del unet_ema_first
del text_encoder_ema_first
try:
@ -661,6 +721,20 @@ def main(args):
epoch_len = math.ceil(len(train_batch) / args.batch_size)
if use_ema_dacay_training:
if args.ema_decay_target != None:
# Value in range (0,1). decay_rate will be calculated from equation: decay_rate^(total_steps/decay_interval)=decay_target. Using this parameter will enable the feature and overide decay_target.
total_number_of_steps: float = epoch_len * args.max_epochs
total_number_of_ema_update: float = total_number_of_steps / args.ema_decay_interval
args.ema_decay_rate = args.ema_decay_target ** (1 / total_number_of_ema_update)
logging.info(f"ema_decay_target is {args.ema_decay_target}, calculated ema_decay_rate will be: {args.ema_decay_rate}.")
logging.info(
f"EMA decay enabled, with ema_decay_rate {args.ema_decay_rate}, ema_decay_interval: {args.ema_decay_interval}, ema_decay_device: {args.ema_decay_device}.")
ed_optimizer = EveryDreamOptimizer(args,
optimizer_config,
text_encoder,
@ -676,7 +750,7 @@ def main(args):
default_sample_steps=args.sample_steps,
use_xformers=is_xformers_available() and not args.disable_xformers,
use_penultimate_clip_layer=(args.clip_skip >= 2),
guidance_rescale = 0.7 if args.zero_frequency_noise_ratio == -1 else 0
guidance_rescale=0.7 if args.enable_zero_terminal_snr else 0
)
"""
@ -766,12 +840,13 @@ def main(args):
del pixel_values
latents = latents[0].sample() * 0.18215
if zero_frequency_noise_ratio > 0.0:
if zero_frequency_noise_ratio != None:
if zero_frequency_noise_ratio < 0:
zero_frequency_noise_ratio = 0
# see https://www.crosslabs.org//blog/diffusion-with-offset-noise
zero_frequency_noise = zero_frequency_noise_ratio * torch.randn(latents.shape[0], latents.shape[1], 1, 1, device=latents.device)
noise = torch.randn_like(latents) + zero_frequency_noise
else:
noise = torch.randn_like(latents)
bsz = latents.shape[0]
@ -826,6 +901,11 @@ def main(args):
return model_pred, target
def generate_samples(global_step: int, batch):
nonlocal unet
nonlocal text_encoder
nonlocal unet_ema
nonlocal text_encoder_ema
with isolate_rng():
prev_sample_steps = sample_generator.sample_steps
sample_generator.reload_config()
@ -834,20 +914,71 @@ def main(args):
print(f" * SampleGenerator config changed, now generating images samples every " +
f"{sample_generator.sample_steps} training steps (next={next_sample_step})")
sample_generator.update_random_captions(batch["captions"])
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
diffusers_scheduler_config=inference_scheduler.config
).to(device)
sample_generator.generate_samples(inference_pipe, global_step)
del inference_pipe
gc.collect()
models_info = []
if (args.ema_decay_rate is None) or (args.ema_decay_sample_raw_training is not None):
models_info.append({"is_ema": False, "swap_required": False})
if (args.ema_decay_rate is not None) and (args.ema_decay_sample_ema_model is not None):
models_info.append({"is_ema": True, "swap_required": args.ema_decay_device != device})
for model_info in models_info:
extra_info: str = ""
if model_info["is_ema"]:
current_unet, current_text_encoder = unet_ema, text_encoder_ema
extra_info = "ema_"
else:
current_unet, current_text_encoder = unet, text_encoder
if model_info["swap_required"]:
with torch.no_grad():
unet_unloaded = unet.to(args.ema_decay_device)
del unet
text_encoder_unloaded = text_encoder.to(args.ema_decay_device)
del text_encoder
current_unet = unet_ema.to(device)
del unet_ema
current_text_encoder = text_encoder_ema.to(device)
del text_encoder_ema
gc.collect()
inference_pipe = sample_generator.create_inference_pipe(unet=current_unet,
text_encoder=current_text_encoder,
tokenizer=tokenizer,
vae=vae,
diffusers_scheduler_config=inference_scheduler.config
).to(device)
sample_generator.generate_samples(inference_pipe, global_step, extra_info=extra_info)
# Cleanup
del inference_pipe
if model_info["swap_required"]:
with torch.no_grad():
unet = unet_unloaded.to(device)
del unet_unloaded
text_encoder = text_encoder_unloaded.to(device)
del text_encoder_unloaded
unet_ema = current_unet.to(args.ema_decay_device)
del current_unet
text_encoder_ema = current_text_encoder.to(args.ema_decay_device)
del current_text_encoder
gc.collect()
def make_save_path(epoch, global_step, prepend=""):
return os.path.join(f"{log_folder}/ckpts/{prepend}{args.project_name}-ep{epoch:02}-gs{global_step:05}")
# Pre-train validation to establish a starting point on the loss graph
if validator:
validator.do_validation(global_step=0,
@ -874,14 +1005,7 @@ def main(args):
for epoch in range(args.max_epochs):
if args.load_settings_every_epoch:
try:
print(f"Loading training config from {args.config}.")
with open(args.config, 'rt') as f:
read_json = json.load(f)
args.__dict__.update(read_json)
except Exception as config_read:
print(f"Error on loading training config from {args.config}.")
load_train_json_from_file(args)
plugin_runner.run_on_epoch_start(epoch=epoch,
@ -923,11 +1047,14 @@ def main(args):
if args.ema_decay_rate != None:
if ((global_step + 1) % args.ema_decay_interval) == 0:
debug_start_time = time.time()
# TODO: handle Unet/TE not trained
# TODO: Remove time measurement
update_ema(unet, unet_ema, args.ema_decay_rate, device)
update_ema(text_encoder, text_encoder_ema, args.ema_decay_rate, device)
debug_start_time = time.time() # TODO: Remove time measurement when debug done
if args.disable_unet_training != True:
update_ema(unet, unet_ema, args.ema_decay_rate, default_device=device)
if args.disable_textenc_training != True:
update_ema(text_encoder, text_encoder_ema, args.ema_decay_rate, default_device=device)
debug_end_time = time.time()
debug_elapsed_time = debug_end_time - debug_start_time
print(f"Command update_EMA unet and TE took {debug_elapsed_time:.3f} seconds.")
@ -1053,14 +1180,7 @@ if __name__ == "__main__":
argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from")
args, argv = argparser.parse_known_args()
if args.config is not None:
print(f"Loading training config from {args.config}.")
with open(args.config, 'rt') as f:
args.__dict__.update(json.load(f))
if len(argv) > 0:
print(f"Config .json loaded but there are additional CLI arguments -- these will override values in {args.config}.")
else:
print("No config file specified, using command line args")
load_train_json_from_file(args)
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
argparser.add_argument("--amp", action="store_true", default=True, help="deprecated, use --disable_amp if you wish to disable AMP")
@ -1109,15 +1229,16 @@ if __name__ == "__main__":
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
argparser.add_argument("--zero_frequency_noise_ratio", type=float, default=0.02, help="adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15, set to -1 to use zero terminal SNR noising beta schedule instead")
argparser.add_argument("--ema_decay_rate", type=float, default=None, help="EMA decay rate. EMA model will be update by (1 - ema_decay_rate) this value every interval. Values less than 1 and not so far from 1. Using this parameter will enable the feature.")
argparser.add_argument("--ema_decay_target", type=float, default=None, help="Value in range (0,1). decay_rate will be calculated from equation: decay_rate^(total_steps/decay_interval)=decay_target. Using this parameter will enable the feature and overide decay_target.")
argparser.add_argument("--ema_decay_interval", type=int, default=500, help="How many steps between updating EMA decay. EMA model will be update on every global_steps modulo decay_interval.")
argparser.add_argument("--ema_decay_device", type=str, default='cpu', help="EMA decay device values: cpu, cuda. Using CPU is taking up to 4 seconds vs fraction of a second on CUDA, using CUDA will reserve around 3.2GB Vram for model.")
argparser.add_argument("--ema_decay_sample_raw_model", action="store_true", default=False, help="Will show samples from training model, just like regular training. Can be used with: --ema_decay_sample_ema_model")
argparser.add_argument("--ema_decay_sample_ema_model", action="store_true", default=False, help="Will show samples from EMA model. Can be used with: --ema_decay_sample_raw_model")
argparser.add_argument("--min_snr_gamma", type=int, default=None, help="min-SNR-gamma parameteris the loss function into individual tasks. recommended values: 5, 1, 20. More info: https://arxiv.org/abs/2303.09556")
argparser.add_argument("--load_settings_every_epoch", action="store_true", help="Will load 'train.json' at start of every epoch.")
argparser.add_argument("--zero_frequency_noise_ratio", type=float, default=0.02, help="adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15")
argparser.add_argument("--enable_zero_terminal_snr", action="store_true", default=None, help="Use zero terminal SNR noising beta schedule")
argparser.add_argument("--ema_decay_rate", type=float, default=None, help="EMA decay rate. EMA model will be updated with (1 - ema_decay_rate) from training, and the ema_decay_rate from previous EMA, every interval. Values less than 1 and not so far from 1. Using this parameter will enable the feature.")
argparser.add_argument("--ema_decay_target", type=float, default=None, help="EMA decay target value in range (0,1). ema_decay_rate will be calculated from equation: decay_rate^(total_steps/decay_interval)=decay_target. Using this parameter will enable the feature and overide decay_target.")
argparser.add_argument("--ema_decay_interval", type=int, default=500, help="How many steps between every EMA decay update. EMA model will be update on every global_steps modulo decay_interval.")
argparser.add_argument("--ema_decay_device", type=str, default='cpu', help="EMA decay device values: cpu, cuda. Using 'cpu' is taking around 4 seconds per update vs fraction of a second on 'cuda'. Using 'cuda' will reserve around 3.2GB VRAM for a model, with 'cpu' RAM will be used.")
argparser.add_argument("--ema_decay_sample_raw_training", action="store_true", default=False, help="Will show samples from trained model, just like regular training. Can be used with: --ema_decay_sample_ema_model")
argparser.add_argument("--ema_decay_sample_ema_model", action="store_true", default=False, help="Will show samples from EMA model. Can be used with: --ema_decay_sample_raw_training")
argparser.add_argument("--min_snr_gamma", type=int, default=None, help="min-SNR-gamma parameteris the loss function into individual tasks. Recommended values: 5, 1, 20. Disabled by default and enabled when used. More info: https://arxiv.org/abs/2303.09556")
argparser.add_argument("--load_settings_every_epoch", action="store_true", default=None, help="Will load 'train.json' at start of every epoch. Disabled by default and enabled when used.")
# load CLI args to overwrite existing config args
args = argparser.parse_args(args=argv, namespace=args)

View File

@ -181,7 +181,7 @@ class SampleGenerator:
self.sample_requests = self._make_random_caption_sample_requests()
@torch.no_grad()
def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int):
def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int, extra_info: str = ""):
"""
generates samples at different cfg scales and saves them to disk
"""
@ -269,15 +269,15 @@ class SampleGenerator:
prompt = prompts[prompt_idx]
clean_prompt = clean_filename(prompt)
result.save(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
with open(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f:
result.save(f"{self.log_folder}/samples/{extra_info}gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
with open(f"{self.log_folder}/samples/{extra_info}gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f:
f.write(str(batch[prompt_idx]))
tfimage = transforms.ToTensor()(result)
if batch[prompt_idx].wants_random_caption:
self.log_writer.add_image(tag=f"sample_{sample_index}", img_tensor=tfimage, global_step=global_step)
self.log_writer.add_image(tag=f"{extra_info}sample_{sample_index}", img_tensor=tfimage, global_step=global_step)
else:
self.log_writer.add_image(tag=f"sample_{sample_index}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=global_step)
self.log_writer.add_image(tag=f"{extra_info}sample_{sample_index}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=global_step)
sample_index += 1
del result