ema update
This commit is contained in:
parent
303c8312e3
commit
2dff3aa8d1
7
train.py
7
train.py
|
@ -369,7 +369,6 @@ def log_args(log_writer, args):
|
|||
log_writer.add_text("config", arglog)
|
||||
|
||||
def update_ema(model, ema_model, decay, default_device, ema_device):
|
||||
|
||||
with torch.no_grad():
|
||||
original_model_on_proper_device = model
|
||||
need_to_delete_original = False
|
||||
|
@ -970,7 +969,7 @@ def main(args):
|
|||
|
||||
models_info = []
|
||||
|
||||
if (args.ema_decay_rate is None) or args.ema_sample_raw_training:
|
||||
if (args.ema_decay_rate is None) or args.ema_sample_nonema_model:
|
||||
models_info.append({"is_ema": False, "swap_required": False})
|
||||
|
||||
if (args.ema_decay_rate is not None) and args.ema_sample_ema_model:
|
||||
|
@ -1035,8 +1034,6 @@ def main(args):
|
|||
return os.path.join(f"{log_folder}/ckpts/{prepend}{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
|
||||
|
||||
|
||||
|
||||
# Pre-train validation to establish a starting point on the loss graph
|
||||
if validator:
|
||||
validator.do_validation(global_step=0,
|
||||
|
@ -1065,7 +1062,6 @@ def main(args):
|
|||
if args.load_settings_every_epoch:
|
||||
load_train_json_from_file(args)
|
||||
|
||||
|
||||
plugin_runner.run_on_epoch_start(epoch=epoch,
|
||||
global_step=global_step,
|
||||
project_name=args.project_name,
|
||||
|
@ -1087,6 +1083,7 @@ def main(args):
|
|||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
step_start_time = time.time()
|
||||
|
||||
plugin_runner.run_on_step_start(epoch=epoch,
|
||||
global_step=global_step,
|
||||
project_name=args.project_name,
|
||||
|
|
Loading…
Reference in New Issue