diff --git a/train.py b/train.py index 2fdd944..af9cbc6 100644 --- a/train.py +++ b/train.py @@ -369,7 +369,6 @@ def log_args(log_writer, args): log_writer.add_text("config", arglog) def update_ema(model, ema_model, decay, default_device, ema_device): - with torch.no_grad(): original_model_on_proper_device = model need_to_delete_original = False @@ -970,7 +969,7 @@ def main(args): models_info = [] - if (args.ema_decay_rate is None) or args.ema_sample_raw_training: + if (args.ema_decay_rate is None) or args.ema_sample_nonema_model: models_info.append({"is_ema": False, "swap_required": False}) if (args.ema_decay_rate is not None) and args.ema_sample_ema_model: @@ -1035,8 +1034,6 @@ def main(args): return os.path.join(f"{log_folder}/ckpts/{prepend}{args.project_name}-ep{epoch:02}-gs{global_step:05}") - - # Pre-train validation to establish a starting point on the loss graph if validator: validator.do_validation(global_step=0, @@ -1065,7 +1062,6 @@ def main(args): if args.load_settings_every_epoch: load_train_json_from_file(args) - plugin_runner.run_on_epoch_start(epoch=epoch, global_step=global_step, project_name=args.project_name, @@ -1087,6 +1083,7 @@ def main(args): for step, batch in enumerate(train_dataloader): step_start_time = time.time() + plugin_runner.run_on_step_start(epoch=epoch, global_step=global_step, project_name=args.project_name,