From 9fc6ae7a09e69d4ac7e4a2ad3bf822d8407e6301 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Tue, 16 Jan 2024 10:23:52 +1300 Subject: [PATCH] prevent OOM with disabled unet when gradient checkpointing is enabled unet needs to be in train() mode for gradient checkpointing to work --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 749c5c7..1312b46 100644 --- a/train.py +++ b/train.py @@ -891,7 +891,7 @@ def main(args): train_dataloader = build_torch_dataloader(train_batch, batch_size=args.batch_size) - unet.train() if not args.disable_unet_training else unet.eval() + unet.train() if (args.gradient_checkpointing or not args.disable_unet_training) else unet.eval() text_encoder.train() if not args.disable_textenc_training else text_encoder.eval() logging.info(f" unet device: {unet.device}, precision: {unet.dtype}, training: {unet.training}")