update windows to torch 2.1 and add attention type option

This commit is contained in:
Victor Hall 2023-11-24 16:13:01 -05:00
parent 772d6bc439
commit a284c52dee
4 changed files with 27 additions and 51 deletions

View File

@ -28,8 +28,8 @@
}, },
"text_encoder_overrides": { "text_encoder_overrides": {
"optimizer": null, "optimizer": null,
"lr": null, "lr": 5e-7,
"lr_scheduler": null, "lr_scheduler": "cosine",
"lr_decay_steps": null, "lr_decay_steps": null,
"lr_warmup_steps": null, "lr_warmup_steps": null,
"betas": null, "betas": null,

View File

@ -1,5 +1,6 @@
{ {
"batch_size": 10, "attn_type": "sdp",
"batch_size": 8,
"ckpt_every_n_minutes": null, "ckpt_every_n_minutes": null,
"clip_grad_norm": null, "clip_grad_norm": null,
"clip_skip": 0, "clip_skip": 0,
@ -7,19 +8,13 @@
"data_root": "/mnt/q/training_samples/ff7r/man", "data_root": "/mnt/q/training_samples/ff7r/man",
"disable_amp": false, "disable_amp": false,
"disable_textenc_training": false, "disable_textenc_training": false,
"disable_xformers": false,
"flip_p": 0.0, "flip_p": 0.0,
"gpuid": 0, "gpuid": 0,
"gradient_checkpointing": true, "gradient_checkpointing": true,
"grad_accum": 1, "grad_accum": 1,
"logdir": "logs", "logdir": "logs",
"log_step": 25, "log_step": 25,
"lowvram": false, "max_epochs": 40,
"lr": 1.5e-06,
"lr_decay_steps": 0,
"lr_scheduler": "constant",
"lr_warmup_steps": null,
"max_epochs": 1,
"notebook": false, "notebook": false,
"optimizer_config": "optimizer.json", "optimizer_config": "optimizer.json",
"project_name": "project_abc", "project_name": "project_abc",

View File

@ -299,19 +299,6 @@ def append_epoch_log(global_step: int, epoch_pbar, gpu, log_writer, **logs):
if logs is not None: if logs is not None:
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}") epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")
def set_args_12gb(args):
logging.info(" Setting args to 12GB mode")
if not args.gradient_checkpointing:
logging.info(" - Overiding gradient checkpointing to True")
args.gradient_checkpointing = True
if args.batch_size > 2:
logging.info(" - Overiding batch size to max 2")
args.batch_size = 2
args.grad_accum = 1
if args.resolution > 512:
logging.info(" - Overiding resolution to max 512")
args.resolution = 512
def find_last_checkpoint(logdir, is_ema=False): def find_last_checkpoint(logdir, is_ema=False):
""" """
Finds the last checkpoint in the logdir, recursively Finds the last checkpoint in the logdir, recursively
@ -365,9 +352,6 @@ def setup_args(args):
args.ema_resume_model = find_last_checkpoint(args.logdir, is_ema=True) args.ema_resume_model = find_last_checkpoint(args.logdir, is_ema=True)
if args.lowvram:
set_args_12gb(args)
if not args.shuffle_tags: if not args.shuffle_tags:
args.shuffle_tags = False args.shuffle_tags = False
@ -376,9 +360,6 @@ def setup_args(args):
args.clip_skip = max(min(4, args.clip_skip), 0) args.clip_skip = max(min(4, args.clip_skip), 0)
if args.useadam8bit:
logging.warning(f"{Fore.LIGHTYELLOW_EX} Useadam8bit arg is deprecated, use optimizer.json instead, which defaults to useadam8bit anyway{Style.RESET_ALL}")
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None: if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
logging.info(f"{Fore.LIGHTCYAN_EX} No checkpoint saving specified, defaulting to every 20 minutes.{Style.RESET_ALL}") logging.info(f"{Fore.LIGHTCYAN_EX} No checkpoint saving specified, defaulting to every 20 minutes.{Style.RESET_ALL}")
args.ckpt_every_n_minutes = 20 args.ckpt_every_n_minutes = 20
@ -699,21 +680,20 @@ def main(args):
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable() text_encoder.gradient_checkpointing_enable()
if not args.disable_xformers: if args.attn_type == "xformers":
if (args.amp and is_sd1attn) or (not is_sd1attn): if (args.amp and is_sd1attn) or (not is_sd1attn):
try: try:
unet.enable_xformers_memory_efficient_attention() unet.enable_xformers_memory_efficient_attention()
logging.info("Enabled xformers") logging.info("Enabled xformers")
except Exception as ex: except Exception as ex:
logging.warning("failed to load xformers, using attention slicing instead") logging.warning("failed to load xformers, using default SDP attention instead")
unet.set_attention_slice("auto")
pass pass
elif (not args.amp and is_sd1attn): elif (args.disable_amp and is_sd1attn):
logging.info("AMP is disabled but model is SD1.X, using attention slicing instead of xformers") logging.info("AMP is disabled but model is SD1.X, xformers is incompatible so using default attention")
unet.set_attention_slice("auto") elif args.attn_type == "slice":
else:
logging.info("xformers disabled via arg, using attention slicing instead")
unet.set_attention_slice("auto") unet.set_attention_slice("auto")
else:
logging.info("* Using SDP attention *")
vae = vae.to(device, dtype=torch.float16 if args.amp else torch.float32) vae = vae.to(device, dtype=torch.float16 if args.amp else torch.float32)
unet = unet.to(device, dtype=torch.float32) unet = unet.to(device, dtype=torch.float32)
@ -846,7 +826,7 @@ def main(args):
config_file_path=args.sample_prompts, config_file_path=args.sample_prompts,
batch_size=max(1,args.batch_size//2), batch_size=max(1,args.batch_size//2),
default_sample_steps=args.sample_steps, default_sample_steps=args.sample_steps,
use_xformers=is_xformers_available() and not args.disable_xformers, use_xformers=args.attn_type == "xformers",
use_penultimate_clip_layer=(args.clip_skip >= 2), use_penultimate_clip_layer=(args.clip_skip >= 2),
guidance_rescale=0.7 if args.enable_zero_terminal_snr else 0 guidance_rescale=0.7 if args.enable_zero_terminal_snr else 0
) )
@ -1337,30 +1317,29 @@ if __name__ == "__main__":
argparser = argparse.ArgumentParser(description="EveryDream2 Training options") argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
argparser.add_argument("--amp", action="store_true", default=True, help="deprecated, use --disable_amp if you wish to disable AMP") argparser.add_argument("--amp", action="store_true", default=True, help="deprecated, use --disable_amp if you wish to disable AMP")
argparser.add_argument("--attn_type", type=str, default="sdp", help="Attention mechanismto use", choices=["xformers", "sdp", "slice"])
argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)") argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)")
argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20") argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20")
argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?") argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?")
argparser.add_argument("--clip_skip", type=int, default=0, help="Train using penultimate layer (def: 0) (2 is 'penultimate')", choices=[0, 1, 2, 3, 4]) argparser.add_argument("--clip_skip", type=int, default=0, help="Train using penultimate layer (def: 0) (2 is 'penultimate')", choices=[0, 1, 2, 3, 4])
argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)") argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)")
argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are") argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are")
argparser.add_argument("--disable_amp", action="store_true", default=False, help="disables training of text encoder (def: False)") argparser.add_argument("--disable_amp", action="store_true", default=False, help="disables automatic mixed precision (def: False)")
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False)") argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False)")
argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED") argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED")
argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)")
argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!") argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!")
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1), use nvidia-smi to find your GPU ids") argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1), use nvidia-smi to find your GPU ids")
argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)") argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)")
argparser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation factor (def: 1), (ex, 2)") argparser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation factor (def: 1), (ex, 2)")
argparser.add_argument("--logdir", type=str, default="logs", help="folder to save logs to (def: logs)") argparser.add_argument("--logdir", type=str, default="logs", help="folder to save logs to (def: logs)")
argparser.add_argument("--log_step", type=int, default=25, help="How often to log training stats, def: 25, recommend default!") argparser.add_argument("--log_step", type=int, default=25, help="How often to log training stats, def: 25, recommend default!")
argparser.add_argument("--lowvram", action="store_true", default=False, help="automatically overrides various args to support 12GB gpu")
argparser.add_argument("--lr", type=float, default=None, help="Learning rate, if using scheduler is maximum LR at top of curve") argparser.add_argument("--lr", type=float, default=None, help="Learning rate, if using scheduler is maximum LR at top of curve")
argparser.add_argument("--lr_decay_steps", type=int, default=0, help="Steps to reach minimum LR, default: automatically set") argparser.add_argument("--lr_decay_steps", type=int, default=0, help="Steps to reach minimum LR, default: automatically set")
argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"]) argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"])
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant") argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant")
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for") argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
argparser.add_argument("--no_prepend_last", action="store_true", help="Do not prepend 'last-' to the final checkpoint filename") argparser.add_argument("--no_prepend_last", action="store_true", help="Do not prepend 'last-' to the final checkpoint filename")
argparser.add_argument("--no_save_ckpt", action="store_true", help="Save only diffusers files, no .ckpts" ) argparser.add_argument("--no_save_ckpt", action="store_true", help="Save only diffusers files, not .safetensors files (save disk space if you do not need LDM-style checkpoints)" )
argparser.add_argument("--optimizer_config", default="optimizer.json", help="Path to a JSON configuration file for the optimizer. Default is 'optimizer.json'") argparser.add_argument("--optimizer_config", default="optimizer.json", help="Path to a JSON configuration file for the optimizer. Default is 'optimizer.json'")
argparser.add_argument('--plugins', nargs='+', help='Names of plugins to use') argparser.add_argument('--plugins', nargs='+', help='Names of plugins to use')
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'") argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
@ -1376,9 +1355,8 @@ if __name__ == "__main__":
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later") argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random") argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets") argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
argparser.add_argument("--train_sampler", type=str, default="ddpm", help="sampler used for training, (default: ddpm)", choices=["ddpm", "pndm", "ddim"]) argparser.add_argument("--train_sampler", type=str, default="ddpm", help="noise sampler used for training, (default: ddpm)", choices=["ddpm", "pndm", "ddim"])
argparser.add_argument("--keep_tags", type=int, default=0, help="Number of tags to keep when shuffle, def: 0 (shuffle all)") argparser.add_argument("--keep_tags", type=int, default=0, help="Number of tags to keep when shuffle, used to randomly select subset of tags when shuffling is enabled, def: 0 (shuffle all)")
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="deprecated, use --optimizer_config and optimizer.json instead")
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY") argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
argparser.add_argument("--validation_config", default=None, help="Path to a JSON configuration file for the validator. Default is no validation.") argparser.add_argument("--validation_config", default=None, help="Path to a JSON configuration file for the validator. Default is no validation.")
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)") argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
@ -1386,7 +1364,7 @@ if __name__ == "__main__":
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)") argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
argparser.add_argument("--zero_frequency_noise_ratio", type=float, default=0.02, help="adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15") argparser.add_argument("--zero_frequency_noise_ratio", type=float, default=0.02, help="adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15")
argparser.add_argument("--enable_zero_terminal_snr", action="store_true", default=None, help="Use zero terminal SNR noising beta schedule") argparser.add_argument("--enable_zero_terminal_snr", action="store_true", default=None, help="Use zero terminal SNR noising beta schedule")
argparser.add_argument("--load_settings_every_epoch", action="store_true", default=None, help="Will load 'train.json' at start of every epoch. Disabled by default and enabled when used.") argparser.add_argument("--load_settings_every_epoch", action="store_true", default=None, help="Enable reloading of 'train.json' at start of every epoch.")
argparser.add_argument("--min_snr_gamma", type=int, default=None, help="min-SNR-gamma parameter is the loss function into individual tasks. Recommended values: 5, 1, 20. Disabled by default and enabled when used. More info: https://arxiv.org/abs/2303.09556") argparser.add_argument("--min_snr_gamma", type=int, default=None, help="min-SNR-gamma parameter is the loss function into individual tasks. Recommended values: 5, 1, 20. Disabled by default and enabled when used. More info: https://arxiv.org/abs/2303.09556")
argparser.add_argument("--ema_decay_rate", type=float, default=None, help="EMA decay rate. EMA model will be updated with (1 - ema_rate) from training, and the ema_rate from previous EMA, every interval. Values less than 1 and not so far from 1. Using this parameter will enable the feature.") argparser.add_argument("--ema_decay_rate", type=float, default=None, help="EMA decay rate. EMA model will be updated with (1 - ema_rate) from training, and the ema_rate from previous EMA, every interval. Values less than 1 and not so far from 1. Using this parameter will enable the feature.")
argparser.add_argument("--ema_strength_target", type=float, default=None, help="EMA decay target value in range (0,1). emarate will be calculated from equation: 'ema_decay_rate=ema_strength_target^(total_steps/ema_update_interval)'. Using this parameter will enable the ema feature and overide ema_decay_rate.") argparser.add_argument("--ema_strength_target", type=float, default=None, help="EMA decay target value in range (0,1). emarate will be calculated from equation: 'ema_decay_rate=ema_strength_target^(total_steps/ema_update_interval)'. Using this parameter will enable the ema feature and overide ema_decay_rate.")

