Fix Gradient Checkpointing

This commit is contained in:
cafeai 2022-12-03 21:24:38 +09:00
parent 31dd4f6433
commit 29e7df519b
1 changed files with 2 additions and 3 deletions

View File

@ -741,7 +741,6 @@ def main():
if args.gradient_checkpointing: if args.gradient_checkpointing:
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()
if args.train_text_encoder: if args.train_text_encoder:
text_encoder.gradient_checkpointing_enable() text_encoder.gradient_checkpointing_enable()