Push DLMA into main, pass config to resolve

This patch

* passes the configuration (`argparse.Namespace`) to the resolver,
* pushes the DLMA code into the main function,
* makes DLMA take a `list[ImageTrainItem]` instead of `data_root`,
* makes `EveryDreamBatch` take `DLMA` instead of `data_root`, etc.
* allows `data_root` to be a list.

By doing these things, both `EveryDreamBatch` and DLMA can be free from
data resolution logic. It also reduces the number of arguments which
need to be passed down to EDB and DLMA.
This commit is contained in:
Joel Holdbrooks 2023-01-29 17:08:54 -08:00
parent bc273d0512
commit 326d861a86
4 changed files with 59 additions and 124 deletions

View File

@ -15,15 +15,10 @@ limitations under the License.
"""
import bisect
import math
import os
import logging
import copy
import random
from data.image_train_item import ImageTrainItem
import data.aspects as aspects
import data.resolver as resolver
from colorama import Fore, Style
import PIL
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
@ -34,22 +29,20 @@ class DataLoaderMultiAspect():
data_root: root folder of training data
batch_size: number of images per batch
flip_p: probability of flipping image horizontally (i.e. 0-0.5)
"""
def __init__(self, data_root, seed=555, debug_level=0, batch_size=1, flip_p=0.0, resolution=512, log_folder=None):
self.data_root = data_root
self.debug_level = debug_level
self.flip_p = flip_p
self.log_folder = log_folder
def __init__(self, image_train_items, seed=555, batch_size=1):
self.seed = seed
self.batch_size = batch_size
self.has_scanned = False
self.aspects = aspects.get_aspect_buckets(resolution=resolution, square_only=False)
logging.info(f"* DLMA resolution {resolution}, buckets: {self.aspects}")
self.__prepare_train_data()
(self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings()
# Prepare data
self.prepared_train_data = image_train_items
random.Random(self.seed).shuffle(self.prepared_train_data)
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
# Initialize ratings
self.rating_overall_sum: float = 0.0
self.ratings_summed: list[float] = []
for image in self.prepared_train_data:
self.rating_overall_sum += image.caption.rating()
self.ratings_summed.append(self.rating_overall_sum)
def __pick_multiplied_set(self, randomizer):
"""
@ -138,54 +131,6 @@ class DataLoaderMultiAspect():
return image_caption_pairs
def __sort_and_precalc_image_ratings(self) -> tuple[float, list[float]]:
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
rating_overall_sum: float = 0.0
ratings_summed: list[float] = []
for image in self.prepared_train_data:
rating_overall_sum += image.caption.rating()
ratings_summed.append(rating_overall_sum)
return rating_overall_sum, ratings_summed
def __prepare_train_data(self, flip_p=0.0) -> list[ImageTrainItem]:
"""
Create ImageTrainItem objects with metadata for hydration later
"""
if not self.has_scanned:
self.has_scanned = True
logging.info(" Preloading images...")
items = resolver.resolve(self.data_root, self.aspects, flip_p=flip_p, seed=self.seed)
image_paths = set(map(lambda item: item.pathname, items))
print (f" * DLMA: {len(items)} images loaded from {len(image_paths)} files")
self.prepared_train_data = [item for item in items if item.error is None]
random.Random(self.seed).shuffle(self.prepared_train_data)
self.__report_errors(items)
def __report_errors(self, items: list[ImageTrainItem]):
for item in items:
if item.error is not None:
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
logging.error(f" *** exception: {item.error}")
undersized_items = [item for item in items if item.is_undersized]
if len(undersized_items) > 0:
underized_log_path = os.path.join(self.log_folder, "undersized_images.txt")
logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}")
logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}")
with open(underized_log_path, "w") as undersized_images_file:
undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:")
for undersized_item in undersized_items:
message = f" *** {undersized_item.pathname} with size: {undersized_item.image_size} is smaller than target size: {undersized_item.target_wh}\n"
undersized_images_file.write(message)
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
"""
Picks a random subset of all images

View File

@ -38,12 +38,9 @@ class EveryDreamBatch(Dataset):
jitter: number of pixels to jitter the crop by, only for non-square images
"""
def __init__(self,
data_root,
flip_p=0.0,
data_loader: dlma,
debug_level=0,
batch_size=1,
conditional_dropout=0.02,
resolution=512,
crop_jitter=20,
seed=555,
tokenizer=None,
@ -54,8 +51,8 @@ class EveryDreamBatch(Dataset):
rated_dataset=False,
rated_dataset_dropout_target=0.5
):
self.data_root = data_root
self.batch_size = batch_size
self.data_loader = data_loader
self.batch_size = data_loader.batch_size
self.debug_level = debug_level
self.conditional_dropout = conditional_dropout
self.crop_jitter = crop_jitter
@ -70,26 +67,11 @@ class EveryDreamBatch(Dataset):
self.seed = seed
self.rated_dataset = rated_dataset
self.rated_dataset_dropout_target = rated_dataset_dropout_target
if seed == -1:
seed = random.randint(0, 99999)
if not dls.shared_dataloader:
logging.info(" * Creating new dataloader singleton")
dls.shared_dataloader = dlma(data_root=data_root,
seed=seed,
debug_level=debug_level,
batch_size=self.batch_size,
flip_p=flip_p,
resolution=resolution,
log_folder=self.log_folder,
)
self.image_train_items = dls.shared_dataloader.get_shuffled_image_buckets(1.0) # First epoch always trains on all images
self.image_train_items = self.data_loader.get_shuffled_image_buckets(1.0) # First epoch always trains on all images
num_images = len(self.image_train_items)
logging.info(f" ** Trainer Set: {num_images / batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}")
logging.info(f" ** Trainer Set: {num_images / self.batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}")
if self.write_schedule:
self.__write_batch_schedule(0)

View File

@ -4,18 +4,21 @@ import os
import random
import typing
import zipfile
import argparse
import PIL.Image as Image
import tqdm
from colorama import Fore, Style
from data.image_train_item import ImageCaption, ImageTrainItem
class DataResolver:
def __init__(self, aspects: list[typing.Tuple[int, int]], flip_p=0.0, seed=555):
self.seed = seed
self.aspects = aspects
self.flip_p = flip_p
def __init__(self, args: argparse.Namespace):
"""
:param args: EveryDream configuration, an `argparse.Namespace` object.
"""
self.aspects = args.aspects
self.flip_p = args.flip_p
self.seed = args.seed
def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
"""
@ -173,8 +176,11 @@ class DirectoryResolver(DataResolver):
if os.path.isdir(current):
yield from DirectoryResolver.recurse_data_root(current)
def strategy(data_root: str):
def strategy(data_root: str) -> typing.Type[DataResolver]:
"""
Determine the strategy to use for resolving the data.
:param data_root: The root directory or JSON file to resolve.
"""
if os.path.isfile(data_root) and data_root.endswith('.json'):
return JSONResolver
@ -183,41 +189,37 @@ def strategy(data_root: str):
raise ValueError(f"data_root '{data_root}' is not a valid directory or JSON file.")
def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0, seed=555) -> list[ImageTrainItem]:
def resolve_root(path: str, args: argparse.Namespace) -> list[ImageTrainItem]:
"""
:param data_root: Directory or JSON file.
:param aspects: The list of aspect ratios to use
:param flip_p: The probability of flipping the image
Resolve the training data from the root path.
:param path: The root path to resolve.
:param args: EveryDream configuration, an `argparse.Namespace` object.
"""
if os.path.isfile(path) and path.endswith('.json'):
return JSONResolver(aspects, flip_p, seed).image_train_items(path)
if os.path.isdir(path):
return DirectoryResolver(aspects, flip_p, seed).image_train_items(path)
raise ValueError(f"data_root '{path}' is not a valid directory or JSON file.")
resolver = strategy(path)
return resolver(args).image_train_items(path)
def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=0.0, seed=555) -> list[ImageTrainItem]:
def resolve(value: typing.Union[dict, str], args: argparse.Namespace) -> list[ImageTrainItem]:
"""
Resolve the training data from the value.
:param value: The value to resolve, either a dict or a string.
:param aspects: The list of aspect ratios to use
:param flip_p: The probability of flipping the image
:param value: The value to resolve, either a dict, an array, or a string.
:param args: EveryDream configuration, an `argparse.Namespace` object.
"""
if isinstance(value, str):
return resolve_root(value, aspects, flip_p)
return resolve_root(value, args)
if isinstance(value, dict):
resolver = value.get('resolver', None)
match resolver:
case 'directory' | 'json':
path = value.get('path', None)
return resolve_root(path, aspects, flip_p, seed)
return resolve_root(path, args)
case 'multi':
items = []
for resolver in value.get('resolvers', []):
items += resolve(resolver, aspects, flip_p, seed)
return items
return resolve(value.get('resolvers', []), args)
case _:
raise ValueError(f"Cannot resolve training data for resolver value '{resolver}'")
raise ValueError(f"Cannot resolve training data for resolver value '{resolver}'")
if isinstance(value, list):
items = []
for item in value:
items += resolve(item, args)
return items

