diff --git a/diffusers_trainer.py b/diffusers_trainer.py index 00c076d..2a8441d 100644 --- a/diffusers_trainer.py +++ b/diffusers_trainer.py @@ -82,6 +82,7 @@ parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler", parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding') parser.add_argument('--output_bucket_info', type=bool, default=False, help='Outputs bucket information and exits') parser.add_argument('--resize', type=bool, default=False, help="Resizes dataset's images to the appropriate bucket dimensions.") +parser.add_argument('--use_xformers', type=bool, default=False, help='Use memory efficient attention') args = parser.parse_args() def setup(): @@ -560,6 +561,9 @@ def main(): if args.gradient_checkpointing: unet.enable_gradient_checkpointing() + + if args.use_xformers: + unet.set_use_memory_efficient_attention_xformers(True) if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails. try: