diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index 6c62ba2..0e08746 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -741,9 +741,8 @@ def main(): if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - - if args.train_text_encoder: - text_encoder.gradient_checkpointing_enable() + if args.train_text_encoder: + text_encoder.gradient_checkpointing_enable() if args.use_xformers: unet.set_use_memory_efficient_attention_xformers(True)