diff --git a/train.py b/train.py index 140ae60..9ae9119 100644 --- a/train.py +++ b/train.py @@ -294,7 +294,6 @@ def main(args): else: from tqdm.auto import tqdm - logging.info(f" Seed: {args.seed}") seed = args.seed if args.seed != -1 else random.randint(0, 2**30) logging.info(f" Seed: {seed}") set_seed(seed) @@ -897,6 +896,12 @@ def update_old_args(t_args): if not hasattr(t_args, "mixed_precision"): print(f" Config json is missing 'mixed_precision' flag") t_args.__dict__["mixed_precision"] = "fp32" + if not hasattr(t_args, "rated_dataset"): + print(f" Config json is missing 'rated_dataset' flag") + t_args.__dict__["rated_dataset"] = False + if not hasattr(t_args, "rated_dataset_target_dropout_percent"): + print(f" Config json is missing 'rated_dataset_target_dropout_percent' flag") + t_args.__dict__["rated_dataset_target_dropout_percent"] = 50 if __name__ == "__main__":