Fix Gradient Checkpointing
This commit is contained in:
parent
31dd4f6433
commit
29e7df519b
|
@ -741,9 +741,8 @@ def main():
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
|
if args.train_text_encoder:
|
||||||
if args.train_text_encoder:
|
text_encoder.gradient_checkpointing_enable()
|
||||||
text_encoder.gradient_checkpointing_enable()
|
|
||||||
|
|
||||||
if args.use_xformers:
|
if args.use_xformers:
|
||||||
unet.set_use_memory_efficient_attention_xformers(True)
|
unet.set_use_memory_efficient_attention_xformers(True)
|
||||||
|
|
Loading…
Reference in New Issue