prevent OOM with disabled unet when gradient checkpointing is enabled

unet needs to be in train() mode for gradient checkpointing to work
This commit is contained in:
Damian Stewart 2024-01-16 10:23:52 +13:00 committed by GitHub
parent e08d5ded98
commit 9fc6ae7a09
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -891,7 +891,7 @@ def main(args):
train_dataloader = build_torch_dataloader(train_batch, batch_size=args.batch_size)
unet.train() if not args.disable_unet_training else unet.eval()
unet.train() if (args.gradient_checkpointing or not args.disable_unet_training) else unet.eval()
text_encoder.train() if not args.disable_textenc_training else text_encoder.eval()
logging.info(f" unet device: {unet.device}, precision: {unet.dtype}, training: {unet.training}")