Merge branch 'main' into main
This commit is contained in:
commit
4e89309d6b
|
@ -82,6 +82,7 @@ parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler",
|
||||||
parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding')
|
parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding')
|
||||||
parser.add_argument('--output_bucket_info', type=bool, default=False, help='Outputs bucket information and exits')
|
parser.add_argument('--output_bucket_info', type=bool, default=False, help='Outputs bucket information and exits')
|
||||||
parser.add_argument('--resize', type=bool, default=False, help="Resizes dataset's images to the appropriate bucket dimensions.")
|
parser.add_argument('--resize', type=bool, default=False, help="Resizes dataset's images to the appropriate bucket dimensions.")
|
||||||
|
parser.add_argument('--use_xformers', type=bool, default=False, help='Use memory efficient attention')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
def setup():
|
def setup():
|
||||||
|
@ -561,6 +562,9 @@ def main():
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
|
|
||||||
|
if args.use_xformers:
|
||||||
|
unet.set_use_memory_efficient_attention_xformers(True)
|
||||||
|
|
||||||
if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails.
|
if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails.
|
||||||
try:
|
try:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
|
|
Loading…
Reference in New Issue