diff --git a/train.py b/train.py index 9f47336..b6fe28a 100644 --- a/train.py +++ b/train.py @@ -267,6 +267,9 @@ def setup_args(args): args.lr = args.lr * (total_batch_size**0.55) logging.info(f"{Fore.CYAN} * Scaling learning rate {tmp_lr} by {total_batch_size**0.5}, new value: {args.lr}{Style.RESET_ALL}") + if not os.path.exists(args.save_ckpt_dir): + os.makedirs(args.save_ckpt_dir) + return args def main(args):