deprecate ed1_mode, autodetect
This commit is contained in:
parent
ed025d27b6
commit
3e803a8313
|
@ -17,6 +17,10 @@ Covers install, setup of base models, startning training, basic tweaking, and lo
|
|||
|
||||
Behind the scenes look at how the trainer handles multiaspect and crop jitter
|
||||
|
||||
### Tools repo
|
||||
|
||||
Make sure to check out the [tools repo](https://github.com/victorchall/EveryDream), it has a grab bag of scripts to help with your data curation prior to training. It has automatic bulk BLIP captioning for BLIP, script to web scrape based on Laion data files, script to rename generic pronouns to proper names or append artist tags to your captions, etc.
|
||||
|
||||
## Docs
|
||||
|
||||
[Setup and installation](doc/SETUP.md)
|
||||
|
|
39
train.py
39
train.py
|
@ -92,14 +92,14 @@ def convert_to_hf(ckpt_path):
|
|||
else:
|
||||
logging.info(f"Found cached checkpoint at {hf_cache}")
|
||||
|
||||
patch_unet(hf_cache, args.ed1_mode, args.lowvram)
|
||||
return hf_cache
|
||||
is_sd1attn = patch_unet(hf_cache)
|
||||
return hf_cache, is_sd1attn
|
||||
elif os.path.isdir(hf_cache):
|
||||
patch_unet(hf_cache, args.ed1_mode, args.lowvram)
|
||||
return hf_cache
|
||||
is_sd1attn = patch_unet(hf_cache)
|
||||
return hf_cache, is_sd1attn
|
||||
else:
|
||||
patch_unet(ckpt_path, args.ed1_mode, args.lowvram)
|
||||
return ckpt_path
|
||||
is_sd1attn = patch_unet(ckpt_path)
|
||||
return ckpt_path, is_sd1attn
|
||||
|
||||
def setup_local_logger(args):
|
||||
"""
|
||||
|
@ -230,10 +230,6 @@ def setup_args(args):
|
|||
# find the last checkpoint in the logdir
|
||||
args.resume_ckpt = find_last_checkpoint(args.logdir)
|
||||
|
||||
if args.ed1_mode and not args.amp and not args.disable_xformers:
|
||||
args.disable_xformers = True
|
||||
logging.info(" ED1 mode without amp: Overiding disable_xformers to True")
|
||||
|
||||
if args.lowvram:
|
||||
set_args_12gb(args)
|
||||
|
||||
|
@ -398,7 +394,7 @@ def main(args):
|
|||
"""
|
||||
logging.info(f"Generating samples gs:{gs}, for {prompts}")
|
||||
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
|
||||
gen = torch.Generator(device="cuda").manual_seed(seed)
|
||||
gen = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
i = 0
|
||||
for prompt in prompts:
|
||||
|
@ -442,10 +438,10 @@ def main(args):
|
|||
del images
|
||||
|
||||
try:
|
||||
hf_ckpt_path = convert_to_hf(args.resume_ckpt)
|
||||
hf_ckpt_path, is_sd1attn = convert_to_hf(args.resume_ckpt)
|
||||
text_encoder = CLIPTextModel.from_pretrained(hf_ckpt_path, subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained(hf_ckpt_path, subfolder="vae")
|
||||
unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet", upcast_attention=not args.ed1_mode)
|
||||
unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet", upcast_attention=not is_sd1attn)
|
||||
sample_scheduler = DDIMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(hf_ckpt_path, subfolder="tokenizer", use_fast=False)
|
||||
|
@ -455,11 +451,8 @@ def main(args):
|
|||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
if args.ed1_mode and not args.lowvram:
|
||||
unet.set_attention_slice(4)
|
||||
|
||||
if not args.disable_xformers and is_xformers_available():
|
||||
if not args.disable_xformers and (args.amp and is_sd1attn) or (not is_sd1attn):
|
||||
try:
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
logging.info("Enabled xformers")
|
||||
|
@ -660,8 +653,8 @@ def main(args):
|
|||
enabled=args.amp,
|
||||
#enabled=True,
|
||||
init_scale=2**17.5,
|
||||
growth_factor=1.5,
|
||||
backoff_factor=1.0/1.5,
|
||||
growth_factor=1.8,
|
||||
backoff_factor=1.0/1.8,
|
||||
growth_interval=50,
|
||||
)
|
||||
logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)")
|
||||
|
@ -831,11 +824,12 @@ def main(args):
|
|||
global_step += 1
|
||||
|
||||
if global_step == 500:
|
||||
scaler.set_growth_factor(1.35)
|
||||
scaler.set_backoff_factor(1/1.35)
|
||||
scaler.set_growth_factor(1.4)
|
||||
scaler.set_backoff_factor(1/1.4)
|
||||
if global_step == 1000:
|
||||
scaler.set_growth_factor(1.2)
|
||||
scaler.set_backoff_factor(1/1.2)
|
||||
scaler.set_growth_interval(100)
|
||||
# end of step
|
||||
|
||||
steps_pbar.close()
|
||||
|
@ -926,7 +920,6 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED")
|
||||
argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)")
|
||||
argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!")
|
||||
argparser.add_argument("--ed1_mode", action="store_true", default=False, help="Disables xformers and reduces attention heads to 8 (SD1.x style)")
|
||||
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1)")
|
||||
argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)")
|
||||
argparser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation factor (def: 1), (ex, 2)")
|
||||
|
@ -957,6 +950,6 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
|
||||
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
|
||||
|
||||
args = argparser.parse_args()
|
||||
args, _ = argparser.parse_known_args()
|
||||
|
||||
main(args)
|
||||
|
|
|
@ -17,7 +17,7 @@ import os
|
|||
import json
|
||||
import logging
|
||||
|
||||
def patch_unet(ckpt_path, force_sd1attn: bool = False, low_vram: bool = False):
|
||||
def patch_unet(ckpt_path):
|
||||
"""
|
||||
Patch the UNet to use updated attention heads for xformers support in FP32
|
||||
"""
|
||||
|
@ -25,11 +25,9 @@ def patch_unet(ckpt_path, force_sd1attn: bool = False, low_vram: bool = False):
|
|||
with open(unet_cfg_path, "r") as f:
|
||||
unet_cfg = json.load(f)
|
||||
|
||||
if force_sd1attn:
|
||||
unet_cfg["attention_head_dim"] = [8, 8, 8, 8]
|
||||
else: # SD 2 attn for xformers
|
||||
unet_cfg["attention_head_dim"] = [5, 10, 20, 20]
|
||||
is_sd1attn = unet_cfg["attention_head_dim"] == [8, 8, 8, 8]
|
||||
is_sd1attn = unet_cfg["attention_head_dim"] == 8 or is_sd1attn
|
||||
|
||||
logging.info(f" unet attention_head_dim: {unet_cfg['attention_head_dim']}")
|
||||
with open(unet_cfg_path, "w") as f:
|
||||
json.dump(unet_cfg, f, indent=2)
|
||||
|
||||
return is_sd1attn
|
||||
|
|
Loading…
Reference in New Issue