From 3e803a831337aedb712dc039e6195e3784a00214 Mon Sep 17 00:00:00 2001 From: Victor Hall Date: Tue, 17 Jan 2023 12:44:18 -0500 Subject: [PATCH] deprecate ed1_mode, autodetect --- README.md | 4 ++++ train.py | 39 ++++++++++++++++----------------------- utils/patch_unet.py | 12 +++++------- 3 files changed, 25 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index 784020f..35ee21f 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,10 @@ Covers install, setup of base models, startning training, basic tweaking, and lo Behind the scenes look at how the trainer handles multiaspect and crop jitter +### Tools repo + +Make sure to check out the [tools repo](https://github.com/victorchall/EveryDream), it has a grab bag of scripts to help with your data curation prior to training. It has automatic bulk BLIP captioning for BLIP, script to web scrape based on Laion data files, script to rename generic pronouns to proper names or append artist tags to your captions, etc. + ## Docs [Setup and installation](doc/SETUP.md) diff --git a/train.py b/train.py index 872a406..e9f02ec 100644 --- a/train.py +++ b/train.py @@ -92,14 +92,14 @@ def convert_to_hf(ckpt_path): else: logging.info(f"Found cached checkpoint at {hf_cache}") - patch_unet(hf_cache, args.ed1_mode, args.lowvram) - return hf_cache + is_sd1attn = patch_unet(hf_cache) + return hf_cache, is_sd1attn elif os.path.isdir(hf_cache): - patch_unet(hf_cache, args.ed1_mode, args.lowvram) - return hf_cache + is_sd1attn = patch_unet(hf_cache) + return hf_cache, is_sd1attn else: - patch_unet(ckpt_path, args.ed1_mode, args.lowvram) - return ckpt_path + is_sd1attn = patch_unet(ckpt_path) + return ckpt_path, is_sd1attn def setup_local_logger(args): """ @@ -230,10 +230,6 @@ def setup_args(args): # find the last checkpoint in the logdir args.resume_ckpt = find_last_checkpoint(args.logdir) - if args.ed1_mode and not args.amp and not args.disable_xformers: - args.disable_xformers = True - logging.info(" ED1 mode without amp: Overiding disable_xformers to True") - if args.lowvram: set_args_12gb(args) @@ -398,7 +394,7 @@ def main(args): """ logging.info(f"Generating samples gs:{gs}, for {prompts}") seed = args.seed if args.seed != -1 else random.randint(0, 2**30) - gen = torch.Generator(device="cuda").manual_seed(seed) + gen = torch.Generator(device=device).manual_seed(seed) i = 0 for prompt in prompts: @@ -442,10 +438,10 @@ def main(args): del images try: - hf_ckpt_path = convert_to_hf(args.resume_ckpt) + hf_ckpt_path, is_sd1attn = convert_to_hf(args.resume_ckpt) text_encoder = CLIPTextModel.from_pretrained(hf_ckpt_path, subfolder="text_encoder") vae = AutoencoderKL.from_pretrained(hf_ckpt_path, subfolder="vae") - unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet", upcast_attention=not args.ed1_mode) + unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet", upcast_attention=not is_sd1attn) sample_scheduler = DDIMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler") tokenizer = CLIPTokenizer.from_pretrained(hf_ckpt_path, subfolder="tokenizer", use_fast=False) @@ -455,11 +451,8 @@ def main(args): if args.gradient_checkpointing: unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() - - if args.ed1_mode and not args.lowvram: - unet.set_attention_slice(4) - if not args.disable_xformers and is_xformers_available(): + if not args.disable_xformers and (args.amp and is_sd1attn) or (not is_sd1attn): try: unet.enable_xformers_memory_efficient_attention() logging.info("Enabled xformers") @@ -660,8 +653,8 @@ def main(args): enabled=args.amp, #enabled=True, init_scale=2**17.5, - growth_factor=1.5, - backoff_factor=1.0/1.5, + growth_factor=1.8, + backoff_factor=1.0/1.8, growth_interval=50, ) logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)") @@ -831,11 +824,12 @@ def main(args): global_step += 1 if global_step == 500: - scaler.set_growth_factor(1.35) - scaler.set_backoff_factor(1/1.35) + scaler.set_growth_factor(1.4) + scaler.set_backoff_factor(1/1.4) if global_step == 1000: scaler.set_growth_factor(1.2) scaler.set_backoff_factor(1/1.2) + scaler.set_growth_interval(100) # end of step steps_pbar.close() @@ -926,7 +920,6 @@ if __name__ == "__main__": argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED") argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)") argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!") - argparser.add_argument("--ed1_mode", action="store_true", default=False, help="Disables xformers and reduces attention heads to 8 (SD1.x style)") argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1)") argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)") argparser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation factor (def: 1), (ex, 2)") @@ -957,6 +950,6 @@ if __name__ == "__main__": argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs") argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)") - args = argparser.parse_args() + args, _ = argparser.parse_known_args() main(args) diff --git a/utils/patch_unet.py b/utils/patch_unet.py index 4f66747..cc00349 100644 --- a/utils/patch_unet.py +++ b/utils/patch_unet.py @@ -17,7 +17,7 @@ import os import json import logging -def patch_unet(ckpt_path, force_sd1attn: bool = False, low_vram: bool = False): +def patch_unet(ckpt_path): """ Patch the UNet to use updated attention heads for xformers support in FP32 """ @@ -25,11 +25,9 @@ def patch_unet(ckpt_path, force_sd1attn: bool = False, low_vram: bool = False): with open(unet_cfg_path, "r") as f: unet_cfg = json.load(f) - if force_sd1attn: - unet_cfg["attention_head_dim"] = [8, 8, 8, 8] - else: # SD 2 attn for xformers - unet_cfg["attention_head_dim"] = [5, 10, 20, 20] + is_sd1attn = unet_cfg["attention_head_dim"] == [8, 8, 8, 8] + is_sd1attn = unet_cfg["attention_head_dim"] == 8 or is_sd1attn logging.info(f" unet attention_head_dim: {unet_cfg['attention_head_dim']}") - with open(unet_cfg_path, "w") as f: - json.dump(unet_cfg, f, indent=2) + + return is_sd1attn