Fix Gradient Checkpointing

This commit is contained in:
cafeai 2022-12-03 21:24:38 +09:00
parent 31dd4f6433
commit 29e7df519b
1 changed files with 2 additions and 3 deletions

View File

@ -741,9 +741,8 @@ def main():
if args.gradient_checkpointing: if args.gradient_checkpointing:
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()
if args.train_text_encoder:
if args.train_text_encoder: text_encoder.gradient_checkpointing_enable()
text_encoder.gradient_checkpointing_enable()
if args.use_xformers: if args.use_xformers:
unet.set_use_memory_efficient_attention_xformers(True) unet.set_use_memory_efficient_attention_xformers(True)