From debcdd25060266eec0b4247016746f4d72fe0389 Mon Sep 17 00:00:00 2001 From: nawnie <106923464+nawnie@users.noreply.github.com> Date: Sun, 22 Jan 2023 00:02:35 -0600 Subject: [PATCH] Update train.py Revert --- train.py | 112 ++++++++++++++++++++++++++++++++----------------------- 1 file changed, 65 insertions(+), 47 deletions(-) diff --git a/train.py b/train.py index b7243de..0883363 100644 --- a/train.py +++ b/train.py @@ -22,9 +22,10 @@ import logging import time import gc import random +import shutil import torch.nn.functional as F -from torch.cuda.amp import autocast +from torch.cuda.amp import autocast, GradScaler import torchvision.transforms as transforms from colorama import Fore, Style, Cursor @@ -46,12 +47,10 @@ from accelerate.utils import set_seed import wandb from torch.utils.tensorboard import SummaryWriter -import keyboard - from data.every_dream import EveryDreamBatch from utils.convert_diff_to_ckpt import convert as converter from utils.gpu import GPU -forstepTime = time.time() + _SIGTERM_EXIT_CODE = 130 _VERY_LARGE_NUMBER = 1e9 @@ -87,19 +86,19 @@ def convert_to_hf(ckpt_path): import utils.convert_original_stable_diffusion_to_diffusers as convert convert.convert(ckpt_path, f"ckpt_cache/{ckpt_path}") except: - logging.info("Please manually convert the checkpoint to Diffusers format, see readme.") + logging.info("Please manually convert the checkpoint to Diffusers format (one time setup), see readme.") exit() else: logging.info(f"Found cached checkpoint at {hf_cache}") - is_sd1attn = patch_unet(hf_cache) - return hf_cache, is_sd1attn + is_sd1attn, yaml = patch_unet(hf_cache) + return hf_cache, is_sd1attn, yaml elif os.path.isdir(hf_cache): - is_sd1attn = patch_unet(hf_cache) - return hf_cache, is_sd1attn + is_sd1attn, yaml = patch_unet(hf_cache) + return hf_cache, is_sd1attn, yaml else: - is_sd1attn = patch_unet(ckpt_path) - return ckpt_path, is_sd1attn + is_sd1attn, yaml = patch_unet(ckpt_path) + return ckpt_path, is_sd1attn, yaml def setup_local_logger(args): """ @@ -174,7 +173,6 @@ def append_epoch_log(global_step: int, epoch_pbar, gpu, log_writer, **logs): if logs is not None: epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}") - print(f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step} | Elapsed : {time.time() - forstepTime}s") def set_args_12gb(args): @@ -276,6 +274,28 @@ def setup_args(args): return args +def update_grad_scaler(scaler: GradScaler, global_step, epoch, step): + if global_step == 250 or (epoch >= 2 and step == 1): + factor = 1.8 + scaler.set_growth_factor(factor) + scaler.set_backoff_factor(1/factor) + scaler.set_growth_interval(50) + if global_step == 500 or (epoch >= 4 and step == 1): + factor = 1.6 + scaler.set_growth_factor(factor) + scaler.set_backoff_factor(1/factor) + scaler.set_growth_interval(50) + if global_step == 1000 or (epoch >= 8 and step == 1): + factor = 1.3 + scaler.set_growth_factor(factor) + scaler.set_backoff_factor(1/factor) + scaler.set_growth_interval(100) + if global_step == 3000 or (epoch >= 15 and step == 1): + factor = 1.15 + scaler.set_growth_factor(factor) + scaler.set_backoff_factor(1/factor) + scaler.set_growth_interval(100) + def main(args): """ Main entry point @@ -297,12 +317,12 @@ def main(args): torch.backends.cudnn.benchmark = True log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}") - logging.info(f"Logging to {log_folder}") + if not os.path.exists(log_folder): os.makedirs(log_folder) @torch.no_grad() - def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, save_full_precision=False): + def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, yaml_name, save_full_precision=False): """ Save the model to disk """ @@ -323,17 +343,24 @@ def main(args): ) pipeline.save_pretrained(save_path) sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt" + if save_ckpt_dir is not None: sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path) else: sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path) + save_ckpt_dir = os.curdir half = not save_full_precision logging.info(f" * Saving SD model to {sd_ckpt_full}") converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half) - # optimizer_path = os.path.join(save_path, "optimizer.pt") + if yaml_name: + yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(save_path))}.yaml" + logging.info(f" * Saving yaml to {yaml_save_path}") + shutil.copyfile(yaml_name, yaml_save_path) + + # optimizer_path = os.path.join(save_path, "optimizer.pt") # if self.save_optimizer_flag: # logging.info(f" Saving optimizer state to {save_path}") # self.save_optimizer(self.ctx.optimizer, optimizer_path) @@ -439,7 +466,7 @@ def main(args): del images try: - hf_ckpt_path, is_sd1attn = convert_to_hf(args.resume_ckpt) + hf_ckpt_path, is_sd1attn, yaml = convert_to_hf(args.resume_ckpt) text_encoder = CLIPTextModel.from_pretrained(hf_ckpt_path, subfolder="text_encoder") vae = AutoencoderKL.from_pretrained(hf_ckpt_path, subfolder="vae") unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet", upcast_attention=not is_sd1attn) @@ -453,13 +480,14 @@ def main(args): unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() - if not args.disable_xformers and (args.amp and is_sd1attn) or (not is_sd1attn): - try: - unet.enable_xformers_memory_efficient_attention() - logging.info("Enabled xformers") - except Exception as ex: - logging.warning("failed to load xformers, continuing without it") - pass + if not args.disable_xformers: + if (args.amp and is_sd1attn) or (not is_sd1attn): + try: + unet.enable_xformers_memory_efficient_attention() + logging.info("Enabled xformers") + except Exception as ex: + logging.warning("failed to load xformers, continuing without it") + pass else: logging.info("xformers not available or disabled") @@ -487,7 +515,7 @@ def main(args): betas = (0.9, 0.999) epsilon = 1e-8 if args.amp: - epsilon = 1e-8 + epsilon = 2e-8 weight_decay = 0.01 if args.useadam8bit: @@ -563,7 +591,6 @@ def main(args): log_args(log_writer, args) - """ Train the model @@ -650,13 +677,12 @@ def main(args): #scaler = torch.cuda.amp.GradScaler() - scaler = torch.cuda.amp.GradScaler( + scaler = GradScaler( enabled=args.amp, - #enabled=True, init_scale=2**17.5, - growth_factor=1.8, - backoff_factor=1.0/1.8, - growth_interval=50, + growth_factor=2, + backoff_factor=1.0/2, + growth_interval=25, ) logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)") @@ -677,6 +703,8 @@ def main(args): append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer) loss_log_step = [] + + assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct" try: for epoch in range(args.max_epochs): @@ -727,16 +755,13 @@ def main(args): with autocast(enabled=args.amp): model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - del timesteps, encoder_hidden_states, noisy_latents + #del timesteps, encoder_hidden_states, noisy_latents #with autocast(enabled=args.amp): loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") del target, model_pred - #if args.amp: scaler.scale(loss).backward() - #else: - # loss.backward() if args.clip_grad_norm is not None: if not args.disable_unet_training: @@ -792,7 +817,7 @@ def main(args): append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs) torch.cuda.empty_cache() - if (not args.notebook and keyboard.is_pressed("ctrl+alt+page up")) or ((global_step + 1) % args.sample_steps == 0): + if (global_step + 1) % args.sample_steps == 0: pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=sample_scheduler, vae=vae) pipe = pipe.to(device) @@ -814,23 +839,16 @@ def main(args): last_epoch_saved_time = time.time() logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}") save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}") - __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision) + __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision) if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 1 and epoch < args.max_epochs - 1: logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}") save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}") - __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision) + __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision) del batch global_step += 1 - - if global_step == 500: - scaler.set_growth_factor(1.4) - scaler.set_backoff_factor(1/1.4) - if global_step == 1000: - scaler.set_growth_factor(1.2) - scaler.set_backoff_factor(1/1.2) - scaler.set_growth_interval(100) + update_grad_scaler(scaler, global_step, epoch, step) if args.amp else None # end of step steps_pbar.close() @@ -850,7 +868,7 @@ def main(args): # end of training save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}") - __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision) + __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision) total_elapsed_time = time.time() - training_start_time logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}") @@ -860,7 +878,7 @@ def main(args): except Exception as ex: logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}") save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}") - __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision) + __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision) raise ex logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")