From 29e7df519b4abe1457ffd90da63e8160af927ee2 Mon Sep 17 00:00:00 2001 From: cafeai <116491182+cafeai@users.noreply.github.com> Date: Sat, 3 Dec 2022 21:24:38 +0900 Subject: [PATCH] Fix Gradient Checkpointing --- trainer/diffusers_trainer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index 6c62ba2..0e08746 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -741,9 +741,8 @@ def main(): if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - - if args.train_text_encoder: - text_encoder.gradient_checkpointing_enable() + if args.train_text_encoder: + text_encoder.gradient_checkpointing_enable() if args.use_xformers: unet.set_use_memory_efficient_attention_xformers(True)