check if your git commit is out of date

This commit is contained in:
Victor Hall 2023-04-29 18:15:25 -04:00
parent aad00eab2e
commit 29a19fd8b1
3 changed files with 20 additions and 2 deletions

View File

@ -5,6 +5,7 @@
"clip_skip": 0,
"cond_dropout": 0.04,
"data_root": "X:\\my_project_data\\project_abc",
"disable_amp": false,
"disable_textenc_training": false,
"disable_xformers": false,
"flip_p": 0.0,

View File

@ -58,6 +58,7 @@ from data.image_train_item import ImageTrainItem
from utils.huggingface_downloader import try_download_model_from_hf
from utils.convert_diff_to_ckpt import convert as converter
from utils.isolate_rng import isolate_rng
from utils.check_git import check_git
if torch.cuda.is_available():
from utils.gpu import GPU
@ -981,6 +982,7 @@ def main(args):
if __name__ == "__main__":
check_git()
supported_resolutions = aspects.get_supported_resolutions()
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from")
@ -996,7 +998,7 @@ if __name__ == "__main__":
print("No config file specified, using command line args")
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
argparser.add_argument("--amp", action="store_true", default=True, help="deprecated, use --disable_amp if you wish to disable AMP")
#argparser.add_argument("--amp", action="store_true", default=True, help="deprecated, use --disable_amp if you wish to disable AMP")
argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)")
argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20")
argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?")
@ -1019,7 +1021,7 @@ if __name__ == "__main__":
argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"])
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant")
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
#argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
argparser.add_argument("--optimizer_config", default="optimizer.json", help="Path to a JSON configuration file for the optimizer. Default is 'optimizer.json'")
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)

15
utils/check_git.py Normal file
View File

@ -0,0 +1,15 @@
def check_git():
import subprocess
result = subprocess.run(["git", "symbolic-ref", "--short", "HEAD"], capture_output=True, text=True)
branch = result.stdout.strip()
result = subprocess.run(["git", "rev-list", "--left-right", "--count", f"origin/{branch}...{branch}"], capture_output=True, text=True)
ahead, behind = map(int, result.stdout.split())
if behind > 0:
print(f"** Your branch '{branch}' is {behind} commit(s) behind the remote. Consider running 'git pull'.")
elif ahead > 0:
print(f"** Your branch '{branch}' is {ahead} commit(s) ahead the remote, consider a pull request.")
else:
print(f"** Your branch '{branch}' is up to date with the remote")