diff --git a/train.json b/train.json index 8791b81..ebd1998 100644 --- a/train.json +++ b/train.json @@ -38,5 +38,6 @@ "wandb": false, "write_schedule": false, "rated_dataset": false, - "rated_dataset_target_dropout_percent": 50 + "rated_dataset_target_dropout_percent": 50, + "zero_frequency_noise_ratio": 0.0 } diff --git a/train.py b/train.py index 8ca8b65..56f02ec 100644 --- a/train.py +++ b/train.py @@ -27,6 +27,7 @@ import gc import random import traceback import shutil +import importlib import torch.nn.functional as F from torch.cuda.amp import autocast, GradScaler @@ -765,7 +766,7 @@ def main(args): assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct" # actual prediction function - shared between train and validate - def get_model_prediction_and_target(image, tokens): + def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0): with torch.no_grad(): with autocast(enabled=args.amp): pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device) @@ -773,7 +774,12 @@ def main(args): del pixel_values latents = latents[0].sample() * 0.18215 - noise = torch.randn_like(latents) + if zero_frequency_noise_ratio > 0.0: + zero_frequency_noise = zero_frequency_noise_ratio * torch.randn(latents.shape[0], latents.shape[1], 1, 1, device=latents.device) + noise = torch.randn_like(latents) + zero_frequency_noise + else: + noise = torch.randn_like(latents) + bsz = latents.shape[0] timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) @@ -839,7 +845,7 @@ def main(args): for step, batch in enumerate(train_dataloader): step_start_time = time.time() - model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"]) + model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio) #del timesteps, encoder_hidden_states, noisy_latents #with autocast(enabled=args.amp): @@ -944,7 +950,7 @@ def main(args): if validator: validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target) - + gc.collect() # end of epoch @@ -1029,6 +1035,7 @@ if __name__ == "__main__": argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)") argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs") argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)") + argparser.add_argument("--zero_frequency_noise_ratio", type=float, default=0.0, help="adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15") # load CLI args to overwrite existing config args args = argparser.parse_args(args=argv, namespace=args)