diff --git a/chain.bat b/chain.bat index a4ff929..a27fd0a 100644 --- a/chain.bat +++ b/chain.bat @@ -1,3 +1,5 @@ python train.py --config chain0.json python train.py --config chain1.json -python train.py --config chain2.json \ No newline at end of file +python train.py --config chain2.json + +pause \ No newline at end of file diff --git a/train.py b/train.py index 234aa8c..24347dc 100644 --- a/train.py +++ b/train.py @@ -436,6 +436,9 @@ def main(args): unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() + if args.ed1_mode: + unet.set_attention_slice(4) + if not args.disable_xformers and is_xformers_available(): try: unet.enable_xformers_memory_efficient_attention() diff --git a/utils/patch_unet.py b/utils/patch_unet.py index eb07747..7520e85 100644 --- a/utils/patch_unet.py +++ b/utils/patch_unet.py @@ -27,7 +27,7 @@ def patch_unet(ckpt_path, force_sd1attn: bool = False): if force_sd1attn: - unet_cfg["attention_head_dim"] = [5, 8, 8, 8] + unet_cfg["attention_head_dim"] = [8, 8, 8, 8] else: unet_cfg["attention_head_dim"] = [5, 10, 20, 20]