Merge branch 'main' into main

This commit is contained in:
Anthony Mercurio 2022-11-10 09:22:37 -07:00 committed by GitHub
commit 4e89309d6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 0 deletions

View File

@ -82,6 +82,7 @@ parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler",
parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding') parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding')
parser.add_argument('--output_bucket_info', type=bool, default=False, help='Outputs bucket information and exits') parser.add_argument('--output_bucket_info', type=bool, default=False, help='Outputs bucket information and exits')
parser.add_argument('--resize', type=bool, default=False, help="Resizes dataset's images to the appropriate bucket dimensions.") parser.add_argument('--resize', type=bool, default=False, help="Resizes dataset's images to the appropriate bucket dimensions.")
parser.add_argument('--use_xformers', type=bool, default=False, help='Use memory efficient attention')
args = parser.parse_args() args = parser.parse_args()
def setup(): def setup():
@ -561,6 +562,9 @@ def main():
if args.gradient_checkpointing: if args.gradient_checkpointing:
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()
if args.use_xformers:
unet.set_use_memory_efficient_attention_xformers(True)
if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails. if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails.
try: try:
import bitsandbytes as bnb import bitsandbytes as bnb