View File

@ -2,6 +2,7 @@ import json
import glob
import os
import unittest
import argparse
import PIL.Image as Image
@ -10,13 +11,18 @@ import data.resolver as resolver
DATA_PATH = os.path.abspath('./test/data')
JSON_ROOT_PATH = os.path.join(DATA_PATH, 'test_root.json')
ASPECTS = aspects.get_aspect_buckets(512)
IMAGE_1_PATH = os.path.join(DATA_PATH, 'test1.jpg')
CAPTION_1_PATH = os.path.join(DATA_PATH, 'test1.txt')
IMAGE_2_PATH = os.path.join(DATA_PATH, 'test2.jpg')
IMAGE_3_PATH = os.path.join(DATA_PATH, 'test3.jpg')
ARGS = argparse.Namespace(
aspects=aspects.get_aspect_buckets(512),
flip_p=0.5,
seed=42,
)
class TestResolve(unittest.TestCase):
@classmethod
def setUpClass(cls):
@ -51,7 +57,7 @@ class TestResolve(unittest.TestCase):
os.remove(file)
def test_directory_resolve_with_str(self):
items = resolver.resolve(DATA_PATH, ASPECTS)
items = resolver.resolve(DATA_PATH, ARGS)
image_paths = [item.pathname for item in items]
image_captions = [item.caption for item in items]
captions = [caption.get_caption() for caption in image_captions]
@ -69,7 +75,7 @@ class TestResolve(unittest.TestCase):
'path': DATA_PATH,
}
items = resolver.resolve(data_root_spec, ASPECTS)
items = resolver.resolve(data_root_spec, ARGS)
image_paths = [item.pathname for item in items]
image_captions = [item.caption for item in items]
captions = [caption.get_caption() for caption in image_captions]
@ -82,7 +88,7 @@ class TestResolve(unittest.TestCase):
self.assertEqual(len(undersized_images), 1)
def test_json_resolve_with_str(self):
items = resolver.resolve(JSON_ROOT_PATH, ASPECTS)
items = resolver.resolve(JSON_ROOT_PATH, ARGS)
image_paths = [item.pathname for item in items]
image_captions = [item.caption for item in items]
captions = [caption.get_caption() for caption in image_captions]
@ -100,7 +106,7 @@ class TestResolve(unittest.TestCase):
'path': JSON_ROOT_PATH,
}
items = resolver.resolve(data_root_spec, ASPECTS)
items = resolver.resolve(data_root_spec, ARGS)
image_paths = [item.pathname for item in items]
image_captions = [item.caption for item in items]
captions = [caption.get_caption() for caption in image_captions]