option to disable xformers
This commit is contained in:
parent
726eecc958
commit
51ee7253bb
38
train.py
38
train.py
|
@ -202,7 +202,9 @@ def main(args):
|
||||||
set_seed(seed)
|
set_seed(seed)
|
||||||
gpu = GPU()
|
gpu = GPU()
|
||||||
device = torch.device(f"cuda:{args.gpuid}")
|
device = torch.device(f"cuda:{args.gpuid}")
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
args.clip_skip = max(min(4, args.clip_skip), 0)
|
args.clip_skip = max(min(4, args.clip_skip), 0)
|
||||||
|
|
||||||
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
|
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
|
||||||
|
@ -216,11 +218,11 @@ def main(args):
|
||||||
args.save_every_n_epochs = _VERY_LARGE_NUMBER
|
args.save_every_n_epochs = _VERY_LARGE_NUMBER
|
||||||
|
|
||||||
if args.save_every_n_epochs < _VERY_LARGE_NUMBER and args.ckpt_every_n_minutes < _VERY_LARGE_NUMBER:
|
if args.save_every_n_epochs < _VERY_LARGE_NUMBER and args.ckpt_every_n_minutes < _VERY_LARGE_NUMBER:
|
||||||
logging.warning(f"{Fore.LIGHTYELLOW_EX}Both save_every_n_epochs and ckpt_every_n_minutes are set, this will potentially spam a lot of checkpoints{Style.RESET_ALL}")
|
logging.warning(f"{Fore.LIGHTYELLOW_EX}** Both save_every_n_epochs and ckpt_every_n_minutes are set, this will potentially spam a lot of checkpoints{Style.RESET_ALL}")
|
||||||
logging.warning(f"{Fore.LIGHTYELLOW_EX}save_every_n_epochs: {args.save_every_n_epochs}, ckpt_every_n_minutes: {args.ckpt_every_n_minutes}{Style.RESET_ALL}")
|
logging.warning(f"{Fore.LIGHTYELLOW_EX}** save_every_n_epochs: {args.save_every_n_epochs}, ckpt_every_n_minutes: {args.ckpt_every_n_minutes}{Style.RESET_ALL}")
|
||||||
|
|
||||||
if args.cond_dropout > 0.26:
|
if args.cond_dropout > 0.26:
|
||||||
logging.warning(f"{Fore.LIGHTYELLOW_EX}cond_dropout is set fairly high: {args.cond_dropout}, make sure this was intended{Style.RESET_ALL}")
|
logging.warning(f"{Fore.LIGHTYELLOW_EX}** cond_dropout is set fairly high: {args.cond_dropout}, make sure this was intended{Style.RESET_ALL}")
|
||||||
|
|
||||||
total_batch_size = args.batch_size * args.grad_accum
|
total_batch_size = args.batch_size * args.grad_accum
|
||||||
|
|
||||||
|
@ -287,12 +289,7 @@ def main(args):
|
||||||
requires_safety_checker=None, # avoid nag
|
requires_safety_checker=None, # avoid nag
|
||||||
feature_extractor=None, # must be none of no safety checker
|
feature_extractor=None, # must be none of no safety checker
|
||||||
)
|
)
|
||||||
if is_xformers_available():
|
|
||||||
try:
|
|
||||||
pipe.enable_xformers_memory_efficient_attention()
|
|
||||||
except Exception as ex:
|
|
||||||
logging.warning("failed to load xformers, continuing without it")
|
|
||||||
pass
|
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
def __generate_sample(pipe: StableDiffusionPipeline, prompt : str, cfg: float, resolution: int, gen):
|
def __generate_sample(pipe: StableDiffusionPipeline, prompt : str, cfg: float, resolution: int, gen):
|
||||||
|
@ -392,15 +389,15 @@ def main(args):
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
text_encoder.gradient_checkpointing_enable()
|
text_encoder.gradient_checkpointing_enable()
|
||||||
|
|
||||||
if is_xformers_available():
|
if not args.disable_xformers and is_xformers_available():
|
||||||
try:
|
try:
|
||||||
unet.enable_xformers_memory_efficient_attention()
|
unet.enable_xformers_memory_efficient_attention()
|
||||||
print(" Enabled xformers")
|
except Exception as ex:
|
||||||
except Exception as e:
|
logging.warning("failed to load xformers, continuing without it")
|
||||||
logging.warning(
|
pass
|
||||||
"Could not enable memory efficient attention. Make sure xformers is installed"
|
else:
|
||||||
f" correctly and a GPU is available: {e}"
|
unet.disable_xformers_memory_efficient_attention()
|
||||||
)
|
logging.info("xformers not available or disabled")
|
||||||
|
|
||||||
default_lr = 3e-6
|
default_lr = 3e-6
|
||||||
curr_lr = args.lr if args.lr is not None else default_lr
|
curr_lr = args.lr if args.lr is not None else default_lr
|
||||||
|
@ -412,7 +409,6 @@ def main(args):
|
||||||
if args.disable_textenc_training:
|
if args.disable_textenc_training:
|
||||||
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
|
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
|
||||||
params_to_train = itertools.chain(unet.parameters())
|
params_to_train = itertools.chain(unet.parameters())
|
||||||
text_encoder.eval()
|
|
||||||
else:
|
else:
|
||||||
logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}")
|
logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}")
|
||||||
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
|
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||||
|
@ -576,7 +572,7 @@ def main(args):
|
||||||
)
|
)
|
||||||
|
|
||||||
unet.train()
|
unet.train()
|
||||||
text_encoder.train()
|
text_encoder.train() if not args.disable_textenc_training else text_encoder.eval()
|
||||||
|
|
||||||
logging.info(f" unet device: {unet.device}, precision: {unet.dtype}, training: {unet.training}")
|
logging.info(f" unet device: {unet.device}, precision: {unet.dtype}, training: {unet.training}")
|
||||||
logging.info(f" text_encoder device: {text_encoder.device}, precision: {text_encoder.dtype}, training: {text_encoder.training}")
|
logging.info(f" text_encoder device: {text_encoder.device}, precision: {text_encoder.dtype}, training: {text_encoder.training}")
|
||||||
|
@ -705,10 +701,10 @@ def main(args):
|
||||||
__generate_test_samples(pipe=pipe, prompts=prompts, log_writer=log_writer, log_folder=log_folder, gs=global_step, random_captions=True, resolution=args.resolution)
|
__generate_test_samples(pipe=pipe, prompts=prompts, log_writer=log_writer, log_folder=log_folder, gs=global_step, random_captions=True, resolution=args.resolution)
|
||||||
|
|
||||||
del pipe
|
del pipe
|
||||||
torch.cuda.empty_cache()
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
min_since_last_ckpt = (time.time() - last_epoch_saved_time) / 60
|
min_since_last_ckpt = (time.time() - last_epoch_saved_time) / 60
|
||||||
|
|
||||||
if args.ckpt_every_n_minutes is not None and (min_since_last_ckpt > args.ckpt_every_n_minutes):
|
if args.ckpt_every_n_minutes is not None and (min_since_last_ckpt > args.ckpt_every_n_minutes):
|
||||||
last_epoch_saved_time = time.time()
|
last_epoch_saved_time = time.time()
|
||||||
|
@ -788,7 +784,7 @@ if __name__ == "__main__":
|
||||||
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1)")
|
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1)")
|
||||||
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
|
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
|
||||||
argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)")
|
argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)")
|
||||||
|
argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)")
|
||||||
argparser.add_argument("--lowvram", action="store_true", default=False, help="automatically overrides various args to support 12GB gpu")
|
argparser.add_argument("--lowvram", action="store_true", default=False, help="automatically overrides various args to support 12GB gpu")
|
||||||
|
|
||||||
args = argparser.parse_args()
|
args = argparser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue