Fix Gradient Checkpointing
This commit is contained in:
parent
31dd4f6433
commit
29e7df519b
|
@ -741,9 +741,8 @@ def main():
|
|||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
if args.train_text_encoder:
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
if args.use_xformers:
|
||||
unet.set_use_memory_efficient_attention_xformers(True)
|
||||
|
|
Loading…
Reference in New Issue