From 326d861a86d3bae6493e37292c8775c5a2ebf837 Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Sun, 29 Jan 2023 17:08:54 -0800 Subject: [PATCH 1/8] Push DLMA into main, pass config to resolve This patch * passes the configuration (`argparse.Namespace`) to the resolver, * pushes the DLMA code into the main function, * makes DLMA take a `list[ImageTrainItem]` instead of `data_root`, * makes `EveryDreamBatch` take `DLMA` instead of `data_root`, etc. * allows `data_root` to be a list. By doing these things, both `EveryDreamBatch` and DLMA can be free from data resolution logic. It also reduces the number of arguments which need to be passed down to EDB and DLMA. --- data/data_loader.py | 77 ++++++-------------------------------- data/every_dream.py | 28 +++----------- data/resolver.py | 62 +++++++++++++++--------------- test/test_data_resolver.py | 16 +++++--- 4 files changed, 59 insertions(+), 124 deletions(-) diff --git a/data/data_loader.py b/data/data_loader.py index 105f7dd..8db6f89 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -15,15 +15,10 @@ limitations under the License. """ import bisect import math -import os -import logging import copy import random from data.image_train_item import ImageTrainItem -import data.aspects as aspects -import data.resolver as resolver -from colorama import Fore, Style import PIL PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default @@ -34,22 +29,20 @@ class DataLoaderMultiAspect(): data_root: root folder of training data batch_size: number of images per batch - flip_p: probability of flipping image horizontally (i.e. 0-0.5) """ - def __init__(self, data_root, seed=555, debug_level=0, batch_size=1, flip_p=0.0, resolution=512, log_folder=None): - self.data_root = data_root - self.debug_level = debug_level - self.flip_p = flip_p - self.log_folder = log_folder + def __init__(self, image_train_items, seed=555, batch_size=1): self.seed = seed self.batch_size = batch_size - self.has_scanned = False - self.aspects = aspects.get_aspect_buckets(resolution=resolution, square_only=False) - - logging.info(f"* DLMA resolution {resolution}, buckets: {self.aspects}") - self.__prepare_train_data() - (self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings() - + # Prepare data + self.prepared_train_data = image_train_items + random.Random(self.seed).shuffle(self.prepared_train_data) + self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating()) + # Initialize ratings + self.rating_overall_sum: float = 0.0 + self.ratings_summed: list[float] = [] + for image in self.prepared_train_data: + self.rating_overall_sum += image.caption.rating() + self.ratings_summed.append(self.rating_overall_sum) def __pick_multiplied_set(self, randomizer): """ @@ -138,54 +131,6 @@ class DataLoaderMultiAspect(): return image_caption_pairs - def __sort_and_precalc_image_ratings(self) -> tuple[float, list[float]]: - self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating()) - - rating_overall_sum: float = 0.0 - ratings_summed: list[float] = [] - for image in self.prepared_train_data: - rating_overall_sum += image.caption.rating() - ratings_summed.append(rating_overall_sum) - - return rating_overall_sum, ratings_summed - - def __prepare_train_data(self, flip_p=0.0) -> list[ImageTrainItem]: - """ - Create ImageTrainItem objects with metadata for hydration later - """ - - if not self.has_scanned: - self.has_scanned = True - - logging.info(" Preloading images...") - - items = resolver.resolve(self.data_root, self.aspects, flip_p=flip_p, seed=self.seed) - image_paths = set(map(lambda item: item.pathname, items)) - - print (f" * DLMA: {len(items)} images loaded from {len(image_paths)} files") - - self.prepared_train_data = [item for item in items if item.error is None] - random.Random(self.seed).shuffle(self.prepared_train_data) - self.__report_errors(items) - - def __report_errors(self, items: list[ImageTrainItem]): - for item in items: - if item.error is not None: - logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}") - logging.error(f" *** exception: {item.error}") - - undersized_items = [item for item in items if item.is_undersized] - - if len(undersized_items) > 0: - underized_log_path = os.path.join(self.log_folder, "undersized_images.txt") - logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}") - logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}") - with open(underized_log_path, "w") as undersized_images_file: - undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:") - for undersized_item in undersized_items: - message = f" *** {undersized_item.pathname} with size: {undersized_item.image_size} is smaller than target size: {undersized_item.target_wh}\n" - undersized_images_file.write(message) - def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]: """ Picks a random subset of all images diff --git a/data/every_dream.py b/data/every_dream.py index aaaccd4..465b6fc 100644 --- a/data/every_dream.py +++ b/data/every_dream.py @@ -38,12 +38,9 @@ class EveryDreamBatch(Dataset): jitter: number of pixels to jitter the crop by, only for non-square images """ def __init__(self, - data_root, - flip_p=0.0, + data_loader: dlma, debug_level=0, - batch_size=1, conditional_dropout=0.02, - resolution=512, crop_jitter=20, seed=555, tokenizer=None, @@ -54,8 +51,8 @@ class EveryDreamBatch(Dataset): rated_dataset=False, rated_dataset_dropout_target=0.5 ): - self.data_root = data_root - self.batch_size = batch_size + self.data_loader = data_loader + self.batch_size = data_loader.batch_size self.debug_level = debug_level self.conditional_dropout = conditional_dropout self.crop_jitter = crop_jitter @@ -70,26 +67,11 @@ class EveryDreamBatch(Dataset): self.seed = seed self.rated_dataset = rated_dataset self.rated_dataset_dropout_target = rated_dataset_dropout_target - - if seed == -1: - seed = random.randint(0, 99999) - - if not dls.shared_dataloader: - logging.info(" * Creating new dataloader singleton") - dls.shared_dataloader = dlma(data_root=data_root, - seed=seed, - debug_level=debug_level, - batch_size=self.batch_size, - flip_p=flip_p, - resolution=resolution, - log_folder=self.log_folder, - ) - - self.image_train_items = dls.shared_dataloader.get_shuffled_image_buckets(1.0) # First epoch always trains on all images + self.image_train_items = self.data_loader.get_shuffled_image_buckets(1.0) # First epoch always trains on all images num_images = len(self.image_train_items) - logging.info(f" ** Trainer Set: {num_images / batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}") + logging.info(f" ** Trainer Set: {num_images / self.batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}") if self.write_schedule: self.__write_batch_schedule(0) diff --git a/data/resolver.py b/data/resolver.py index 23e1b44..fbcd076 100644 --- a/data/resolver.py +++ b/data/resolver.py @@ -4,18 +4,21 @@ import os import random import typing import zipfile +import argparse -import PIL.Image as Image import tqdm from colorama import Fore, Style from data.image_train_item import ImageCaption, ImageTrainItem class DataResolver: - def __init__(self, aspects: list[typing.Tuple[int, int]], flip_p=0.0, seed=555): - self.seed = seed - self.aspects = aspects - self.flip_p = flip_p + def __init__(self, args: argparse.Namespace): + """ + :param args: EveryDream configuration, an `argparse.Namespace` object. + """ + self.aspects = args.aspects + self.flip_p = args.flip_p + self.seed = args.seed def image_train_items(self, data_root: str) -> list[ImageTrainItem]: """ @@ -173,8 +176,11 @@ class DirectoryResolver(DataResolver): if os.path.isdir(current): yield from DirectoryResolver.recurse_data_root(current) - -def strategy(data_root: str): +def strategy(data_root: str) -> typing.Type[DataResolver]: + """ + Determine the strategy to use for resolving the data. + :param data_root: The root directory or JSON file to resolve. + """ if os.path.isfile(data_root) and data_root.endswith('.json'): return JSONResolver @@ -183,41 +189,37 @@ def strategy(data_root: str): raise ValueError(f"data_root '{data_root}' is not a valid directory or JSON file.") - -def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0, seed=555) -> list[ImageTrainItem]: +def resolve_root(path: str, args: argparse.Namespace) -> list[ImageTrainItem]: """ - :param data_root: Directory or JSON file. - :param aspects: The list of aspect ratios to use - :param flip_p: The probability of flipping the image + Resolve the training data from the root path. + :param path: The root path to resolve. + :param args: EveryDream configuration, an `argparse.Namespace` object. """ - if os.path.isfile(path) and path.endswith('.json'): - return JSONResolver(aspects, flip_p, seed).image_train_items(path) - - if os.path.isdir(path): - return DirectoryResolver(aspects, flip_p, seed).image_train_items(path) - - raise ValueError(f"data_root '{path}' is not a valid directory or JSON file.") + resolver = strategy(path) + return resolver(args).image_train_items(path) -def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=0.0, seed=555) -> list[ImageTrainItem]: +def resolve(value: typing.Union[dict, str], args: argparse.Namespace) -> list[ImageTrainItem]: """ Resolve the training data from the value. - :param value: The value to resolve, either a dict or a string. - :param aspects: The list of aspect ratios to use - :param flip_p: The probability of flipping the image + :param value: The value to resolve, either a dict, an array, or a string. + :param args: EveryDream configuration, an `argparse.Namespace` object. """ if isinstance(value, str): - return resolve_root(value, aspects, flip_p) + return resolve_root(value, args) if isinstance(value, dict): resolver = value.get('resolver', None) match resolver: case 'directory' | 'json': path = value.get('path', None) - return resolve_root(path, aspects, flip_p, seed) + return resolve_root(path, args) case 'multi': - items = [] - for resolver in value.get('resolvers', []): - items += resolve(resolver, aspects, flip_p, seed) - return items + return resolve(value.get('resolvers', []), args) case _: - raise ValueError(f"Cannot resolve training data for resolver value '{resolver}'") \ No newline at end of file + raise ValueError(f"Cannot resolve training data for resolver value '{resolver}'") + + if isinstance(value, list): + items = [] + for item in value: + items += resolve(item, args) + return items \ No newline at end of file diff --git a/test/test_data_resolver.py b/test/test_data_resolver.py index 625f228..575d9fe 100644 --- a/test/test_data_resolver.py +++ b/test/test_data_resolver.py @@ -2,6 +2,7 @@ import json import glob import os import unittest +import argparse import PIL.Image as Image @@ -10,13 +11,18 @@ import data.resolver as resolver DATA_PATH = os.path.abspath('./test/data') JSON_ROOT_PATH = os.path.join(DATA_PATH, 'test_root.json') -ASPECTS = aspects.get_aspect_buckets(512) IMAGE_1_PATH = os.path.join(DATA_PATH, 'test1.jpg') CAPTION_1_PATH = os.path.join(DATA_PATH, 'test1.txt') IMAGE_2_PATH = os.path.join(DATA_PATH, 'test2.jpg') IMAGE_3_PATH = os.path.join(DATA_PATH, 'test3.jpg') +ARGS = argparse.Namespace( + aspects=aspects.get_aspect_buckets(512), + flip_p=0.5, + seed=42, +) + class TestResolve(unittest.TestCase): @classmethod def setUpClass(cls): @@ -51,7 +57,7 @@ class TestResolve(unittest.TestCase): os.remove(file) def test_directory_resolve_with_str(self): - items = resolver.resolve(DATA_PATH, ASPECTS) + items = resolver.resolve(DATA_PATH, ARGS) image_paths = [item.pathname for item in items] image_captions = [item.caption for item in items] captions = [caption.get_caption() for caption in image_captions] @@ -69,7 +75,7 @@ class TestResolve(unittest.TestCase): 'path': DATA_PATH, } - items = resolver.resolve(data_root_spec, ASPECTS) + items = resolver.resolve(data_root_spec, ARGS) image_paths = [item.pathname for item in items] image_captions = [item.caption for item in items] captions = [caption.get_caption() for caption in image_captions] @@ -82,7 +88,7 @@ class TestResolve(unittest.TestCase): self.assertEqual(len(undersized_images), 1) def test_json_resolve_with_str(self): - items = resolver.resolve(JSON_ROOT_PATH, ASPECTS) + items = resolver.resolve(JSON_ROOT_PATH, ARGS) image_paths = [item.pathname for item in items] image_captions = [item.caption for item in items] captions = [caption.get_caption() for caption in image_captions] @@ -100,7 +106,7 @@ class TestResolve(unittest.TestCase): 'path': JSON_ROOT_PATH, } - items = resolver.resolve(data_root_spec, ASPECTS) + items = resolver.resolve(data_root_spec, ARGS) image_paths = [item.pathname for item in items] image_captions = [item.caption for item in items] captions = [caption.get_caption() for caption in image_captions] From 09d95fac58e89c717c9fa4d9202692f33d188781 Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Sun, 29 Jan 2023 17:21:12 -0800 Subject: [PATCH 2/8] Add test for list data resolver --- test/test_data_resolver.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/test/test_data_resolver.py b/test/test_data_resolver.py index 575d9fe..f668974 100644 --- a/test/test_data_resolver.py +++ b/test/test_data_resolver.py @@ -116,4 +116,22 @@ class TestResolve(unittest.TestCase): self.assertEqual(captions, ['caption for test1', 'caption for test2', 'test3']) undersized_images = list(filter(lambda i: i.is_undersized, items)) - self.assertEqual(len(undersized_images), 1) \ No newline at end of file + self.assertEqual(len(undersized_images), 1) + + def test_resolve_with_list(self): + data_root_spec = [ + DATA_PATH, + JSON_ROOT_PATH, + ] + + items = resolver.resolve(data_root_spec, ARGS) + image_paths = [item.pathname for item in items] + image_captions = [item.caption for item in items] + captions = [caption.get_caption() for caption in image_captions] + + self.assertEqual(len(items), 6) + self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH] * 2) + self.assertEqual(captions, ['caption for test1', 'test2', 'test3', 'caption for test1', 'caption for test2', 'test3']) + + undersized_images = list(filter(lambda i: i.is_undersized, items)) + self.assertEqual(len(undersized_images), 2) \ No newline at end of file From c0ec46c03015660597327b5f81fce906ca111862 Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Sun, 29 Jan 2023 17:31:57 -0800 Subject: [PATCH 3/8] Don't need to set data loader singleton; formatting tweaks --- data/every_dream.py | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/data/every_dream.py b/data/every_dream.py index 465b6fc..0d0cd41 100644 --- a/data/every_dream.py +++ b/data/every_dream.py @@ -16,15 +16,12 @@ limitations under the License. import logging import torch from torch.utils.data import Dataset -from data.data_loader import DataLoaderMultiAspect as dlma -import math -import data.dl_singleton as dls +from data.data_loader import DataLoaderMultiAspect from data.image_train_item import ImageTrainItem import random from torchvision import transforms from transformers import CLIPTokenizer import torch.nn.functional as F -import numpy class EveryDreamBatch(Dataset): """ @@ -38,7 +35,7 @@ class EveryDreamBatch(Dataset): jitter: number of pixels to jitter the crop by, only for non-square images """ def __init__(self, - data_loader: dlma, + data_loader: DataLoaderMultiAspect, debug_level=0, conditional_dropout=0.02, crop_jitter=20, @@ -59,7 +56,6 @@ class EveryDreamBatch(Dataset): self.unloaded_to_idx = 0 self.tokenizer = tokenizer self.log_folder = log_folder - #print(f"tokenizer: {tokenizer}") self.max_token_length = self.tokenizer.model_max_length self.retain_contrast = retain_contrast self.write_schedule = write_schedule @@ -67,8 +63,9 @@ class EveryDreamBatch(Dataset): self.seed = seed self.rated_dataset = rated_dataset self.rated_dataset_dropout_target = rated_dataset_dropout_target - self.image_train_items = self.data_loader.get_shuffled_image_buckets(1.0) # First epoch always trains on all images - + # First epoch always trains on all images + self.image_train_items = self.data_loader.get_shuffled_image_buckets(1.0) + num_images = len(self.image_train_items) logging.info(f" ** Trainer Set: {num_images / self.batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}") @@ -83,20 +80,15 @@ class EveryDreamBatch(Dataset): except Exception as e: logging.error(f" * Error writing to batch schedule for file path: {self.image_train_items[i].pathname}") - def get_runts(): - return dls.shared_dataloader.runts - def shuffle(self, epoch_n: int, max_epochs: int): self.seed += 1 - if dls.shared_dataloader: - if self.rated_dataset: - dropout_fraction = (max_epochs - (epoch_n * self.rated_dataset_dropout_target)) / max_epochs - else: - dropout_fraction = 1.0 - - self.image_train_items = dls.shared_dataloader.get_shuffled_image_buckets(dropout_fraction) + + if self.rated_dataset: + dropout_fraction = (max_epochs - (epoch_n * self.rated_dataset_dropout_target)) / max_epochs else: - raise Exception("No dataloader singleton to shuffle") + dropout_fraction = 1.0 + + self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction) if self.write_schedule: self.__write_batch_schedule(epoch_n + 1) From 3fe335f3283a549bf414ab6c14d8c4c456bc4464 Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Sun, 29 Jan 2023 17:47:10 -0800 Subject: [PATCH 4/8] Update documentation --- data/data_loader.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/data/data_loader.py b/data/data_loader.py index 8db6f89..6222dc2 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -27,10 +27,11 @@ class DataLoaderMultiAspect(): """ Data loader for multi-aspect-ratio training and bucketing - data_root: root folder of training data + image_train_items: list of `lImageTrainItem` objects + seed: random seed batch_size: number of images per batch """ - def __init__(self, image_train_items, seed=555, batch_size=1): + def __init__(self, image_train_items: list[ImageTrainItem], seed=555, batch_size=1): self.seed = seed self.batch_size = batch_size # Prepare data From 12a0cb6286c985b7c04ff1e77aedccb5a583ac4d Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Sun, 29 Jan 2023 17:58:42 -0800 Subject: [PATCH 5/8] Update documentation --- data/every_dream.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/data/every_dream.py b/data/every_dream.py index 0d0cd41..e21d639 100644 --- a/data/every_dream.py +++ b/data/every_dream.py @@ -25,14 +25,11 @@ import torch.nn.functional as F class EveryDreamBatch(Dataset): """ - data_root: root path of all your training images, will be recursively searched for images - repeats: how many times to repeat each image in the dataset - flip_p: probability of flipping the image horizontally + data_loader: `DataLoaderMultiAspect` object debug_level: 0=none, 1=print drops due to unfilled batches on aspect ratio buckets, 2=debug info per image, 3=save crops to disk for inspection - batch_size: how many images to return in a batch conditional_dropout: probability of dropping the caption for a given image - resolution: max resolution (relative to square) - jitter: number of pixels to jitter the crop by, only for non-square images + crop_jitter: number of pixels to jitter the crop by, only for non-square images + seed: random seed """ def __init__(self, data_loader: DataLoaderMultiAspect, From 56f130c027f981eac63d20e310261cf6f10d1740 Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Sun, 29 Jan 2023 18:11:34 -0800 Subject: [PATCH 6/8] Forgot to add train.py earlier :facepalm:; move write_batch_schedule to train.py --- data/every_dream.py | 18 ----------- train.py | 78 ++++++++++++++++++++++++++++++++++++++------- 2 files changed, 66 insertions(+), 30 deletions(-) diff --git a/data/every_dream.py b/data/every_dream.py index e21d639..38af008 100644 --- a/data/every_dream.py +++ b/data/every_dream.py @@ -38,9 +38,7 @@ class EveryDreamBatch(Dataset): crop_jitter=20, seed=555, tokenizer=None, - log_folder=None, retain_contrast=False, - write_schedule=False, shuffle_tags=False, rated_dataset=False, rated_dataset_dropout_target=0.5 @@ -52,10 +50,8 @@ class EveryDreamBatch(Dataset): self.crop_jitter = crop_jitter self.unloaded_to_idx = 0 self.tokenizer = tokenizer - self.log_folder = log_folder self.max_token_length = self.tokenizer.model_max_length self.retain_contrast = retain_contrast - self.write_schedule = write_schedule self.shuffle_tags = shuffle_tags self.seed = seed self.rated_dataset = rated_dataset @@ -64,18 +60,7 @@ class EveryDreamBatch(Dataset): self.image_train_items = self.data_loader.get_shuffled_image_buckets(1.0) num_images = len(self.image_train_items) - logging.info(f" ** Trainer Set: {num_images / self.batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}") - if self.write_schedule: - self.__write_batch_schedule(0) - - def __write_batch_schedule(self, epoch_n): - with open(f"{self.log_folder}/ep{epoch_n}_batch_schedule.txt", "w", encoding='utf-8') as f: - for i in range(len(self.image_train_items)): - try: - f.write(f"step:{int(i / self.batch_size):05}, wh:{self.image_train_items[i].target_wh}, r:{self.image_train_items[i].runt_size}, path:{self.image_train_items[i].pathname}\n") - except Exception as e: - logging.error(f" * Error writing to batch schedule for file path: {self.image_train_items[i].pathname}") def shuffle(self, epoch_n: int, max_epochs: int): self.seed += 1 @@ -87,9 +72,6 @@ class EveryDreamBatch(Dataset): self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction) - if self.write_schedule: - self.__write_batch_schedule(epoch_n + 1) - def __len__(self): return len(self.image_train_items) diff --git a/train.py b/train.py index 1ad92d6..db4b908 100644 --- a/train.py +++ b/train.py @@ -15,6 +15,7 @@ limitations under the License. """ import os +import pprint import sys import math import signal @@ -48,11 +49,15 @@ from accelerate.utils import set_seed import wandb from torch.utils.tensorboard import SummaryWriter +from data.data_loader import DataLoaderMultiAspect from data.every_dream import EveryDreamBatch +from data.image_train_item import ImageTrainItem from utils.huggingface_downloader import try_download_model_from_hf from utils.convert_diff_to_ckpt import convert as converter from utils.gpu import GPU +import data.aspects as aspects +import data.resolver as resolver _SIGTERM_EXIT_CODE = 130 _VERY_LARGE_NUMBER = 1e9 @@ -265,6 +270,8 @@ def setup_args(args): logging.info(logging.info(f"{Fore.CYAN} * Activating rated images learning with a target rate of {args.rated_dataset_target_dropout_percent}% {Style.RESET_ALL}")) + args.aspects = aspects.get_aspect_buckets(args.resolution) + return args def update_grad_scaler(scaler: GradScaler, global_step, epoch, step): @@ -288,6 +295,35 @@ def update_grad_scaler(scaler: GradScaler, global_step, epoch, step): scaler.set_growth_factor(factor) scaler.set_backoff_factor(1/factor) scaler.set_growth_interval(100) + + +def report_image_train_item_problems(log_folder, items: list[ImageTrainItem]) -> None: + for item in items: + if item.error is not None: + logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}") + logging.error(f" *** exception: {item.error}") + + undersized_items = [item for item in items if item.is_undersized] + + if len(undersized_items) > 0: + underized_log_path = os.path.join(log_folder, "undersized_images.txt") + logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}") + logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}") + with open(underized_log_path, "w") as undersized_images_file: + undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:") + for undersized_item in undersized_items: + message = f" *** {undersized_item.pathname} with size: {undersized_item.image_size} is smaller than target size: {undersized_item.target_wh}\n" + undersized_images_file.write(message) + +def write_batch_schedule(log_folder, train_batch, epoch): + if args.write_schedule: + with open(f"{log_folder}/ep{epoch}_batch_schedule.txt", "w", encoding='utf-8') as f: + for i in range(len(train_batch.image_train_items)): + try: + item = train_batch.image_train_items[i] + f.write(f"step:{int(i / train_batch.batch_size):05}, wh:{item.target_wh}, r:{item.runt_size}, path:{item.pathname}\n") + except Exception as e: + logging.error(f" * Error writing to batch schedule for file path: {item.pathname}") def main(args): """ @@ -313,6 +349,8 @@ def main(args): if not os.path.exists(log_folder): os.makedirs(log_folder) + + args.log_folder = log_folder @torch.no_grad() def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, yaml_name, save_full_precision=False): @@ -547,23 +585,34 @@ def main(args): ) log_optimizer(optimizer, betas, epsilon) + + logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}") + logging.info(" Preloading images...") + + image_train_items = resolver.resolve(args.data_root, args) + report_image_train_item_problems(log_folder, image_train_items) + image_paths = set(map(lambda item: item.pathname, image_train_items)) + # Remove erroneous items + image_train_items = [item for item in image_train_items if item.error is None] + print (f" * DLMA: {len(image_train_items)} images loaded from {len(image_paths)} files") + + data_loader = DataLoaderMultiAspect( + image_train_items=image_train_items, + seed=seed, + batch_size=args.batch_size, + ) train_batch = EveryDreamBatch( - data_root=args.data_root, - flip_p=args.flip_p, + data_loader=data_loader, debug_level=1, - batch_size=args.batch_size, conditional_dropout=args.cond_dropout, - resolution=args.resolution, tokenizer=tokenizer, seed = seed, - log_folder=log_folder, - write_schedule=args.write_schedule, shuffle_tags=args.shuffle_tags, rated_dataset=args.rated_dataset, rated_dataset_dropout_target=(1.0 - (args.rated_dataset_target_dropout_percent / 100.0)) ) - + torch.cuda.benchmark = False epoch_len = math.ceil(len(train_batch) / args.batch_size) @@ -589,10 +638,11 @@ def main(args): if args.wandb is not None and args.wandb: wandb.init(project=args.project_name, sync_tensorboard=True, ) - log_writer = SummaryWriter(log_dir=log_folder, - flush_secs=5, - comment="EveryDream2FineTunes", - ) + log_writer = SummaryWriter( + log_dir=log_folder, + flush_secs=5, + comment="EveryDream2FineTunes", + ) def log_args(log_writer, args): arglog = "args:\n" @@ -729,6 +779,8 @@ def main(args): # # discard the grads, just want to pin memory # optimizer.zero_grad(set_to_none=True) + write_batch_schedule(log_folder, train_batch, 0) + for epoch in range(args.max_epochs): loss_epoch = [] epoch_start_time = time.time() @@ -879,6 +931,7 @@ def main(args): epoch_pbar.update(1) if epoch < args.max_epochs - 1: train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs) + write_batch_schedule(log_folder, train_batch, epoch + 1) loss_local = sum(loss_epoch) / len(loss_epoch) log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step) @@ -943,7 +996,6 @@ if __name__ == "__main__": t_args = argparse.Namespace() t_args.__dict__.update(json.load(f)) update_old_args(t_args) # update args to support older configs - print(f" args: \n{t_args.__dict__}") args = argparser.parse_args(namespace=t_args) else: print("No config file specified, using command line args") @@ -992,4 +1044,6 @@ if __name__ == "__main__": args, _ = argparser.parse_known_args() + print(f" Args:") + pprint.pprint(args.__dict__) main(args) From f96d44ddb4f28e69da5c4036f155b7e133e227a8 Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Sun, 29 Jan 2023 18:20:40 -0800 Subject: [PATCH 7/8] Move image resolution into its own function --- train.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/train.py b/train.py index db4b908..af0d1ee 100644 --- a/train.py +++ b/train.py @@ -296,8 +296,7 @@ def update_grad_scaler(scaler: GradScaler, global_step, epoch, step): scaler.set_backoff_factor(1/factor) scaler.set_growth_interval(100) - -def report_image_train_item_problems(log_folder, items: list[ImageTrainItem]) -> None: +def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem]) -> None: for item in items: if item.error is not None: logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}") @@ -315,6 +314,21 @@ def report_image_train_item_problems(log_folder, items: list[ImageTrainItem]) -> message = f" *** {undersized_item.pathname} with size: {undersized_item.image_size} is smaller than target size: {undersized_item.target_wh}\n" undersized_images_file.write(message) +def resolve_image_train_items(args: argparse.Namespace, log_folder: str) -> list[ImageTrainItem]: + logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}") + logging.info(" Preloading images...") + + resolved_items = resolver.resolve(args.data_root, args) + report_image_train_item_problems(log_folder, resolved_items) + image_paths = set(map(lambda item: item.pathname, resolved_items)) + + # Remove erroneous items + image_train_items = [item for item in resolved_items if item.error is None] + + print (f" * DLMA: {len(image_train_items)} images loaded from {len(image_paths)} files") + + return image_train_items + def write_batch_schedule(log_folder, train_batch, epoch): if args.write_schedule: with open(f"{log_folder}/ep{epoch}_batch_schedule.txt", "w", encoding='utf-8') as f: @@ -349,8 +363,6 @@ def main(args): if not os.path.exists(log_folder): os.makedirs(log_folder) - - args.log_folder = log_folder @torch.no_grad() def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, yaml_name, save_full_precision=False): @@ -586,15 +598,7 @@ def main(args): log_optimizer(optimizer, betas, epsilon) - logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}") - logging.info(" Preloading images...") - - image_train_items = resolver.resolve(args.data_root, args) - report_image_train_item_problems(log_folder, image_train_items) - image_paths = set(map(lambda item: item.pathname, image_train_items)) - # Remove erroneous items - image_train_items = [item for item in image_train_items if item.error is None] - print (f" * DLMA: {len(image_train_items)} images loaded from {len(image_paths)} files") + image_train_items = resolve_image_train_items(args, log_folder) data_loader = DataLoaderMultiAspect( image_train_items=image_train_items, @@ -958,7 +962,6 @@ def main(args): logging.info(f"{Fore.LIGHTWHITE_EX} **** Finished training ****{Style.RESET_ALL}") logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}") - def update_old_args(t_args): """ Update old args to new args to deal with json config loading and missing args for compatibility From c8c658d181a2999a179f3a5b3321c4c203bb09fe Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Sun, 29 Jan 2023 18:28:07 -0800 Subject: [PATCH 8/8] Forgot to pass args to write_batch_schedule --- train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index af0d1ee..13de14f 100644 --- a/train.py +++ b/train.py @@ -329,7 +329,7 @@ def resolve_image_train_items(args: argparse.Namespace, log_folder: str) -> list return image_train_items -def write_batch_schedule(log_folder, train_batch, epoch): +def write_batch_schedule(args: argparse.Namespace, log_folder: str, train_batch: EveryDreamBatch, epoch: int): if args.write_schedule: with open(f"{log_folder}/ep{epoch}_batch_schedule.txt", "w", encoding='utf-8') as f: for i in range(len(train_batch.image_train_items)): @@ -783,7 +783,7 @@ def main(args): # # discard the grads, just want to pin memory # optimizer.zero_grad(set_to_none=True) - write_batch_schedule(log_folder, train_batch, 0) + write_batch_schedule(args, log_folder, train_batch, 0) for epoch in range(args.max_epochs): loss_epoch = [] @@ -935,7 +935,7 @@ def main(args): epoch_pbar.update(1) if epoch < args.max_epochs - 1: train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs) - write_batch_schedule(log_folder, train_batch, epoch + 1) + write_batch_schedule(args, log_folder, train_batch, epoch + 1) loss_local = sum(loss_epoch) / len(loss_epoch) log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)