Merge branch 'main' into feat_rolling_save
This commit is contained in:
commit
09aa13c3dd
|
@ -230,10 +230,10 @@ In this mode, the EMA model will be saved alongside the regular checkpoint from
|
|||
For more information, consult the [research paper](https://arxiv.org/abs/2101.08482) or continue reading the tuning notes below.
|
||||
**Parameters:**
|
||||
- `--ema_decay_rate`: Determines the EMA decay rate. It defines how much the EMA model is updated from training at each update. Values should be close to 1 but not exceed it. Activating this parameter triggers the EMA decay feature.
|
||||
- `--ema_decay_target`: Set the EMA decay target value within the (0,1) range. The `ema_decay_rate` is computed based on the relation: decay_rate to the power of (total_steps/decay_interval) equals decay_target. Enabling this parameter will override `ema_decay_rate` and will enable EMA feature.
|
||||
- `--ema_strength_target`: Set the EMA strength target value within the (0,1) range. The `ema_decay_rate` is computed based on the relation: decay_rate to the power of (total_steps/decay_interval) equals decay_target. Enabling this parameter will override `ema_decay_rate` and will enable EMA feature. See [ema_strength_target](#ema_strength_target) for more information.
|
||||
- `--ema_update_interval`: Set the interval in steps between EMA updates. The update occurs at each optimizer step. If you use grad_accum, actual update interval will be multipled by your grad_accum value.
|
||||
- `--ema_device`: Choose between `cpu` and `cuda` for EMA. Opting for 'cpu' takes around 4 seconds per update and uses approximately 3.2GB RAM, while 'cuda' is much faster but requires a similar amount of VRAM.
|
||||
- `--ema_sample_raw_training`: Activate to display samples from the trained model, mirroring conventional training. They will not be presented by default with EMA decay enabled.
|
||||
- `--ema_sample_nonema_model`: Activate to display samples from the non-ema trained model, mirroring conventional training. They will not be presented by default with EMA decay enabled.
|
||||
- `--ema_sample_ema_model`: Turn on to exhibit samples from the EMA model. EMA models will be used for samples generations by default with EMA decay enabled, unless disabled.
|
||||
- `--ema_resume_model`: Indicate the EMA decay checkpoint to continue from, working like `--resume_ckpt` but will load EMA model. Using `findlast` will only load EMA version and not regular training.
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
"clip_grad_norm": null,
|
||||
"clip_skip": 0,
|
||||
"cond_dropout": 0.04,
|
||||
"data_root": "X:\\my_project_data\\project_abc",
|
||||
"data_root": "/mnt/q/training_samples/ff7r/man",
|
||||
"disable_amp": false,
|
||||
"disable_textenc_training": false,
|
||||
"disable_xformers": false,
|
||||
|
@ -19,7 +19,7 @@
|
|||
"lr_decay_steps": 0,
|
||||
"lr_scheduler": "constant",
|
||||
"lr_warmup_steps": null,
|
||||
"max_epochs": 30,
|
||||
"max_epochs": 1,
|
||||
"notebook": false,
|
||||
"optimizer_config": "optimizer.json",
|
||||
"project_name": "project_abc",
|
||||
|
@ -45,10 +45,10 @@
|
|||
"load_settings_every_epoch": false,
|
||||
"min_snr_gamma": null,
|
||||
"ema_decay_rate": null,
|
||||
"ema_decay_target": null,
|
||||
"ema_strength_target": null,
|
||||
"ema_update_interval": null,
|
||||
"ema_device": null,
|
||||
"ema_sample_raw_training": false,
|
||||
"ema_sample_nonema_model": false,
|
||||
"ema_sample_ema_model": false,
|
||||
"ema_resume_model" : null
|
||||
}
|
||||
|
|
7
train.py
7
train.py
|
@ -473,7 +473,6 @@ def log_args(log_writer, args):
|
|||
log_writer.add_text("config", arglog)
|
||||
|
||||
def update_ema(model, ema_model, decay, default_device, ema_device):
|
||||
|
||||
with torch.no_grad():
|
||||
original_model_on_proper_device = model
|
||||
need_to_delete_original = False
|
||||
|
@ -984,7 +983,7 @@ def main(args):
|
|||
|
||||
models_info = []
|
||||
|
||||
if (args.ema_decay_rate is None) or args.ema_sample_raw_training:
|
||||
if (args.ema_decay_rate is None) or args.ema_sample_nonema_model:
|
||||
models_info.append({"is_ema": False, "swap_required": False})
|
||||
|
||||
if (args.ema_decay_rate is not None) and args.ema_sample_ema_model:
|
||||
|
@ -1054,8 +1053,6 @@ def main(args):
|
|||
return os.path.join(log_folder, "ckpts", basename)
|
||||
|
||||
|
||||
|
||||
|
||||
# Pre-train validation to establish a starting point on the loss graph
|
||||
if validator:
|
||||
validator.do_validation(global_step=0,
|
||||
|
@ -1108,6 +1105,7 @@ def main(args):
|
|||
data_root=args.data_root
|
||||
)
|
||||
|
||||
|
||||
loss_epoch = []
|
||||
epoch_start_time = time.time()
|
||||
images_per_sec_log_step = []
|
||||
|
@ -1123,6 +1121,7 @@ def main(args):
|
|||
for step, batch in enumerate(train_dataloader):
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
plugin_runner.run_on_step_start(epoch=epoch,
|
||||
local_step=step,
|
||||
global_step=global_step,
|
||||
|
|
|
@ -44,10 +44,10 @@
|
|||
"load_settings_every_epoch": false,
|
||||
"min_snr_gamma": null,
|
||||
"ema_decay_rate": null,
|
||||
"ema_decay_target": null,
|
||||
"ema_decay_interval": null,
|
||||
"ema_decay_device": null,
|
||||
"ema_decay_sample_raw_training": false,
|
||||
"ema_decay_sample_ema_model": false,
|
||||
"ema_decay_resume_model" : null
|
||||
"ema_strength_target": null,
|
||||
"ema_update_interval": null,
|
||||
"ema_device": null,
|
||||
"ema_sample_nonema_model": false,
|
||||
"ema_sample_ema_model": false,
|
||||
"ema_resume_model" : null
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue