From ba25992140f30ba2c2bf1b9daf0641f4ba9e022f Mon Sep 17 00:00:00 2001 From: Victor Hall Date: Sun, 15 Jan 2023 22:07:37 -0500 Subject: [PATCH] merge --- train.json | 1 + train.py | 128 +++++++++++++++++++++++++++---------------- utils/log_wrapper.py | 57 ++++++++++++++++--- 3 files changed, 132 insertions(+), 54 deletions(-) diff --git a/train.json b/train.json index 1184eab..374f3ca 100644 --- a/train.json +++ b/train.json @@ -21,6 +21,7 @@ "lr_scheduler": "constant", "lr_warmup_steps": null, "max_epochs": 30, + "notebook": false, "project_name": "project_abc", "resolution": 512, "resume_ckpt": "sd_v1-5_vae", diff --git a/train.py b/train.py index 56b2b90..140ae60 100644 --- a/train.py +++ b/train.py @@ -222,12 +222,15 @@ def setup_args(args): Sets defaults for missing args (possible if missing from json config) Forces some args to be set based on others for compatibility reasons """ + if args.disable_unet_training and args.disable_textenc_training: + raise ValueError("Both unet and textenc are disabled, nothing to train") + if args.resume_ckpt == "findlast": logging.info(f"{Fore.LIGHTCYAN_EX} Finding last checkpoint in logdir: {args.logdir}{Style.RESET_ALL}") # find the last checkpoint in the logdir args.resume_ckpt = find_last_checkpoint(args.logdir) - if args.ed1_mode and not args.disable_xformers: + if args.ed1_mode and args.mixed_precision == "fp32" and not args.disable_xformers: args.disable_xformers = True logging.info(" ED1 mode: Overiding disable_xformers to True") @@ -238,7 +241,7 @@ def setup_args(args): args.shuffle_tags = False args.clip_skip = max(min(4, args.clip_skip), 0) - + if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None: logging.info(f"{Fore.LIGHTCYAN_EX} No checkpoint saving specified, defaulting to every 20 minutes.{Style.RESET_ALL}") args.ckpt_every_n_minutes = 20 @@ -248,7 +251,7 @@ def setup_args(args): if args.save_every_n_epochs is None or args.save_every_n_epochs < 1: args.save_every_n_epochs = _VERY_LARGE_NUMBER - + if args.save_every_n_epochs < _VERY_LARGE_NUMBER and args.ckpt_every_n_minutes < _VERY_LARGE_NUMBER: logging.warning(f"{Fore.LIGHTYELLOW_EX}** Both save_every_n_epochs and ckpt_every_n_minutes are set, this will potentially spam a lot of checkpoints{Style.RESET_ALL}") logging.warning(f"{Fore.LIGHTYELLOW_EX}** save_every_n_epochs: {args.save_every_n_epochs}, ckpt_every_n_minutes: {args.ckpt_every_n_minutes}{Style.RESET_ALL}") @@ -269,6 +272,9 @@ def setup_args(args): if args.save_ckpt_dir is not None and not os.path.exists(args.save_ckpt_dir): os.makedirs(args.save_ckpt_dir) + if args.mixed_precision != "fp32" and (args.clip_grad_norm is None or args.clip_grad_norm <= 0): + args.clip_grad_norm = 1.0 + if args.rated_dataset: args.rated_dataset_target_dropout_percent = min(max(args.rated_dataset_target_dropout_percent, 0), 100) @@ -286,9 +292,11 @@ def main(args): if args.notebook: from tqdm.notebook import tqdm else: - from tqdm.auto import tqdm + from tqdm.auto import tqdm + logging.info(f" Seed: {args.seed}") seed = args.seed if args.seed != -1 else random.randint(0, 2**30) + logging.info(f" Seed: {seed}") set_seed(seed) gpu = GPU() device = torch.device(f"cuda:{args.gpuid}") @@ -441,7 +449,7 @@ def main(args): hf_ckpt_path = convert_to_hf(args.resume_ckpt) text_encoder = CLIPTextModel.from_pretrained(hf_ckpt_path, subfolder="text_encoder") vae = AutoencoderKL.from_pretrained(hf_ckpt_path, subfolder="vae") - unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet") + unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet", upcast_attention=not args.ed1_mode) sample_scheduler = DDIMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler") tokenizer = CLIPTokenizer.from_pretrained(hf_ckpt_path, subfolder="tokenizer", use_fast=False) @@ -468,22 +476,38 @@ def main(args): default_lr = 2e-6 curr_lr = args.lr if args.lr is not None else default_lr - # vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16) - # unet = unet.to(device, dtype=torch.float32 if not args.amp else torch.float16) - # text_encoder = text_encoder.to(device, dtype=torch.float32 if not args.amp else torch.float16) - vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16) - unet = unet.to(device, dtype=torch.float32) - text_encoder = text_encoder.to(device, dtype=torch.float32) + d_type = torch.float32 + if args.mixed_precision == "fp16": + d_type = torch.float16 + logging.info(" * Using fp16 *") + args.amp = True + elif args.mixed_precision == "bf16": + d_type = torch.bfloat16 + logging.info(" * Using bf16 *") + args.amp = True + else: + logging.info(" * Using FP32 *") + + + vae = vae.to(device, dtype=torch.float16 if (args.amp and d_type == torch.float32) else d_type) + unet = unet.to(device, dtype=d_type) + text_encoder = text_encoder.to(device, dtype=d_type) if args.disable_textenc_training: logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}") params_to_train = itertools.chain(unet.parameters()) + elif args.disable_unet_training: + logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}") + params_to_train = itertools.chain(text_encoder.parameters()) else: logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}") params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters()) betas = (0.9, 0.999) - epsilon = 1e-8 if not args.amp else 1e-8 + epsilon = 1e-8 + if args.amp or args.mix_precision == "fp16": + epsilon = 1e-8 + weight_decay = 0.01 if args.useadam8bit: import bitsandbytes as bnb @@ -502,6 +526,8 @@ def main(args): amsgrad=False, ) + log_optimizer(optimizer, betas, epsilon) + train_batch = EveryDreamBatch( data_root=args.data_root, flip_p=args.flip_p, @@ -540,11 +566,8 @@ def main(args): sample_prompts.append(line.strip()) - if False: #args.wandb is not None and args.wandb: # not yet supported - log_writer = wandb.init(project="EveryDream2FineTunes", - name=args.project_name, - dir=log_folder, - ) + if args.wandb is not None and args.wandb: + wandb.init(project=args.project_name, sync_tensorboard=True, ) else: log_writer = SummaryWriter(log_dir=log_folder, flush_secs=5, @@ -602,7 +625,6 @@ def main(args): logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs") - def collate_fn(batch): """ Collates batches @@ -632,7 +654,7 @@ def main(args): collate_fn=collate_fn ) - unet.train() + unet.train() if not args.disable_unet_training else unet.eval() text_encoder.train() if not args.disable_textenc_training else text_encoder.eval() logging.info(f" unet device: {unet.device}, precision: {unet.dtype}, training: {unet.training}") @@ -643,9 +665,20 @@ def main(args): logging.info(f" {Fore.GREEN}Project name: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.project_name}{Style.RESET_ALL}") logging.info(f" {Fore.GREEN}grad_accum: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.grad_accum}{Style.RESET_ALL}"), logging.info(f" {Fore.GREEN}batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.batch_size}{Style.RESET_ALL}") - #logging.info(f" {Fore.GREEN}total_batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{total_batch_size}") logging.info(f" {Fore.GREEN}epoch_len: {Fore.LIGHTGREEN_EX}{epoch_len}{Style.RESET_ALL}") + if args.amp or d_type != torch.float32: + #scaler = torch.cuda.amp.GradScaler() + scaler = torch.cuda.amp.GradScaler( + enabled=False, + #enabled=True, + init_scale=2048.0, + growth_factor=1.5, + backoff_factor=0.707, + growth_interval=50, + ) + logging.info(f" Grad scaler enabled: {scaler.is_enabled()}") + epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True) epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}") @@ -661,20 +694,6 @@ def main(args): append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer) - #loss = torch.tensor(0.0, device=device, dtype=torch.float32) - - if args.amp: - #scaler = torch.cuda.amp.GradScaler() - scaler = torch.cuda.amp.GradScaler( - #enabled=False, - enabled=True, - init_scale=1024.0, - growth_factor=2.0, - backoff_factor=0.5, - growth_interval=50, - ) - logging.info(f" Grad scaler enabled: {scaler.is_enabled()}") - loss_log_step = [] try: @@ -723,8 +742,8 @@ def main(args): raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") del noise, latents, cuda_caption - #with autocast(enabled=args.amp): - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + with autocast(enabled=args.amp or d_type != torch.float32): + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample del timesteps, encoder_hidden_states, noisy_latents #with autocast(enabled=args.amp): @@ -732,15 +751,17 @@ def main(args): del target, model_pred - if args.clip_grad_norm is not None: - torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm) - torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm) - if args.amp: scaler.scale(loss).backward() else: loss.backward() + if args.clip_grad_norm is not None: + if not args.disable_unet_training: + torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm) + if not args.disable_textenc_training: + torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm) + if batch["runt_size"] > 0: grad_scale = batch["runt_size"] / args.batch_size with torch.no_grad(): # not required? just in case for now, needs more testing @@ -753,7 +774,7 @@ def main(args): param.grad *= grad_scale if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1): - if args.amp: + if args.amp and d_type == torch.float32: scaler.step(optimizer) scaler.update() else: @@ -779,6 +800,7 @@ def main(args): loss_log_step = [] logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec} log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step) + log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step) sum_img = sum(images_per_sec_log_step) avg = sum_img / len(images_per_sec_log_step) images_per_sec_log_step = [] @@ -861,17 +883,25 @@ def update_old_args(t_args): Update old args to new args to deal with json config loading and missing args for compatibility """ if not hasattr(t_args, "shuffle_tags"): - print(f" Config json is missing 'shuffle_tags'") + print(f" Config json is missing 'shuffle_tags' flag") t_args.__dict__["shuffle_tags"] = False if not hasattr(t_args, "save_full_precision"): - print(f" Config json is missing 'save_full_precision'") + print(f" Config json is missing 'save_full_precision' flag") t_args.__dict__["save_full_precision"] = False if not hasattr(t_args, "notebook"): - print(f" Config json is missing 'notebook'") + print(f" Config json is missing 'notebook' flag") t_args.__dict__["notebook"] = False + if not hasattr(t_args, "disable_unet_training"): + print(f" Config json is missing 'disable_unet_training' flag") + t_args.__dict__["disable_unet_training"] = False + if not hasattr(t_args, "mixed_precision"): + print(f" Config json is missing 'mixed_precision' flag") + t_args.__dict__["mixed_precision"] = "fp32" + if __name__ == "__main__": supported_resolutions = [256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152] + supported_precisions = ['fp16', 'fp32'] argparser = argparse.ArgumentParser(description="EveryDream2 Training options") argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from") args, _ = argparser.parse_known_args() @@ -881,9 +911,11 @@ if __name__ == "__main__": with open(args.config, 'rt') as f: t_args = argparse.Namespace() t_args.__dict__.update(json.load(f)) + print(t_args.__dict__) update_old_args(t_args) # update args to support older configs print(t_args.__dict__) args = argparser.parse_args(namespace=t_args) + print(f"mixed_precision: {args.mixed_precision}") else: print("No config file specified, using command line args") argparser = argparse.ArgumentParser(description="EveryDream2 Training options") @@ -894,7 +926,8 @@ if __name__ == "__main__": argparser.add_argument("--clip_skip", type=int, default=0, help="Train using penultimate layer (def: 0) (2 is 'penultimate')", choices=[0, 1, 2, 3, 4]) argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)") argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are") - argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False) NOT RECOMMENDED") + argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False)") + argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED") argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)") argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!") argparser.add_argument("--ed1_mode", action="store_true", default=False, help="Disables xformers and reduces attention heads to 8 (SD1.x style)") @@ -909,6 +942,8 @@ if __name__ == "__main__": argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"]) argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant") argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for") + argparser.add_argument("--mixed_precision", type=str, default='fp32', help="precision for the model training", choices=supported_precisions) + argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)") argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'") argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions) argparser.add_argument("--resume_ckpt", type=str, required=True, default="sd_v1-5_vae.ckpt") @@ -916,6 +951,7 @@ if __name__ == "__main__": argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)") argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)") argparser.add_argument("--save_every_n_epochs", type=int, default=None, help="Save checkpoint every n epochs, def: 0 (disabled)") + argparser.add_argument("--save_full_precision", action="store_true", default=False, help="save ckpts at full FP32") argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later") argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)") argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random") @@ -923,8 +959,6 @@ if __name__ == "__main__": argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!") argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY") argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)") - argparser.add_argument("--save_full_precision", action="store_true", default=False, help="save ckpts at full FP32") - argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)") argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs") argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)") diff --git a/utils/log_wrapper.py b/utils/log_wrapper.py index 5008d6b..65aa4a1 100644 --- a/utils/log_wrapper.py +++ b/utils/log_wrapper.py @@ -16,24 +16,67 @@ limitations under the License. import logging import os import time +from colorama import Fore, Style -class LogWrapper(object): +from tensorboard import SummaryWriter +import wandb + +class LogWrapper(): """ singleton for logging """ - def __init__(self, log_dir, project_name): - self.log_dir = log_dir + def __init__(self, args, wandb=False): + self.logdir = args.logdir + self.wandb = wandb + + if wandb: + wandb.init(project=args.project_name, sync_tensorboard=True) + else: + self.log_writer = SummaryWriter(log_dir=args.logdir, + flush_secs=5, + comment="EveryDream2FineTunes", + ) start_time = time.strftime("%Y%m%d-%H%M") - self.log_file = os.path.join(log_dir, f"log-{project_name}-{start_time}.txt") + log_file = os.path.join(args.logdir, f"log-{args.project_name}-{start_time}.txt") self.logger = logging.getLogger(__name__) console = logging.StreamHandler() self.logger.addHandler(console) - file = logging.FileHandler(self.log_file, mode="a", encoding=None, delay=False) + file = logging.FileHandler(log_file, mode="a", encoding=None, delay=False) self.logger.addHandler(file) - def __call__(self): - return self.logger + def add_image(): + """ + log_writer.add_image(tag=f"sample_{i}", img_tensor=tfimage, global_step=gs) + else: + log_writer.add_image(tag=f"sample_{i}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=gs) + """ + pass + + def add_scalar(self, tag: str, img_tensor: float, global_step: int): + if self.wandb: + wandb.log({tag: img_tensor}, step=global_step) + else: + self.log_writer.add_image(tag, img_tensor, global_step) + + def append_epoch_log(self, global_step: int, epoch_pbar, gpu, log_writer, **logs): + """ + updates the vram usage for the epoch + """ + gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory() + self.add_scalar("performance/vram", gpu_used_mem, global_step) + epoch_mem_color = Style.RESET_ALL + if gpu_used_mem > 0.93 * gpu_total_mem: + epoch_mem_color = Fore.LIGHTRED_EX + elif gpu_used_mem > 0.85 * gpu_total_mem: + epoch_mem_color = Fore.LIGHTYELLOW_EX + elif gpu_used_mem > 0.7 * gpu_total_mem: + epoch_mem_color = Fore.LIGHTGREEN_EX + elif gpu_used_mem < 0.5 * gpu_total_mem: + epoch_mem_color = Fore.LIGHTBLUE_EX + + if logs is not None: + epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}") \ No newline at end of file