From 56f130c027f981eac63d20e310261cf6f10d1740 Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Sun, 29 Jan 2023 18:11:34 -0800 Subject: [PATCH] Forgot to add train.py earlier :facepalm:; move write_batch_schedule to train.py --- data/every_dream.py | 18 ----------- train.py | 78 ++++++++++++++++++++++++++++++++++++++------- 2 files changed, 66 insertions(+), 30 deletions(-) diff --git a/data/every_dream.py b/data/every_dream.py index e21d639..38af008 100644 --- a/data/every_dream.py +++ b/data/every_dream.py @@ -38,9 +38,7 @@ class EveryDreamBatch(Dataset): crop_jitter=20, seed=555, tokenizer=None, - log_folder=None, retain_contrast=False, - write_schedule=False, shuffle_tags=False, rated_dataset=False, rated_dataset_dropout_target=0.5 @@ -52,10 +50,8 @@ class EveryDreamBatch(Dataset): self.crop_jitter = crop_jitter self.unloaded_to_idx = 0 self.tokenizer = tokenizer - self.log_folder = log_folder self.max_token_length = self.tokenizer.model_max_length self.retain_contrast = retain_contrast - self.write_schedule = write_schedule self.shuffle_tags = shuffle_tags self.seed = seed self.rated_dataset = rated_dataset @@ -64,18 +60,7 @@ class EveryDreamBatch(Dataset): self.image_train_items = self.data_loader.get_shuffled_image_buckets(1.0) num_images = len(self.image_train_items) - logging.info(f" ** Trainer Set: {num_images / self.batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}") - if self.write_schedule: - self.__write_batch_schedule(0) - - def __write_batch_schedule(self, epoch_n): - with open(f"{self.log_folder}/ep{epoch_n}_batch_schedule.txt", "w", encoding='utf-8') as f: - for i in range(len(self.image_train_items)): - try: - f.write(f"step:{int(i / self.batch_size):05}, wh:{self.image_train_items[i].target_wh}, r:{self.image_train_items[i].runt_size}, path:{self.image_train_items[i].pathname}\n") - except Exception as e: - logging.error(f" * Error writing to batch schedule for file path: {self.image_train_items[i].pathname}") def shuffle(self, epoch_n: int, max_epochs: int): self.seed += 1 @@ -87,9 +72,6 @@ class EveryDreamBatch(Dataset): self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction) - if self.write_schedule: - self.__write_batch_schedule(epoch_n + 1) - def __len__(self): return len(self.image_train_items) diff --git a/train.py b/train.py index 1ad92d6..db4b908 100644 --- a/train.py +++ b/train.py @@ -15,6 +15,7 @@ limitations under the License. """ import os +import pprint import sys import math import signal @@ -48,11 +49,15 @@ from accelerate.utils import set_seed import wandb from torch.utils.tensorboard import SummaryWriter +from data.data_loader import DataLoaderMultiAspect from data.every_dream import EveryDreamBatch +from data.image_train_item import ImageTrainItem from utils.huggingface_downloader import try_download_model_from_hf from utils.convert_diff_to_ckpt import convert as converter from utils.gpu import GPU +import data.aspects as aspects +import data.resolver as resolver _SIGTERM_EXIT_CODE = 130 _VERY_LARGE_NUMBER = 1e9 @@ -265,6 +270,8 @@ def setup_args(args): logging.info(logging.info(f"{Fore.CYAN} * Activating rated images learning with a target rate of {args.rated_dataset_target_dropout_percent}% {Style.RESET_ALL}")) + args.aspects = aspects.get_aspect_buckets(args.resolution) + return args def update_grad_scaler(scaler: GradScaler, global_step, epoch, step): @@ -288,6 +295,35 @@ def update_grad_scaler(scaler: GradScaler, global_step, epoch, step): scaler.set_growth_factor(factor) scaler.set_backoff_factor(1/factor) scaler.set_growth_interval(100) + + +def report_image_train_item_problems(log_folder, items: list[ImageTrainItem]) -> None: + for item in items: + if item.error is not None: + logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}") + logging.error(f" *** exception: {item.error}") + + undersized_items = [item for item in items if item.is_undersized] + + if len(undersized_items) > 0: + underized_log_path = os.path.join(log_folder, "undersized_images.txt") + logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}") + logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}") + with open(underized_log_path, "w") as undersized_images_file: + undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:") + for undersized_item in undersized_items: + message = f" *** {undersized_item.pathname} with size: {undersized_item.image_size} is smaller than target size: {undersized_item.target_wh}\n" + undersized_images_file.write(message) + +def write_batch_schedule(log_folder, train_batch, epoch): + if args.write_schedule: + with open(f"{log_folder}/ep{epoch}_batch_schedule.txt", "w", encoding='utf-8') as f: + for i in range(len(train_batch.image_train_items)): + try: + item = train_batch.image_train_items[i] + f.write(f"step:{int(i / train_batch.batch_size):05}, wh:{item.target_wh}, r:{item.runt_size}, path:{item.pathname}\n") + except Exception as e: + logging.error(f" * Error writing to batch schedule for file path: {item.pathname}") def main(args): """ @@ -313,6 +349,8 @@ def main(args): if not os.path.exists(log_folder): os.makedirs(log_folder) + + args.log_folder = log_folder @torch.no_grad() def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, yaml_name, save_full_precision=False): @@ -547,23 +585,34 @@ def main(args): ) log_optimizer(optimizer, betas, epsilon) + + logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}") + logging.info(" Preloading images...") + + image_train_items = resolver.resolve(args.data_root, args) + report_image_train_item_problems(log_folder, image_train_items) + image_paths = set(map(lambda item: item.pathname, image_train_items)) + # Remove erroneous items + image_train_items = [item for item in image_train_items if item.error is None] + print (f" * DLMA: {len(image_train_items)} images loaded from {len(image_paths)} files") + + data_loader = DataLoaderMultiAspect( + image_train_items=image_train_items, + seed=seed, + batch_size=args.batch_size, + ) train_batch = EveryDreamBatch( - data_root=args.data_root, - flip_p=args.flip_p, + data_loader=data_loader, debug_level=1, - batch_size=args.batch_size, conditional_dropout=args.cond_dropout, - resolution=args.resolution, tokenizer=tokenizer, seed = seed, - log_folder=log_folder, - write_schedule=args.write_schedule, shuffle_tags=args.shuffle_tags, rated_dataset=args.rated_dataset, rated_dataset_dropout_target=(1.0 - (args.rated_dataset_target_dropout_percent / 100.0)) ) - + torch.cuda.benchmark = False epoch_len = math.ceil(len(train_batch) / args.batch_size) @@ -589,10 +638,11 @@ def main(args): if args.wandb is not None and args.wandb: wandb.init(project=args.project_name, sync_tensorboard=True, ) - log_writer = SummaryWriter(log_dir=log_folder, - flush_secs=5, - comment="EveryDream2FineTunes", - ) + log_writer = SummaryWriter( + log_dir=log_folder, + flush_secs=5, + comment="EveryDream2FineTunes", + ) def log_args(log_writer, args): arglog = "args:\n" @@ -729,6 +779,8 @@ def main(args): # # discard the grads, just want to pin memory # optimizer.zero_grad(set_to_none=True) + write_batch_schedule(log_folder, train_batch, 0) + for epoch in range(args.max_epochs): loss_epoch = [] epoch_start_time = time.time() @@ -879,6 +931,7 @@ def main(args): epoch_pbar.update(1) if epoch < args.max_epochs - 1: train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs) + write_batch_schedule(log_folder, train_batch, epoch + 1) loss_local = sum(loss_epoch) / len(loss_epoch) log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step) @@ -943,7 +996,6 @@ if __name__ == "__main__": t_args = argparse.Namespace() t_args.__dict__.update(json.load(f)) update_old_args(t_args) # update args to support older configs - print(f" args: \n{t_args.__dict__}") args = argparser.parse_args(namespace=t_args) else: print("No config file specified, using command line args") @@ -992,4 +1044,6 @@ if __name__ == "__main__": args, _ = argparser.parse_known_args() + print(f" Args:") + pprint.pprint(args.__dict__) main(args)