View File

@ -3,9 +3,9 @@ call "venv\Scripts\activate.bat"
echo should be in venv here echo should be in venv here
cd . cd .
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --extra-index-url "https://download.pytorch.org/whl/cu118" pip install torch==2.1.0+cu121 torchvision==0.16.0+cu121 --extra-index-url "https://download.pytorch.org/whl/cu121"
pip install -U transformers==4.35.0 pip install -U transformers==4.35.0
pip install -U diffusers[torch]==0.21.4 pip install -U diffusers[torch]==0.23.1
pip install pynvml==11.4.1 pip install pynvml==11.4.1
pip install -U pip install -U https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl pip install -U pip install -U https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl
pip install ftfy==6.1.1 pip install ftfy==6.1.1
@ -14,17 +14,20 @@ pip install tensorboard>=2.11.0
pip install protobuf==3.20.1 pip install protobuf==3.20.1
pip install wandb==0.15.3 pip install wandb==0.15.3
pip install pyre-extensions==0.0.29 pip install pyre-extensions==0.0.29
pip install -U xformers==0.0.20 pip install -U xformers==0.0.22.post7
pip install pytorch-lightning==1.6.5 pip install pytorch-lightning==1.6.5
pip install OmegaConf==2.2.3 pip install OmegaConf==2.2.3
pip install numpy==1.23.5 pip install numpy>=1.23.5
pip install lion-pytorch pip install lion-pytorch
pip install compel~=1.1.3 pip install compel~=1.1.3
pip install dadaptation pip install dadaptation
pip install safetensors pip install safetensors
pip install open-flamingo==2.0.0
pip install prodigyopt pip install prodigyopt
pip install torchsde pip install torchsde
pip install --no-deps open-flamingo==2.0.1
pip install einops
pip install einops-exts
pip install open-clip-torch
python utils/get_yamls.py python utils/get_yamls.py
GOTO :eof GOTO :eof