deprecate ed1_mode, autodetect

This commit is contained in:
Victor Hall 2023-01-17 12:44:18 -05:00
parent ed025d27b6
commit 3e803a8313
3 changed files with 25 additions and 30 deletions

View File

@ -17,6 +17,10 @@ Covers install, setup of base models, startning training, basic tweaking, and lo
Behind the scenes look at how the trainer handles multiaspect and crop jitter
### Tools repo
Make sure to check out the [tools repo](https://github.com/victorchall/EveryDream), it has a grab bag of scripts to help with your data curation prior to training. It has automatic bulk BLIP captioning for BLIP, script to web scrape based on Laion data files, script to rename generic pronouns to proper names or append artist tags to your captions, etc.
## Docs
[Setup and installation](doc/SETUP.md)

View File

@ -92,14 +92,14 @@ def convert_to_hf(ckpt_path):
else:
logging.info(f"Found cached checkpoint at {hf_cache}")
patch_unet(hf_cache, args.ed1_mode, args.lowvram)
return hf_cache
is_sd1attn = patch_unet(hf_cache)
return hf_cache, is_sd1attn
elif os.path.isdir(hf_cache):
patch_unet(hf_cache, args.ed1_mode, args.lowvram)
return hf_cache
is_sd1attn = patch_unet(hf_cache)
return hf_cache, is_sd1attn
else:
patch_unet(ckpt_path, args.ed1_mode, args.lowvram)
return ckpt_path
is_sd1attn = patch_unet(ckpt_path)
return ckpt_path, is_sd1attn
def setup_local_logger(args):
"""
@ -230,10 +230,6 @@ def setup_args(args):
# find the last checkpoint in the logdir
args.resume_ckpt = find_last_checkpoint(args.logdir)
if args.ed1_mode and not args.amp and not args.disable_xformers:
args.disable_xformers = True
logging.info(" ED1 mode without amp: Overiding disable_xformers to True")
if args.lowvram:
set_args_12gb(args)
@ -398,7 +394,7 @@ def main(args):
"""
logging.info(f"Generating samples gs:{gs}, for {prompts}")
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
gen = torch.Generator(device="cuda").manual_seed(seed)
gen = torch.Generator(device=device).manual_seed(seed)
i = 0
for prompt in prompts:
@ -442,10 +438,10 @@ def main(args):
del images
try:
hf_ckpt_path = convert_to_hf(args.resume_ckpt)
hf_ckpt_path, is_sd1attn = convert_to_hf(args.resume_ckpt)
text_encoder = CLIPTextModel.from_pretrained(hf_ckpt_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(hf_ckpt_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet", upcast_attention=not args.ed1_mode)
unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet", upcast_attention=not is_sd1attn)
sample_scheduler = DDIMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
noise_scheduler = DDPMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(hf_ckpt_path, subfolder="tokenizer", use_fast=False)
@ -455,11 +451,8 @@ def main(args):
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()
if args.ed1_mode and not args.lowvram:
unet.set_attention_slice(4)
if not args.disable_xformers and is_xformers_available():
if not args.disable_xformers and (args.amp and is_sd1attn) or (not is_sd1attn):
try:
unet.enable_xformers_memory_efficient_attention()
logging.info("Enabled xformers")
@ -660,8 +653,8 @@ def main(args):
enabled=args.amp,
#enabled=True,
init_scale=2**17.5,
growth_factor=1.5,
backoff_factor=1.0/1.5,
growth_factor=1.8,
backoff_factor=1.0/1.8,
growth_interval=50,
)
logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)")
@ -831,11 +824,12 @@ def main(args):
global_step += 1
if global_step == 500:
scaler.set_growth_factor(1.35)
scaler.set_backoff_factor(1/1.35)
scaler.set_growth_factor(1.4)
scaler.set_backoff_factor(1/1.4)
if global_step == 1000:
scaler.set_growth_factor(1.2)
scaler.set_backoff_factor(1/1.2)
scaler.set_growth_interval(100)
# end of step
steps_pbar.close()
@ -926,7 +920,6 @@ if __name__ == "__main__":
argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED")
argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)")
argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!")
argparser.add_argument("--ed1_mode", action="store_true", default=False, help="Disables xformers and reduces attention heads to 8 (SD1.x style)")
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1)")
argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)")
argparser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation factor (def: 1), (ex, 2)")
@ -957,6 +950,6 @@ if __name__ == "__main__":
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
args = argparser.parse_args()
args, _ = argparser.parse_known_args()
main(args)

View File

@ -17,7 +17,7 @@ import os
import json
import logging
def patch_unet(ckpt_path, force_sd1attn: bool = False, low_vram: bool = False):
def patch_unet(ckpt_path):
"""
Patch the UNet to use updated attention heads for xformers support in FP32
"""
@ -25,11 +25,9 @@ def patch_unet(ckpt_path, force_sd1attn: bool = False, low_vram: bool = False):
with open(unet_cfg_path, "r") as f:
unet_cfg = json.load(f)
if force_sd1attn:
unet_cfg["attention_head_dim"] = [8, 8, 8, 8]
else: # SD 2 attn for xformers
unet_cfg["attention_head_dim"] = [5, 10, 20, 20]
is_sd1attn = unet_cfg["attention_head_dim"] == [8, 8, 8, 8]
is_sd1attn = unet_cfg["attention_head_dim"] == 8 or is_sd1attn
logging.info(f" unet attention_head_dim: {unet_cfg['attention_head_dim']}")
with open(unet_cfg_path, "w") as f:
json.dump(unet_cfg, f, indent=2)
return is_sd1attn