zero frequency noise option to improve contrast
This commit is contained in:
parent
43f7f3c0f1
commit
37cf437a5f
|
@ -38,5 +38,6 @@
|
|||
"wandb": false,
|
||||
"write_schedule": false,
|
||||
"rated_dataset": false,
|
||||
"rated_dataset_target_dropout_percent": 50
|
||||
"rated_dataset_target_dropout_percent": 50,
|
||||
"zero_frequency_noise_ratio": 0.0
|
||||
}
|
||||
|
|
15
train.py
15
train.py
|
@ -27,6 +27,7 @@ import gc
|
|||
import random
|
||||
import traceback
|
||||
import shutil
|
||||
import importlib
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
|
@ -765,7 +766,7 @@ def main(args):
|
|||
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
|
||||
|
||||
# actual prediction function - shared between train and validate
|
||||
def get_model_prediction_and_target(image, tokens):
|
||||
def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0):
|
||||
with torch.no_grad():
|
||||
with autocast(enabled=args.amp):
|
||||
pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device)
|
||||
|
@ -773,7 +774,12 @@ def main(args):
|
|||
del pixel_values
|
||||
latents = latents[0].sample() * 0.18215
|
||||
|
||||
noise = torch.randn_like(latents)
|
||||
if zero_frequency_noise_ratio > 0.0:
|
||||
zero_frequency_noise = zero_frequency_noise_ratio * torch.randn(latents.shape[0], latents.shape[1], 1, 1, device=latents.device)
|
||||
noise = torch.randn_like(latents) + zero_frequency_noise
|
||||
else:
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
bsz = latents.shape[0]
|
||||
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
|
@ -839,7 +845,7 @@ def main(args):
|
|||
for step, batch in enumerate(train_dataloader):
|
||||
step_start_time = time.time()
|
||||
|
||||
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
|
||||
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio)
|
||||
|
||||
#del timesteps, encoder_hidden_states, noisy_latents
|
||||
#with autocast(enabled=args.amp):
|
||||
|
@ -944,7 +950,7 @@ def main(args):
|
|||
|
||||
if validator:
|
||||
validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
|
||||
|
||||
|
||||
gc.collect()
|
||||
# end of epoch
|
||||
|
||||
|
@ -1029,6 +1035,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
|
||||
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
|
||||
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
|
||||
argparser.add_argument("--zero_frequency_noise_ratio", type=float, default=0.0, help="adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15")
|
||||
|
||||
# load CLI args to overwrite existing config args
|
||||
args = argparser.parse_args(args=argv, namespace=args)
|
||||
|
|
Loading…
Reference in New Issue