From 0c29d1e84d304aa692d7ed7c9d401838329ce5fb Mon Sep 17 00:00:00 2001 From: laksjdjf Date: Wed, 9 Nov 2022 10:34:49 +0900 Subject: [PATCH 1/2] Update diffusers_trainer.py --- diffusers_trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/diffusers_trainer.py b/diffusers_trainer.py index b51c317..10eef02 100644 --- a/diffusers_trainer.py +++ b/diffusers_trainer.py @@ -81,6 +81,7 @@ parser.add_argument('--image_log_inference_steps', type=int, default=50, help='N parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler", help='Number of inference steps to use to log images.') parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding') parser.add_argument('--output_bucket_info', type=bool, default=False, help='Outputs bucket information and exits') +parser.add_argument('--use_xformers', action='store_true',help='Use memory efficient attention') args = parser.parse_args() def setup(): @@ -549,6 +550,9 @@ def main(): if args.gradient_checkpointing: unet.enable_gradient_checkpointing() + + if args.use_xformers: + unet.set_use_memory_efficient_attention_xformers(True) if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails. try: From 353200b039a59b90f5152a97ff71d903c780c40b Mon Sep 17 00:00:00 2001 From: laksjdjf Date: Wed, 9 Nov 2022 23:56:49 +0000 Subject: [PATCH 2/2] Update diffusers_trainer.py --- diffusers_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffusers_trainer.py b/diffusers_trainer.py index 10eef02..9b02928 100644 --- a/diffusers_trainer.py +++ b/diffusers_trainer.py @@ -81,7 +81,7 @@ parser.add_argument('--image_log_inference_steps', type=int, default=50, help='N parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler", help='Number of inference steps to use to log images.') parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding') parser.add_argument('--output_bucket_info', type=bool, default=False, help='Outputs bucket information and exits') -parser.add_argument('--use_xformers', action='store_true',help='Use memory efficient attention') +parser.add_argument('--use_xformers', type=bool, default=False, help='Use memory efficient attention') args = parser.parse_args() def setup():