Merge branch 'main' into feat_rolling_save

This commit is contained in:
Victor Hall 2023-09-20 16:32:37 -04:00 committed by GitHub
commit 09aa13c3dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 15 additions and 16 deletions

View File

@ -230,10 +230,10 @@ In this mode, the EMA model will be saved alongside the regular checkpoint from
For more information, consult the [research paper](https://arxiv.org/abs/2101.08482) or continue reading the tuning notes below.
**Parameters:**
- `--ema_decay_rate`: Determines the EMA decay rate. It defines how much the EMA model is updated from training at each update. Values should be close to 1 but not exceed it. Activating this parameter triggers the EMA decay feature.
- `--ema_decay_target`: Set the EMA decay target value within the (0,1) range. The `ema_decay_rate` is computed based on the relation: decay_rate to the power of (total_steps/decay_interval) equals decay_target. Enabling this parameter will override `ema_decay_rate` and will enable EMA feature.
- `--ema_strength_target`: Set the EMA strength target value within the (0,1) range. The `ema_decay_rate` is computed based on the relation: decay_rate to the power of (total_steps/decay_interval) equals decay_target. Enabling this parameter will override `ema_decay_rate` and will enable EMA feature. See [ema_strength_target](#ema_strength_target) for more information.
- `--ema_update_interval`: Set the interval in steps between EMA updates. The update occurs at each optimizer step. If you use grad_accum, actual update interval will be multipled by your grad_accum value.
- `--ema_device`: Choose between `cpu` and `cuda` for EMA. Opting for 'cpu' takes around 4 seconds per update and uses approximately 3.2GB RAM, while 'cuda' is much faster but requires a similar amount of VRAM.
- `--ema_sample_raw_training`: Activate to display samples from the trained model, mirroring conventional training. They will not be presented by default with EMA decay enabled.
- `--ema_sample_nonema_model`: Activate to display samples from the non-ema trained model, mirroring conventional training. They will not be presented by default with EMA decay enabled.
- `--ema_sample_ema_model`: Turn on to exhibit samples from the EMA model. EMA models will be used for samples generations by default with EMA decay enabled, unless disabled.
- `--ema_resume_model`: Indicate the EMA decay checkpoint to continue from, working like `--resume_ckpt` but will load EMA model. Using `findlast` will only load EMA version and not regular training.

View File

@ -4,7 +4,7 @@
"clip_grad_norm": null,
"clip_skip": 0,
"cond_dropout": 0.04,
"data_root": "X:\\my_project_data\\project_abc",
"data_root": "/mnt/q/training_samples/ff7r/man",
"disable_amp": false,
"disable_textenc_training": false,
"disable_xformers": false,
@ -19,7 +19,7 @@
"lr_decay_steps": 0,
"lr_scheduler": "constant",
"lr_warmup_steps": null,
"max_epochs": 30,
"max_epochs": 1,
"notebook": false,
"optimizer_config": "optimizer.json",
"project_name": "project_abc",
@ -45,10 +45,10 @@
"load_settings_every_epoch": false,
"min_snr_gamma": null,
"ema_decay_rate": null,
"ema_decay_target": null,
"ema_strength_target": null,
"ema_update_interval": null,
"ema_device": null,
"ema_sample_raw_training": false,
"ema_sample_nonema_model": false,
"ema_sample_ema_model": false,
"ema_resume_model" : null
}

View File

@ -473,7 +473,6 @@ def log_args(log_writer, args):
log_writer.add_text("config", arglog)
def update_ema(model, ema_model, decay, default_device, ema_device):
with torch.no_grad():
original_model_on_proper_device = model
need_to_delete_original = False
@ -984,7 +983,7 @@ def main(args):
models_info = []
if (args.ema_decay_rate is None) or args.ema_sample_raw_training:
if (args.ema_decay_rate is None) or args.ema_sample_nonema_model:
models_info.append({"is_ema": False, "swap_required": False})
if (args.ema_decay_rate is not None) and args.ema_sample_ema_model:
@ -1054,8 +1053,6 @@ def main(args):
return os.path.join(log_folder, "ckpts", basename)
# Pre-train validation to establish a starting point on the loss graph
if validator:
validator.do_validation(global_step=0,
@ -1108,6 +1105,7 @@ def main(args):
data_root=args.data_root
)
loss_epoch = []
epoch_start_time = time.time()
images_per_sec_log_step = []
@ -1123,6 +1121,7 @@ def main(args):
for step, batch in enumerate(train_dataloader):
step_start_time = time.time()
plugin_runner.run_on_step_start(epoch=epoch,
local_step=step,
global_step=global_step,

View File

@ -44,10 +44,10 @@
"load_settings_every_epoch": false,
"min_snr_gamma": null,
"ema_decay_rate": null,
"ema_decay_target": null,
"ema_decay_interval": null,
"ema_decay_device": null,
"ema_decay_sample_raw_training": false,
"ema_decay_sample_ema_model": false,
"ema_decay_resume_model" : null
"ema_strength_target": null,
"ema_update_interval": null,
"ema_device": null,
"ema_sample_nonema_model": false,
"ema_sample_ema_model": false,
"ema_resume_model" : null
}