option to disable xformers

This commit is contained in:
Victor Hall 2023-01-02 17:33:31 -05:00
parent 726eecc958
commit 51ee7253bb
1 changed files with 17 additions and 21 deletions

View File

@ -202,7 +202,9 @@ def main(args):
set_seed(seed) set_seed(seed)
gpu = GPU() gpu = GPU()
device = torch.device(f"cuda:{args.gpuid}") device = torch.device(f"cuda:{args.gpuid}")
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
args.clip_skip = max(min(4, args.clip_skip), 0) args.clip_skip = max(min(4, args.clip_skip), 0)
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None: if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
@ -216,11 +218,11 @@ def main(args):
args.save_every_n_epochs = _VERY_LARGE_NUMBER args.save_every_n_epochs = _VERY_LARGE_NUMBER
if args.save_every_n_epochs < _VERY_LARGE_NUMBER and args.ckpt_every_n_minutes < _VERY_LARGE_NUMBER: if args.save_every_n_epochs < _VERY_LARGE_NUMBER and args.ckpt_every_n_minutes < _VERY_LARGE_NUMBER:
logging.warning(f"{Fore.LIGHTYELLOW_EX}Both save_every_n_epochs and ckpt_every_n_minutes are set, this will potentially spam a lot of checkpoints{Style.RESET_ALL}") logging.warning(f"{Fore.LIGHTYELLOW_EX}** Both save_every_n_epochs and ckpt_every_n_minutes are set, this will potentially spam a lot of checkpoints{Style.RESET_ALL}")
logging.warning(f"{Fore.LIGHTYELLOW_EX}save_every_n_epochs: {args.save_every_n_epochs}, ckpt_every_n_minutes: {args.ckpt_every_n_minutes}{Style.RESET_ALL}") logging.warning(f"{Fore.LIGHTYELLOW_EX}** save_every_n_epochs: {args.save_every_n_epochs}, ckpt_every_n_minutes: {args.ckpt_every_n_minutes}{Style.RESET_ALL}")
if args.cond_dropout > 0.26: if args.cond_dropout > 0.26:
logging.warning(f"{Fore.LIGHTYELLOW_EX}cond_dropout is set fairly high: {args.cond_dropout}, make sure this was intended{Style.RESET_ALL}") logging.warning(f"{Fore.LIGHTYELLOW_EX}** cond_dropout is set fairly high: {args.cond_dropout}, make sure this was intended{Style.RESET_ALL}")
total_batch_size = args.batch_size * args.grad_accum total_batch_size = args.batch_size * args.grad_accum
@ -287,12 +289,7 @@ def main(args):
requires_safety_checker=None, # avoid nag requires_safety_checker=None, # avoid nag
feature_extractor=None, # must be none of no safety checker feature_extractor=None, # must be none of no safety checker
) )
if is_xformers_available():
try:
pipe.enable_xformers_memory_efficient_attention()
except Exception as ex:
logging.warning("failed to load xformers, continuing without it")
pass
return pipe return pipe
def __generate_sample(pipe: StableDiffusionPipeline, prompt : str, cfg: float, resolution: int, gen): def __generate_sample(pipe: StableDiffusionPipeline, prompt : str, cfg: float, resolution: int, gen):
@ -392,15 +389,15 @@ def main(args):
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable() text_encoder.gradient_checkpointing_enable()
if is_xformers_available(): if not args.disable_xformers and is_xformers_available():
try: try:
unet.enable_xformers_memory_efficient_attention() unet.enable_xformers_memory_efficient_attention()
print(" Enabled xformers") except Exception as ex:
except Exception as e: logging.warning("failed to load xformers, continuing without it")
logging.warning( pass
"Could not enable memory efficient attention. Make sure xformers is installed" else:
f" correctly and a GPU is available: {e}" unet.disable_xformers_memory_efficient_attention()
) logging.info("xformers not available or disabled")
default_lr = 3e-6 default_lr = 3e-6
curr_lr = args.lr if args.lr is not None else default_lr curr_lr = args.lr if args.lr is not None else default_lr
@ -412,7 +409,6 @@ def main(args):
if args.disable_textenc_training: if args.disable_textenc_training:
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}") logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
params_to_train = itertools.chain(unet.parameters()) params_to_train = itertools.chain(unet.parameters())
text_encoder.eval()
else: else:
logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}") logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}")
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters()) params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
@ -576,7 +572,7 @@ def main(args):
) )
unet.train() unet.train()
text_encoder.train() text_encoder.train() if not args.disable_textenc_training else text_encoder.eval()
logging.info(f" unet device: {unet.device}, precision: {unet.dtype}, training: {unet.training}") logging.info(f" unet device: {unet.device}, precision: {unet.dtype}, training: {unet.training}")
logging.info(f" text_encoder device: {text_encoder.device}, precision: {text_encoder.dtype}, training: {text_encoder.training}") logging.info(f" text_encoder device: {text_encoder.device}, precision: {text_encoder.dtype}, training: {text_encoder.training}")
@ -705,8 +701,8 @@ def main(args):
__generate_test_samples(pipe=pipe, prompts=prompts, log_writer=log_writer, log_folder=log_folder, gs=global_step, random_captions=True, resolution=args.resolution) __generate_test_samples(pipe=pipe, prompts=prompts, log_writer=log_writer, log_folder=log_folder, gs=global_step, random_captions=True, resolution=args.resolution)
del pipe del pipe
torch.cuda.empty_cache()
gc.collect() gc.collect()
torch.cuda.empty_cache()
min_since_last_ckpt = (time.time() - last_epoch_saved_time) / 60 min_since_last_ckpt = (time.time() - last_epoch_saved_time) / 60
@ -788,7 +784,7 @@ if __name__ == "__main__":
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1)") argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1)")
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)") argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)") argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)")
argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)")
argparser.add_argument("--lowvram", action="store_true", default=False, help="automatically overrides various args to support 12GB gpu") argparser.add_argument("--lowvram", action="store_true", default=False, help="automatically overrides various args to support 12GB gpu")
args = argparser.parse_args() args = argparser.parse_args()