diff --git a/train.json b/train.json index eb03c68..5024002 100644 --- a/train.json +++ b/train.json @@ -5,6 +5,7 @@ "clip_skip": 0, "cond_dropout": 0.04, "data_root": "X:\\my_project_data\\project_abc", + "disable_amp": false, "disable_textenc_training": false, "disable_xformers": false, "flip_p": 0.0, diff --git a/train.py b/train.py index ac2eae1..5c9b0ac 100644 --- a/train.py +++ b/train.py @@ -58,6 +58,7 @@ from data.image_train_item import ImageTrainItem from utils.huggingface_downloader import try_download_model_from_hf from utils.convert_diff_to_ckpt import convert as converter from utils.isolate_rng import isolate_rng +from utils.check_git import check_git if torch.cuda.is_available(): from utils.gpu import GPU @@ -981,6 +982,7 @@ def main(args): if __name__ == "__main__": + check_git() supported_resolutions = aspects.get_supported_resolutions() argparser = argparse.ArgumentParser(description="EveryDream2 Training options") argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from") @@ -996,7 +998,7 @@ if __name__ == "__main__": print("No config file specified, using command line args") argparser = argparse.ArgumentParser(description="EveryDream2 Training options") - argparser.add_argument("--amp", action="store_true", default=True, help="deprecated, use --disable_amp if you wish to disable AMP") + #argparser.add_argument("--amp", action="store_true", default=True, help="deprecated, use --disable_amp if you wish to disable AMP") argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)") argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20") argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?") @@ -1019,7 +1021,7 @@ if __name__ == "__main__": argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"]) argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant") argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for") - argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)") + #argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)") argparser.add_argument("--optimizer_config", default="optimizer.json", help="Path to a JSON configuration file for the optimizer. Default is 'optimizer.json'") argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'") argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions) diff --git a/utils/check_git.py b/utils/check_git.py new file mode 100644 index 0000000..d46f319 --- /dev/null +++ b/utils/check_git.py @@ -0,0 +1,15 @@ +def check_git(): + import subprocess + + result = subprocess.run(["git", "symbolic-ref", "--short", "HEAD"], capture_output=True, text=True) + branch = result.stdout.strip() + + result = subprocess.run(["git", "rev-list", "--left-right", "--count", f"origin/{branch}...{branch}"], capture_output=True, text=True) + ahead, behind = map(int, result.stdout.split()) + + if behind > 0: + print(f"** Your branch '{branch}' is {behind} commit(s) behind the remote. Consider running 'git pull'.") + elif ahead > 0: + print(f"** Your branch '{branch}' is {ahead} commit(s) ahead the remote, consider a pull request.") + else: + print(f"** Your branch '{branch}' is up to date with the remote") \ No newline at end of file