From 51ee7253bbcdec32d2037e81948864c95ee9693e Mon Sep 17 00:00:00 2001 From: Victor Hall Date: Mon, 2 Jan 2023 17:33:31 -0500 Subject: [PATCH] option to disable xformers --- train.py | 38 +++++++++++++++++--------------------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/train.py b/train.py index 56d43f9..1c62d33 100644 --- a/train.py +++ b/train.py @@ -202,7 +202,9 @@ def main(args): set_seed(seed) gpu = GPU() device = torch.device(f"cuda:{args.gpuid}") + torch.backends.cudnn.benchmark = False + args.clip_skip = max(min(4, args.clip_skip), 0) if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None: @@ -216,11 +218,11 @@ def main(args): args.save_every_n_epochs = _VERY_LARGE_NUMBER if args.save_every_n_epochs < _VERY_LARGE_NUMBER and args.ckpt_every_n_minutes < _VERY_LARGE_NUMBER: - logging.warning(f"{Fore.LIGHTYELLOW_EX}Both save_every_n_epochs and ckpt_every_n_minutes are set, this will potentially spam a lot of checkpoints{Style.RESET_ALL}") - logging.warning(f"{Fore.LIGHTYELLOW_EX}save_every_n_epochs: {args.save_every_n_epochs}, ckpt_every_n_minutes: {args.ckpt_every_n_minutes}{Style.RESET_ALL}") + logging.warning(f"{Fore.LIGHTYELLOW_EX}** Both save_every_n_epochs and ckpt_every_n_minutes are set, this will potentially spam a lot of checkpoints{Style.RESET_ALL}") + logging.warning(f"{Fore.LIGHTYELLOW_EX}** save_every_n_epochs: {args.save_every_n_epochs}, ckpt_every_n_minutes: {args.ckpt_every_n_minutes}{Style.RESET_ALL}") if args.cond_dropout > 0.26: - logging.warning(f"{Fore.LIGHTYELLOW_EX}cond_dropout is set fairly high: {args.cond_dropout}, make sure this was intended{Style.RESET_ALL}") + logging.warning(f"{Fore.LIGHTYELLOW_EX}** cond_dropout is set fairly high: {args.cond_dropout}, make sure this was intended{Style.RESET_ALL}") total_batch_size = args.batch_size * args.grad_accum @@ -287,12 +289,7 @@ def main(args): requires_safety_checker=None, # avoid nag feature_extractor=None, # must be none of no safety checker ) - if is_xformers_available(): - try: - pipe.enable_xformers_memory_efficient_attention() - except Exception as ex: - logging.warning("failed to load xformers, continuing without it") - pass + return pipe def __generate_sample(pipe: StableDiffusionPipeline, prompt : str, cfg: float, resolution: int, gen): @@ -392,15 +389,15 @@ def main(args): unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() - if is_xformers_available(): + if not args.disable_xformers and is_xformers_available(): try: unet.enable_xformers_memory_efficient_attention() - print(" Enabled xformers") - except Exception as e: - logging.warning( - "Could not enable memory efficient attention. Make sure xformers is installed" - f" correctly and a GPU is available: {e}" - ) + except Exception as ex: + logging.warning("failed to load xformers, continuing without it") + pass + else: + unet.disable_xformers_memory_efficient_attention() + logging.info("xformers not available or disabled") default_lr = 3e-6 curr_lr = args.lr if args.lr is not None else default_lr @@ -412,7 +409,6 @@ def main(args): if args.disable_textenc_training: logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}") params_to_train = itertools.chain(unet.parameters()) - text_encoder.eval() else: logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}") params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters()) @@ -576,7 +572,7 @@ def main(args): ) unet.train() - text_encoder.train() + text_encoder.train() if not args.disable_textenc_training else text_encoder.eval() logging.info(f" unet device: {unet.device}, precision: {unet.dtype}, training: {unet.training}") logging.info(f" text_encoder device: {text_encoder.device}, precision: {text_encoder.dtype}, training: {text_encoder.training}") @@ -705,10 +701,10 @@ def main(args): __generate_test_samples(pipe=pipe, prompts=prompts, log_writer=log_writer, log_folder=log_folder, gs=global_step, random_captions=True, resolution=args.resolution) del pipe - torch.cuda.empty_cache() gc.collect() + torch.cuda.empty_cache() - min_since_last_ckpt = (time.time() - last_epoch_saved_time) / 60 + min_since_last_ckpt = (time.time() - last_epoch_saved_time) / 60 if args.ckpt_every_n_minutes is not None and (min_since_last_ckpt > args.ckpt_every_n_minutes): last_epoch_saved_time = time.time() @@ -788,7 +784,7 @@ if __name__ == "__main__": argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1)") argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)") argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)") - + argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)") argparser.add_argument("--lowvram", action="store_true", default=False, help="automatically overrides various args to support 12GB gpu") args = argparser.parse_args()