diff --git a/train.py b/train.py index 669c644..21d0466 100644 --- a/train.py +++ b/train.py @@ -917,10 +917,10 @@ def main(args): models_info = [] - if (args.ema_decay_rate is None) or (args.ema_decay_sample_raw_training is not None): + if (args.ema_decay_rate is None) or args.ema_decay_sample_raw_training: models_info.append({"is_ema": False, "swap_required": False}) - if (args.ema_decay_rate is not None) and (args.ema_decay_sample_ema_model is not None): + if (args.ema_decay_rate is not None) and args.ema_decay_sample_ema_model: models_info.append({"is_ema": True, "swap_required": args.ema_decay_device != device}) for model_info in models_info: