Added support for:

1. EMA decay. Using EMA decay model, it is updated every ema_decay_interval by (1 - ema_decay_rate), it can be stored on CPU to save VRAM. Only EMA model is saved now.
2. min_snr_gamma - improve converging speed, more info: https://arxiv.org/abs/2303.09556
3. load_settings_every_epoch - Will load 'train.json' at start of every epoch.
This commit is contained in:
alexds9 2023-09-06 13:38:52 +03:00
parent a7f94e96e0
commit 23df727a1f
1 changed files with 174 additions and 29 deletions

203
train.py
View File

@ -61,6 +61,7 @@ from utils.convert_diff_to_ckpt import convert as converter
from utils.isolate_rng import isolate_rng
from utils.check_git import check_git
from optimizer.optimizers import EveryDreamOptimizer
from copy import deepcopy
if torch.cuda.is_available():
from utils.gpu import GPU
@ -356,6 +357,64 @@ def log_args(log_writer, args):
arglog += f"{arg}={value}, "
log_writer.add_text("config", arglog)
def update_ema(model, ema_model, decay, default_device):
# TODO: handle Unet/TE not trained
with torch.no_grad():
original_model_on_same_device = model
need_to_delete_original = False
if args.ema_decay_device != default_device:
original_model_on_other_device = deepcopy(model)
original_model_on_same_device = original_model_on_other_device.to(args.ema_decay_device, dtype=model.dtype)
del original_model_on_other_device
need_to_delete_original = True
params = dict(original_model_on_same_device.named_parameters())
ema_params = dict(ema_model.named_parameters())
for name in ema_params:
#ema_params[name].data.mul_(decay).add_(params[name].data, alpha=1 - decay)
ema_params[name].data = ema_params[name] * decay + params[name].data * (1.0 - decay)
if need_to_delete_original:
del(original_model_on_same_device)
def compute_snr(timesteps, noise_scheduler):
"""
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
minimal_value = 1e-9
alphas_cumprod = noise_scheduler.alphas_cumprod
# Use .any() to check if any elements in the tensor are zero
if (alphas_cumprod[:-1] == 0).any():
logging.warning(
f"Alphas cumprod has zero elements! Resetting to {minimal_value}.."
)
alphas_cumprod[alphas_cumprod[:-1] == 0] = minimal_value
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
timesteps
].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
device=timesteps.device
)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
# Compute SNR, first without epsilon
snr = (alpha / sigma) ** 2
# Check if the first element in SNR tensor is zero
if torch.any(snr == 0):
snr[snr == 0] = minimal_value
return snr
def main(args):
"""
@ -384,13 +443,14 @@ def main(args):
device = 'cpu'
gpu = None
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
if not os.path.exists(log_folder):
os.makedirs(log_folder)
@torch.no_grad()
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, ed_optimizer, save_ckpt_dir, yaml_name,
def __save_model(save_path, tokenizer, scheduler, vae, ed_optimizer, save_ckpt_dir, yaml_name,
save_full_precision=False, save_optimizer_flag=False, save_ckpt=True):
"""
Save the model to disk
@ -400,16 +460,28 @@ def main(args):
logging.warning(" No model to save, something likely blew up on startup, not saving")
return
logging.info(f" * Saving diffusers model to {save_path}")
pipeline = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None, # save vram
requires_safety_checker=None, # avoid nag
feature_extractor=None, # must be none of no safety checker
)
if args.ema_decay_rate != None:
pipeline = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder_ema,
tokenizer=tokenizer,
unet=unet_ema,
scheduler=scheduler,
safety_checker=None, # save vram
requires_safety_checker=None, # avoid nag
feature_extractor=None, # must be none of no safety checker
)
else:
pipeline = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None, # save vram
requires_safety_checker=None, # avoid nag
feature_extractor=None, # must be none of no safety checker
)
pipeline.save_pretrained(save_path)
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
@ -458,16 +530,15 @@ def main(args):
unet = pipe.unet
del pipe
# leave the inference scheduler alone
inference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
if args.zero_frequency_noise_ratio == -1.0:
# use zero terminal SNR, currently backdoor way to enable it by setting ZFN to -1, still in testing
from utils.unet_utils import enforce_zero_terminal_snr
temp_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
trained_betas = enforce_zero_terminal_snr(temp_scheduler.betas).numpy().tolist()
inference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
else:
inference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(model_root_folder, subfolder="tokenizer", use_fast=False)
@ -504,6 +575,21 @@ def main(args):
else:
text_encoder = text_encoder.to(device, dtype=torch.float32)
# TODO: handle Unet/TE not trained
if args.ema_decay_rate != None:
if args.ema_decay_device == device:
unet_ema = deepcopy(unet)
text_encoder_ema = deepcopy(text_encoder)
else:
unet_ema_first = deepcopy(unet)
text_encoder_ema_first = deepcopy(text_encoder)
unet_ema = unet_ema_first.to(args.ema_decay_device, dtype=unet.dtype)
text_encoder_ema = text_encoder_ema_first.to(args.ema_decay_device, dtype=text_encoder.dtype)
del unet_ema_first
del text_encoder_ema_first
try:
#unet = torch.compile(unet)
#text_encoder = torch.compile(text_encoder)
@ -622,7 +708,9 @@ def main(args):
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
time.sleep(2) # give opportunity to ctrl-C again to cancel save
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
__save_model(interrupted_checkpoint_path, tokenizer, noise_scheduler, vae,
ed_optimizer, args.save_ckpt_dir, args.save_full_precision, args.save_optimizer,
save_ckpt=not args.no_save_ckpt)
exit(_SIGTERM_EXIT_CODE)
else:
# non-main threads (i.e. dataloader workers) should exit cleanly
@ -670,7 +758,7 @@ def main(args):
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
# actual prediction function - shared between train and validate
def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0):
def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0, return_loss=False):
with torch.no_grad():
with autocast(enabled=args.amp):
pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device)
@ -714,7 +802,28 @@ def main(args):
#print(f"types: {type(noisy_latents)} {type(timesteps)} {type(encoder_hidden_states)}")
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
return model_pred, target
if return_loss:
if args.min_snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
snr = compute_snr(timesteps, noise_scheduler)
mse_loss_weights = (
torch.stack(
[snr, args.min_snr_gamma * torch.ones_like(timesteps)], dim=1
).min(dim=1)[0]
/ snr
)
mse_loss_weights[snr == 0] = 1.0
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
return model_pred, target, loss
else:
return model_pred, target
def generate_samples(global_step: int, batch):
with isolate_rng():
@ -729,7 +838,7 @@ def main(args):
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
diffusers_scheduler_config=inference_scheduler.config,
diffusers_scheduler_config=inference_scheduler.config
).to(device)
sample_generator.generate_samples(inference_pipe, global_step)
@ -755,7 +864,7 @@ def main(args):
else:
logging.info("No plugins specified")
plugins = []
from plugins.plugins import PluginRunner
plugin_runner = PluginRunner(plugins=plugins)
@ -763,6 +872,18 @@ def main(args):
write_batch_schedule(args, log_folder, train_batch, epoch = 0)
for epoch in range(args.max_epochs):
if args.load_settings_every_epoch:
try:
print(f"Loading training config from {args.config}.")
with open(args.config, 'rt') as f:
read_json = json.load(f)
args.__dict__.update(read_json)
except Exception as config_read:
print(f"Error on loading training config from {args.config}.")
plugin_runner.run_on_epoch_start(epoch=epoch,
global_step=global_step,
project_name=args.project_name,
@ -790,9 +911,7 @@ def main(args):
log_folder=log_folder,
batch=batch)
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
model_pred, target, loss = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio, return_loss=True)
del target, model_pred
@ -802,6 +921,18 @@ def main(args):
ed_optimizer.step(loss, step, global_step)
if args.ema_decay_rate != None:
if ((global_step + 1) % args.ema_decay_interval) == 0:
debug_start_time = time.time()
# TODO: handle Unet/TE not trained
# TODO: Remove time measurement
update_ema(unet, unet_ema, args.ema_decay_rate, device)
update_ema(text_encoder, text_encoder_ema, args.ema_decay_rate, device)
debug_end_time = time.time()
debug_elapsed_time = debug_end_time - debug_start_time
print(f"Command update_EMA unet and TE took {debug_elapsed_time:.3f} seconds.")
loss_step = loss.detach().item()
steps_pbar.set_postfix({"loss/step": loss_step}, {"gs": global_step})
@ -818,7 +949,7 @@ def main(args):
lr_unet = ed_optimizer.get_unet_lr()
lr_textenc = ed_optimizer.get_textenc_lr()
loss_log_step = []
log_writer.add_scalar(tag="hyperparameter/lr unet", scalar_value=lr_unet, global_step=global_step)
log_writer.add_scalar(tag="hyperparameter/lr text encoder", scalar_value=lr_textenc, global_step=global_step)
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_step, global_step=global_step)
@ -846,12 +977,16 @@ def main(args):
last_epoch_saved_time = time.time()
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
save_path = make_save_path(epoch, global_step)
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
__save_model(save_path, tokenizer, noise_scheduler, vae, ed_optimizer,
args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer,
save_ckpt=not args.no_save_ckpt)
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
save_path = make_save_path(epoch, global_step)
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
__save_model(save_path, tokenizer, noise_scheduler, vae, ed_optimizer,
args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer,
save_ckpt=not args.no_save_ckpt)
plugin_runner.run_on_step_end(epoch=epoch,
global_step=global_step,
@ -884,13 +1019,14 @@ def main(args):
log_folder=log_folder,
data_root=args.data_root)
gc.collect()
gc.collect()
# end of epoch
# end of training
epoch = args.max_epochs
save_path = make_save_path(epoch, global_step, prepend=("" if args.no_prepend_last else "last-"))
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
__save_model(save_path, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir,
yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
total_elapsed_time = time.time() - training_start_time
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
@ -900,7 +1036,8 @@ def main(args):
except Exception as ex:
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
save_path = make_save_path(epoch, global_step, prepend="errored-")
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
__save_model(save_path, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir,
yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
logging.info(f"{Fore.LIGHTYELLOW_EX}Model saved, re-raising exception and exiting. Exception was:{Style.RESET_ALL}{Fore.LIGHTRED_EX} {ex} {Style.RESET_ALL}")
raise ex
@ -973,8 +1110,16 @@ if __name__ == "__main__":
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
argparser.add_argument("--zero_frequency_noise_ratio", type=float, default=0.02, help="adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15, set to -1 to use zero terminal SNR noising beta schedule instead")
argparser.add_argument("--ema_decay_rate", type=float, default=None, help="EMA decay rate. EMA model will be update by (1 - ema_decay_rate) this value every interval. Values less than 1 and not so far from 1. Using this parameter will enable the feature.")
argparser.add_argument("--ema_decay_target", type=float, default=None, help="Value in range (0,1). decay_rate will be calculated from equation: decay_rate^(total_steps/decay_interval)=decay_target. Using this parameter will enable the feature and overide decay_target.")
argparser.add_argument("--ema_decay_interval", type=int, default=500, help="How many steps between updating EMA decay. EMA model will be update on every global_steps modulo decay_interval.")
argparser.add_argument("--ema_decay_device", type=str, default='cpu', help="EMA decay device values: cpu, cuda. Using CPU is taking up to 4 seconds vs fraction of a second on CUDA, using CUDA will reserve around 3.2GB Vram for model.")
argparser.add_argument("--ema_decay_sample_raw_model", action="store_true", default=False, help="Will show samples from training model, just like regular training. Can be used with: --ema_decay_sample_ema_model")
argparser.add_argument("--ema_decay_sample_ema_model", action="store_true", default=False, help="Will show samples from EMA model. Can be used with: --ema_decay_sample_raw_model")
argparser.add_argument("--min_snr_gamma", type=int, default=None, help="min-SNR-gamma parameteris the loss function into individual tasks. recommended values: 5, 1, 20. More info: https://arxiv.org/abs/2303.09556")
argparser.add_argument("--load_settings_every_epoch", action="store_true", help="Will load 'train.json' at start of every epoch.")
# load CLI args to overwrite existing config args
args = argparser.parse_args(args=argv, namespace=args)
main(args)