update windows to torch 2.1 and add attention type option
This commit is contained in:
parent
772d6bc439
commit
a284c52dee
|
@ -28,8 +28,8 @@
|
|||
},
|
||||
"text_encoder_overrides": {
|
||||
"optimizer": null,
|
||||
"lr": null,
|
||||
"lr_scheduler": null,
|
||||
"lr": 5e-7,
|
||||
"lr_scheduler": "cosine",
|
||||
"lr_decay_steps": null,
|
||||
"lr_warmup_steps": null,
|
||||
"betas": null,
|
||||
|
|
11
train.json
11
train.json
|
@ -1,5 +1,6 @@
|
|||
{
|
||||
"batch_size": 10,
|
||||
"attn_type": "sdp",
|
||||
"batch_size": 8,
|
||||
"ckpt_every_n_minutes": null,
|
||||
"clip_grad_norm": null,
|
||||
"clip_skip": 0,
|
||||
|
@ -7,19 +8,13 @@
|
|||
"data_root": "/mnt/q/training_samples/ff7r/man",
|
||||
"disable_amp": false,
|
||||
"disable_textenc_training": false,
|
||||
"disable_xformers": false,
|
||||
"flip_p": 0.0,
|
||||
"gpuid": 0,
|
||||
"gradient_checkpointing": true,
|
||||
"grad_accum": 1,
|
||||
"logdir": "logs",
|
||||
"log_step": 25,
|
||||
"lowvram": false,
|
||||
"lr": 1.5e-06,
|
||||
"lr_decay_steps": 0,
|
||||
"lr_scheduler": "constant",
|
||||
"lr_warmup_steps": null,
|
||||
"max_epochs": 1,
|
||||
"max_epochs": 40,
|
||||
"notebook": false,
|
||||
"optimizer_config": "optimizer.json",
|
||||
"project_name": "project_abc",
|
||||
|
|
50
train.py
50
train.py
|
@ -299,19 +299,6 @@ def append_epoch_log(global_step: int, epoch_pbar, gpu, log_writer, **logs):
|
|||
if logs is not None:
|
||||
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")
|
||||
|
||||
def set_args_12gb(args):
|
||||
logging.info(" Setting args to 12GB mode")
|
||||
if not args.gradient_checkpointing:
|
||||
logging.info(" - Overiding gradient checkpointing to True")
|
||||
args.gradient_checkpointing = True
|
||||
if args.batch_size > 2:
|
||||
logging.info(" - Overiding batch size to max 2")
|
||||
args.batch_size = 2
|
||||
args.grad_accum = 1
|
||||
if args.resolution > 512:
|
||||
logging.info(" - Overiding resolution to max 512")
|
||||
args.resolution = 512
|
||||
|
||||
def find_last_checkpoint(logdir, is_ema=False):
|
||||
"""
|
||||
Finds the last checkpoint in the logdir, recursively
|
||||
|
@ -365,9 +352,6 @@ def setup_args(args):
|
|||
|
||||
args.ema_resume_model = find_last_checkpoint(args.logdir, is_ema=True)
|
||||
|
||||
if args.lowvram:
|
||||
set_args_12gb(args)
|
||||
|
||||
if not args.shuffle_tags:
|
||||
args.shuffle_tags = False
|
||||
|
||||
|
@ -376,9 +360,6 @@ def setup_args(args):
|
|||
|
||||
args.clip_skip = max(min(4, args.clip_skip), 0)
|
||||
|
||||
if args.useadam8bit:
|
||||
logging.warning(f"{Fore.LIGHTYELLOW_EX} Useadam8bit arg is deprecated, use optimizer.json instead, which defaults to useadam8bit anyway{Style.RESET_ALL}")
|
||||
|
||||
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
|
||||
logging.info(f"{Fore.LIGHTCYAN_EX} No checkpoint saving specified, defaulting to every 20 minutes.{Style.RESET_ALL}")
|
||||
args.ckpt_every_n_minutes = 20
|
||||
|
@ -699,21 +680,20 @@ def main(args):
|
|||
unet.enable_gradient_checkpointing()
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
if not args.disable_xformers:
|
||||
if args.attn_type == "xformers":
|
||||
if (args.amp and is_sd1attn) or (not is_sd1attn):
|
||||
try:
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
logging.info("Enabled xformers")
|
||||
except Exception as ex:
|
||||
logging.warning("failed to load xformers, using attention slicing instead")
|
||||
unet.set_attention_slice("auto")
|
||||
logging.warning("failed to load xformers, using default SDP attention instead")
|
||||
pass
|
||||
elif (not args.amp and is_sd1attn):
|
||||
logging.info("AMP is disabled but model is SD1.X, using attention slicing instead of xformers")
|
||||
unet.set_attention_slice("auto")
|
||||
else:
|
||||
logging.info("xformers disabled via arg, using attention slicing instead")
|
||||
elif (args.disable_amp and is_sd1attn):
|
||||
logging.info("AMP is disabled but model is SD1.X, xformers is incompatible so using default attention")
|
||||
elif args.attn_type == "slice":
|
||||
unet.set_attention_slice("auto")
|
||||
else:
|
||||
logging.info("* Using SDP attention *")
|
||||
|
||||
vae = vae.to(device, dtype=torch.float16 if args.amp else torch.float32)
|
||||
unet = unet.to(device, dtype=torch.float32)
|
||||
|
@ -846,7 +826,7 @@ def main(args):
|
|||
config_file_path=args.sample_prompts,
|
||||
batch_size=max(1,args.batch_size//2),
|
||||
default_sample_steps=args.sample_steps,
|
||||
use_xformers=is_xformers_available() and not args.disable_xformers,
|
||||
use_xformers=args.attn_type == "xformers",
|
||||
use_penultimate_clip_layer=(args.clip_skip >= 2),
|
||||
guidance_rescale=0.7 if args.enable_zero_terminal_snr else 0
|
||||
)
|
||||
|
@ -1337,30 +1317,29 @@ if __name__ == "__main__":
|
|||
|
||||
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
||||
argparser.add_argument("--amp", action="store_true", default=True, help="deprecated, use --disable_amp if you wish to disable AMP")
|
||||
argparser.add_argument("--attn_type", type=str, default="sdp", help="Attention mechanismto use", choices=["xformers", "sdp", "slice"])
|
||||
argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)")
|
||||
argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20")
|
||||
argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?")
|
||||
argparser.add_argument("--clip_skip", type=int, default=0, help="Train using penultimate layer (def: 0) (2 is 'penultimate')", choices=[0, 1, 2, 3, 4])
|
||||
argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)")
|
||||
argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are")
|
||||
argparser.add_argument("--disable_amp", action="store_true", default=False, help="disables training of text encoder (def: False)")
|
||||
argparser.add_argument("--disable_amp", action="store_true", default=False, help="disables automatic mixed precision (def: False)")
|
||||
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False)")
|
||||
argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED")
|
||||
argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)")
|
||||
argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!")
|
||||
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1), use nvidia-smi to find your GPU ids")
|
||||
argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)")
|
||||
argparser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation factor (def: 1), (ex, 2)")
|
||||
argparser.add_argument("--logdir", type=str, default="logs", help="folder to save logs to (def: logs)")
|
||||
argparser.add_argument("--log_step", type=int, default=25, help="How often to log training stats, def: 25, recommend default!")
|
||||
argparser.add_argument("--lowvram", action="store_true", default=False, help="automatically overrides various args to support 12GB gpu")
|
||||
argparser.add_argument("--lr", type=float, default=None, help="Learning rate, if using scheduler is maximum LR at top of curve")
|
||||
argparser.add_argument("--lr_decay_steps", type=int, default=0, help="Steps to reach minimum LR, default: automatically set")
|
||||
argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"])
|
||||
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant")
|
||||
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
|
||||
argparser.add_argument("--no_prepend_last", action="store_true", help="Do not prepend 'last-' to the final checkpoint filename")
|
||||
argparser.add_argument("--no_save_ckpt", action="store_true", help="Save only diffusers files, no .ckpts" )
|
||||
argparser.add_argument("--no_save_ckpt", action="store_true", help="Save only diffusers files, not .safetensors files (save disk space if you do not need LDM-style checkpoints)" )
|
||||
argparser.add_argument("--optimizer_config", default="optimizer.json", help="Path to a JSON configuration file for the optimizer. Default is 'optimizer.json'")
|
||||
argparser.add_argument('--plugins', nargs='+', help='Names of plugins to use')
|
||||
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
|
||||
|
@ -1376,9 +1355,8 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
|
||||
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
|
||||
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
|
||||
argparser.add_argument("--train_sampler", type=str, default="ddpm", help="sampler used for training, (default: ddpm)", choices=["ddpm", "pndm", "ddim"])
|
||||
argparser.add_argument("--keep_tags", type=int, default=0, help="Number of tags to keep when shuffle, def: 0 (shuffle all)")
|
||||
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="deprecated, use --optimizer_config and optimizer.json instead")
|
||||
argparser.add_argument("--train_sampler", type=str, default="ddpm", help="noise sampler used for training, (default: ddpm)", choices=["ddpm", "pndm", "ddim"])
|
||||
argparser.add_argument("--keep_tags", type=int, default=0, help="Number of tags to keep when shuffle, used to randomly select subset of tags when shuffling is enabled, def: 0 (shuffle all)")
|
||||
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
|
||||
argparser.add_argument("--validation_config", default=None, help="Path to a JSON configuration file for the validator. Default is no validation.")
|
||||
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
|
||||
|
@ -1386,7 +1364,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
|
||||
argparser.add_argument("--zero_frequency_noise_ratio", type=float, default=0.02, help="adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15")
|
||||
argparser.add_argument("--enable_zero_terminal_snr", action="store_true", default=None, help="Use zero terminal SNR noising beta schedule")
|
||||
argparser.add_argument("--load_settings_every_epoch", action="store_true", default=None, help="Will load 'train.json' at start of every epoch. Disabled by default and enabled when used.")
|
||||
argparser.add_argument("--load_settings_every_epoch", action="store_true", default=None, help="Enable reloading of 'train.json' at start of every epoch.")
|
||||
argparser.add_argument("--min_snr_gamma", type=int, default=None, help="min-SNR-gamma parameter is the loss function into individual tasks. Recommended values: 5, 1, 20. Disabled by default and enabled when used. More info: https://arxiv.org/abs/2303.09556")
|
||||
argparser.add_argument("--ema_decay_rate", type=float, default=None, help="EMA decay rate. EMA model will be updated with (1 - ema_rate) from training, and the ema_rate from previous EMA, every interval. Values less than 1 and not so far from 1. Using this parameter will enable the feature.")
|
||||
argparser.add_argument("--ema_strength_target", type=float, default=None, help="EMA decay target value in range (0,1). emarate will be calculated from equation: 'ema_decay_rate=ema_strength_target^(total_steps/ema_update_interval)'. Using this parameter will enable the ema feature and overide ema_decay_rate.")
|
||||
|
|
|
@ -3,9 +3,9 @@ call "venv\Scripts\activate.bat"
|
|||
echo should be in venv here
|
||||
cd .
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --extra-index-url "https://download.pytorch.org/whl/cu118"
|
||||
pip install torch==2.1.0+cu121 torchvision==0.16.0+cu121 --extra-index-url "https://download.pytorch.org/whl/cu121"
|
||||
pip install -U transformers==4.35.0
|
||||
pip install -U diffusers[torch]==0.21.4
|
||||
pip install -U diffusers[torch]==0.23.1
|
||||
pip install pynvml==11.4.1
|
||||
pip install -U pip install -U https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl
|
||||
pip install ftfy==6.1.1
|
||||
|
@ -14,17 +14,20 @@ pip install tensorboard>=2.11.0
|
|||
pip install protobuf==3.20.1
|
||||
pip install wandb==0.15.3
|
||||
pip install pyre-extensions==0.0.29
|
||||
pip install -U xformers==0.0.20
|
||||
pip install -U xformers==0.0.22.post7
|
||||
pip install pytorch-lightning==1.6.5
|
||||
pip install OmegaConf==2.2.3
|
||||
pip install numpy==1.23.5
|
||||
pip install numpy>=1.23.5
|
||||
pip install lion-pytorch
|
||||
pip install compel~=1.1.3
|
||||
pip install dadaptation
|
||||
pip install safetensors
|
||||
pip install open-flamingo==2.0.0
|
||||
pip install prodigyopt
|
||||
pip install torchsde
|
||||
pip install --no-deps open-flamingo==2.0.1
|
||||
pip install einops
|
||||
pip install einops-exts
|
||||
pip install open-clip-torch
|
||||
python utils/get_yamls.py
|
||||
GOTO :eof
|
||||
|
||||
|
|
Loading…
Reference in New Issue