diff --git a/train.py b/train.py index e9f02ec..fed93d3 100644 --- a/train.py +++ b/train.py @@ -22,9 +22,10 @@ import logging import time import gc import random +import shutil import torch.nn.functional as F -from torch.cuda.amp import autocast +from torch.cuda.amp import autocast, GradScaler import torchvision.transforms as transforms from colorama import Fore, Style, Cursor @@ -92,14 +93,14 @@ def convert_to_hf(ckpt_path): else: logging.info(f"Found cached checkpoint at {hf_cache}") - is_sd1attn = patch_unet(hf_cache) - return hf_cache, is_sd1attn + is_sd1attn, yaml = patch_unet(hf_cache) + return hf_cache, is_sd1attn, yaml elif os.path.isdir(hf_cache): - is_sd1attn = patch_unet(hf_cache) - return hf_cache, is_sd1attn + is_sd1attn, yaml = patch_unet(hf_cache) + return hf_cache, is_sd1attn, yaml else: - is_sd1attn = patch_unet(ckpt_path) - return ckpt_path, is_sd1attn + is_sd1attn, yaml = patch_unet(ckpt_path) + return ckpt_path, is_sd1attn, yaml def setup_local_logger(args): """ @@ -275,6 +276,28 @@ def setup_args(args): return args +def update_grad_scaler(scaler: GradScaler, global_step, epoch, step): + if global_step == 250 or (epoch >= 2 and step == 1): + factor = 1.8 + scaler.set_growth_factor(factor) + scaler.set_backoff_factor(1/factor) + scaler.set_growth_interval(50) + if global_step == 500 or (epoch >= 4 and step == 1): + factor = 1.6 + scaler.set_growth_factor(factor) + scaler.set_backoff_factor(1/factor) + scaler.set_growth_interval(50) + if global_step == 1000 or (epoch >= 8 and step == 1): + factor = 1.3 + scaler.set_growth_factor(factor) + scaler.set_backoff_factor(1/factor) + scaler.set_growth_interval(100) + if global_step == 3000 or (epoch >= 15 and step == 1): + factor = 1.15 + scaler.set_growth_factor(factor) + scaler.set_backoff_factor(1/factor) + scaler.set_growth_interval(100) + def main(args): """ Main entry point @@ -296,12 +319,12 @@ def main(args): torch.backends.cudnn.benchmark = True log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}") - logging.info(f"Logging to {log_folder}") + if not os.path.exists(log_folder): os.makedirs(log_folder) @torch.no_grad() - def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, save_full_precision=False): + def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, yaml_name, save_full_precision=False): """ Save the model to disk """ @@ -322,6 +345,7 @@ def main(args): ) pipeline.save_pretrained(save_path) sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt" + if save_ckpt_dir is not None: sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path) else: @@ -331,8 +355,13 @@ def main(args): logging.info(f" * Saving SD model to {sd_ckpt_full}") converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half) - # optimizer_path = os.path.join(save_path, "optimizer.pt") + if yaml: + yaml_save_path = f"{os.path.basename(save_path)}.yaml" + logging.info(f" * Saving yaml to {yaml_save_path}") + shutil.copyfile(yaml, yaml_save_path) + + # optimizer_path = os.path.join(save_path, "optimizer.pt") # if self.save_optimizer_flag: # logging.info(f" Saving optimizer state to {save_path}") # self.save_optimizer(self.ctx.optimizer, optimizer_path) @@ -438,13 +467,14 @@ def main(args): del images try: - hf_ckpt_path, is_sd1attn = convert_to_hf(args.resume_ckpt) + hf_ckpt_path, is_sd1attn, yaml = convert_to_hf(args.resume_ckpt) text_encoder = CLIPTextModel.from_pretrained(hf_ckpt_path, subfolder="text_encoder") vae = AutoencoderKL.from_pretrained(hf_ckpt_path, subfolder="vae") unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet", upcast_attention=not is_sd1attn) sample_scheduler = DDIMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler") tokenizer = CLIPTokenizer.from_pretrained(hf_ckpt_path, subfolder="tokenizer", use_fast=False) + logging.info(f" Inferred yaml: {yaml}, attention head type: {'sd1' if is_sd1attn else 'sd2'}") except: logging.ERROR(" * Failed to load checkpoint *") @@ -486,7 +516,7 @@ def main(args): betas = (0.9, 0.999) epsilon = 1e-8 if args.amp: - epsilon = 1e-8 + epsilon = 2e-8 weight_decay = 0.01 if args.useadam8bit: @@ -562,7 +592,6 @@ def main(args): log_args(log_writer, args) - """ Train the model @@ -649,13 +678,12 @@ def main(args): #scaler = torch.cuda.amp.GradScaler() - scaler = torch.cuda.amp.GradScaler( + scaler = GradScaler( enabled=args.amp, - #enabled=True, init_scale=2**17.5, - growth_factor=1.8, - backoff_factor=1.0/1.8, - growth_interval=50, + growth_factor=2, + backoff_factor=1.0/2, + growth_interval=25, ) logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)") @@ -813,23 +841,16 @@ def main(args): last_epoch_saved_time = time.time() logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}") save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}") - __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision) + __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision) if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 1 and epoch < args.max_epochs - 1: logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}") save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}") - __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision) + __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision) del batch global_step += 1 - - if global_step == 500: - scaler.set_growth_factor(1.4) - scaler.set_backoff_factor(1/1.4) - if global_step == 1000: - scaler.set_growth_factor(1.2) - scaler.set_backoff_factor(1/1.2) - scaler.set_growth_interval(100) + update_grad_scaler(scaler, global_step, epoch, step) if args.amp else None # end of step steps_pbar.close() @@ -849,7 +870,7 @@ def main(args): # end of training save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}") - __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision) + __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision) total_elapsed_time = time.time() - training_start_time logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}") @@ -859,7 +880,7 @@ def main(args): except Exception as ex: logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}") save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}") - __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision) + __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision) raise ex logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}") diff --git a/utils/patch_unet.py b/utils/patch_unet.py index cc00349..154444e 100644 --- a/utils/patch_unet.py +++ b/utils/patch_unet.py @@ -25,9 +25,25 @@ def patch_unet(ckpt_path): with open(unet_cfg_path, "r") as f: unet_cfg = json.load(f) + scheduler_cfg_path = os.path.join(ckpt_path, "scheduler", "scheduler_config.json") + with open(scheduler_cfg_path, "r") as f: + scheduler_cfg = json.load(f) + is_sd1attn = unet_cfg["attention_head_dim"] == [8, 8, 8, 8] is_sd1attn = unet_cfg["attention_head_dim"] == 8 or is_sd1attn + prediction_type = scheduler_cfg["prediction_type"] + logging.info(f" unet attention_head_dim: {unet_cfg['attention_head_dim']}") - return is_sd1attn + yaml = '' + if prediction_type in ["v_prediction","v-prediction"] and not is_sd1attn: + yaml = "v2-inference-v.yaml" + elif prediction_type == "epsilon" and not is_sd1attn: + yaml = "v2-inference.yaml" + elif prediction_type == "epsilon" and is_sd1attn: + yaml = "v2-inference.yaml" + else: + raise ValueError(f"Unknown model format for: {prediction_type} and attention_head_dim {unet_cfg['attention_head_dim']}") + + return is_sd1attn, yaml