Fix Gradient Checkpointing
This commit is contained in:
parent
31dd4f6433
commit
29e7df519b
|
@ -741,7 +741,6 @@ def main():
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
|
|
||||||
if args.train_text_encoder:
|
if args.train_text_encoder:
|
||||||
text_encoder.gradient_checkpointing_enable()
|
text_encoder.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue