Merge pull request #246 from damian0815/patch-2
prevent OOM with disabled unet when gradient checkpointing is enabled
This commit is contained in:
commit
18198715b7
2
train.py
2
train.py
|
@ -891,7 +891,7 @@ def main(args):
|
|||
|
||||
train_dataloader = build_torch_dataloader(train_batch, batch_size=args.batch_size)
|
||||
|
||||
unet.train() if not args.disable_unet_training else unet.eval()
|
||||
unet.train() if (args.gradient_checkpointing or not args.disable_unet_training) else unet.eval()
|
||||
text_encoder.train() if not args.disable_textenc_training else text_encoder.eval()
|
||||
|
||||
logging.info(f" unet device: {unet.device}, precision: {unet.dtype}, training: {unet.training}")
|
||||
|
|
Loading…
Reference in New Issue