diff --git a/doc/ADVANCED_TWEAKING.md b/doc/ADVANCED_TWEAKING.md index 894a894..dab9a47 100644 --- a/doc/ADVANCED_TWEAKING.md +++ b/doc/ADVANCED_TWEAKING.md @@ -230,10 +230,10 @@ In this mode, the EMA model will be saved alongside the regular checkpoint from For more information, consult the [research paper](https://arxiv.org/abs/2101.08482) or continue reading the tuning notes below. **Parameters:** - `--ema_decay_rate`: Determines the EMA decay rate. It defines how much the EMA model is updated from training at each update. Values should be close to 1 but not exceed it. Activating this parameter triggers the EMA decay feature. -- `--ema_decay_target`: Set the EMA decay target value within the (0,1) range. The `ema_decay_rate` is computed based on the relation: decay_rate to the power of (total_steps/decay_interval) equals decay_target. Enabling this parameter will override `ema_decay_rate` and will enable EMA feature. +- `--ema_strength_target`: Set the EMA strength target value within the (0,1) range. The `ema_decay_rate` is computed based on the relation: decay_rate to the power of (total_steps/decay_interval) equals decay_target. Enabling this parameter will override `ema_decay_rate` and will enable EMA feature. See [ema_strength_target](#ema_strength_target) for more information. - `--ema_update_interval`: Set the interval in steps between EMA updates. The update occurs at each optimizer step. If you use grad_accum, actual update interval will be multipled by your grad_accum value. - `--ema_device`: Choose between `cpu` and `cuda` for EMA. Opting for 'cpu' takes around 4 seconds per update and uses approximately 3.2GB RAM, while 'cuda' is much faster but requires a similar amount of VRAM. -- `--ema_sample_raw_training`: Activate to display samples from the trained model, mirroring conventional training. They will not be presented by default with EMA decay enabled. +- `--ema_sample_nonema_model`: Activate to display samples from the non-ema trained model, mirroring conventional training. They will not be presented by default with EMA decay enabled. - `--ema_sample_ema_model`: Turn on to exhibit samples from the EMA model. EMA models will be used for samples generations by default with EMA decay enabled, unless disabled. - `--ema_resume_model`: Indicate the EMA decay checkpoint to continue from, working like `--resume_ckpt` but will load EMA model. Using `findlast` will only load EMA version and not regular training. diff --git a/train.json b/train.json index f74e100..755ee76 100644 --- a/train.json +++ b/train.json @@ -4,7 +4,7 @@ "clip_grad_norm": null, "clip_skip": 0, "cond_dropout": 0.04, - "data_root": "X:\\my_project_data\\project_abc", + "data_root": "/mnt/q/training_samples/ff7r/man", "disable_amp": false, "disable_textenc_training": false, "disable_xformers": false, @@ -19,7 +19,7 @@ "lr_decay_steps": 0, "lr_scheduler": "constant", "lr_warmup_steps": null, - "max_epochs": 30, + "max_epochs": 1, "notebook": false, "optimizer_config": "optimizer.json", "project_name": "project_abc", @@ -45,10 +45,10 @@ "load_settings_every_epoch": false, "min_snr_gamma": null, "ema_decay_rate": null, - "ema_decay_target": null, + "ema_strength_target": null, "ema_update_interval": null, "ema_device": null, - "ema_sample_raw_training": false, + "ema_sample_nonema_model": false, "ema_sample_ema_model": false, "ema_resume_model" : null } diff --git a/train.py b/train.py index 59d7920..58a2e3f 100644 --- a/train.py +++ b/train.py @@ -473,7 +473,6 @@ def log_args(log_writer, args): log_writer.add_text("config", arglog) def update_ema(model, ema_model, decay, default_device, ema_device): - with torch.no_grad(): original_model_on_proper_device = model need_to_delete_original = False @@ -984,7 +983,7 @@ def main(args): models_info = [] - if (args.ema_decay_rate is None) or args.ema_sample_raw_training: + if (args.ema_decay_rate is None) or args.ema_sample_nonema_model: models_info.append({"is_ema": False, "swap_required": False}) if (args.ema_decay_rate is not None) and args.ema_sample_ema_model: @@ -1054,8 +1053,6 @@ def main(args): return os.path.join(log_folder, "ckpts", basename) - - # Pre-train validation to establish a starting point on the loss graph if validator: validator.do_validation(global_step=0, @@ -1108,6 +1105,7 @@ def main(args): data_root=args.data_root ) + loss_epoch = [] epoch_start_time = time.time() images_per_sec_log_step = [] @@ -1123,6 +1121,7 @@ def main(args): for step, batch in enumerate(train_dataloader): step_start_time = time.time() + plugin_runner.run_on_step_start(epoch=epoch, local_step=step, global_step=global_step, diff --git a/trainSD21.json b/trainSD21.json index 9764522..56d8fd1 100644 --- a/trainSD21.json +++ b/trainSD21.json @@ -44,10 +44,10 @@ "load_settings_every_epoch": false, "min_snr_gamma": null, "ema_decay_rate": null, - "ema_decay_target": null, - "ema_decay_interval": null, - "ema_decay_device": null, - "ema_decay_sample_raw_training": false, - "ema_decay_sample_ema_model": false, - "ema_decay_resume_model" : null + "ema_strength_target": null, + "ema_update_interval": null, + "ema_device": null, + "ema_sample_nonema_model": false, + "ema_sample_ema_model": false, + "ema_resume_model" : null }