save full precision

This commit is contained in:
Victor Hall 2023-01-12 19:18:02 -05:00
parent 94ac152340
commit 3fc12965bb
1 changed files with 24 additions and 10 deletions

View File

@ -45,7 +45,6 @@ from accelerate.utils import set_seed
import wandb
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
import keyboard
@ -278,7 +277,12 @@ def main(args):
"""
log_time = setup_local_logger(args)
args = setup_args(args)
if args.notebook:
from tqdm.notebook import tqdm
else:
from tqdm.auto import tqdm
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
set_seed(seed)
gpu = GPU()
@ -292,7 +296,7 @@ def main(args):
os.makedirs(log_folder)
@torch.no_grad()
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir):
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, save_full_precision=False):
"""
Save the model to disk
"""
@ -317,9 +321,11 @@ def main(args):
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
else:
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
half = not save_full_precision
logging.info(f" * Saving SD model to {sd_ckpt_full}")
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=True)
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)
# optimizer_path = os.path.join(save_path, "optimizer.pt")
# if self.save_optimizer_flag:
@ -575,7 +581,7 @@ def main(args):
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
time.sleep(2) # give opportunity to ctrl-C again to cancel save
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir)
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
exit(_SIGTERM_EXIT_CODE)
signal.signal(signal.SIGINT, sigterm_handler)
@ -771,7 +777,7 @@ def main(args):
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
torch.cuda.empty_cache()
if keyboard.is_pressed("ctrl+alt+page up") or ((global_step + 1) % args.sample_steps == 0):
if (not args.notebook and keyboard.is_pressed("ctrl+alt+page up")) or ((global_step + 1) % args.sample_steps == 0):
pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=sample_scheduler, vae=vae)
pipe = pipe.to(device)
@ -793,12 +799,12 @@ def main(args):
last_epoch_saved_time = time.time()
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 1 and epoch < args.max_epochs - 1:
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
del batch
global_step += 1
@ -819,7 +825,7 @@ def main(args):
# end of training
save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
total_elapsed_time = time.time() - training_start_time
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
@ -829,7 +835,7 @@ def main(args):
except Exception as ex:
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
raise ex
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
@ -844,6 +850,12 @@ def update_old_args(t_args):
if not hasattr(t_args, "shuffle_tags"):
print(f" Config json is missing 'shuffle_tags'")
t_args.__dict__["shuffle_tags"] = False
if not hasattr(t_args, "save_full_precision"):
print(f" Config json is missing 'save_full_precision'")
t_args.__dict__["save_full_precision"] = False
if not hasattr(t_args, "notebook"):
print(f" Config json is missing 'notebook'")
t_args.__dict__["notebook"] = False
if __name__ == "__main__":
supported_resolutions = [256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152]
@ -898,6 +910,8 @@ if __name__ == "__main__":
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
argparser.add_argument("--save_full_precision", action="store_true", default=False, help="save ckpts at full FP32")
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
args = argparser.parse_args()