update ed1 mode

This commit is contained in:
Victor Hall 2023-01-09 13:44:51 -05:00
parent 0b2edc6b65
commit bf869db2e2
3 changed files with 7 additions and 2 deletions

View File

@ -1,3 +1,5 @@
python train.py --config chain0.json python train.py --config chain0.json
python train.py --config chain1.json python train.py --config chain1.json
python train.py --config chain2.json python train.py --config chain2.json
pause

View File

@ -436,6 +436,9 @@ def main(args):
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable() text_encoder.gradient_checkpointing_enable()
if args.ed1_mode:
unet.set_attention_slice(4)
if not args.disable_xformers and is_xformers_available(): if not args.disable_xformers and is_xformers_available():
try: try:
unet.enable_xformers_memory_efficient_attention() unet.enable_xformers_memory_efficient_attention()

View File

@ -27,7 +27,7 @@ def patch_unet(ckpt_path, force_sd1attn: bool = False):
if force_sd1attn: if force_sd1attn:
unet_cfg["attention_head_dim"] = [5, 8, 8, 8] unet_cfg["attention_head_dim"] = [8, 8, 8, 8]
else: else:
unet_cfg["attention_head_dim"] = [5, 10, 20, 20] unet_cfg["attention_head_dim"] = [5, 10, 20, 20]