zero frequency noise option to improve contrast

This commit is contained in:
Victor Hall 2023-02-15 18:53:08 -05:00
parent 43f7f3c0f1
commit 37cf437a5f
2 changed files with 13 additions and 5 deletions

View File

@ -38,5 +38,6 @@
"wandb": false,
"write_schedule": false,
"rated_dataset": false,
"rated_dataset_target_dropout_percent": 50
"rated_dataset_target_dropout_percent": 50,
"zero_frequency_noise_ratio": 0.0
}

View File

@ -27,6 +27,7 @@ import gc
import random
import traceback
import shutil
import importlib
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
@ -765,7 +766,7 @@ def main(args):
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
# actual prediction function - shared between train and validate
def get_model_prediction_and_target(image, tokens):
def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0):
with torch.no_grad():
with autocast(enabled=args.amp):
pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device)
@ -773,7 +774,12 @@ def main(args):
del pixel_values
latents = latents[0].sample() * 0.18215
noise = torch.randn_like(latents)
if zero_frequency_noise_ratio > 0.0:
zero_frequency_noise = zero_frequency_noise_ratio * torch.randn(latents.shape[0], latents.shape[1], 1, 1, device=latents.device)
noise = torch.randn_like(latents) + zero_frequency_noise
else:
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
@ -839,7 +845,7 @@ def main(args):
for step, batch in enumerate(train_dataloader):
step_start_time = time.time()
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio)
#del timesteps, encoder_hidden_states, noisy_latents
#with autocast(enabled=args.amp):
@ -944,7 +950,7 @@ def main(args):
if validator:
validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
gc.collect()
# end of epoch
@ -1029,6 +1035,7 @@ if __name__ == "__main__":
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
argparser.add_argument("--zero_frequency_noise_ratio", type=float, default=0.0, help="adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15")
# load CLI args to overwrite existing config args
args = argparser.parse_args(args=argv, namespace=args)