Push DLMA into main, pass config to resolve
This patch * passes the configuration (`argparse.Namespace`) to the resolver, * pushes the DLMA code into the main function, * makes DLMA take a `list[ImageTrainItem]` instead of `data_root`, * makes `EveryDreamBatch` take `DLMA` instead of `data_root`, etc. * allows `data_root` to be a list. By doing these things, both `EveryDreamBatch` and DLMA can be free from data resolution logic. It also reduces the number of arguments which need to be passed down to EDB and DLMA.
This commit is contained in:
parent
bc273d0512
commit
326d861a86
|
@ -15,15 +15,10 @@ limitations under the License.
|
|||
"""
|
||||
import bisect
|
||||
import math
|
||||
import os
|
||||
import logging
|
||||
import copy
|
||||
|
||||
import random
|
||||
from data.image_train_item import ImageTrainItem
|
||||
import data.aspects as aspects
|
||||
import data.resolver as resolver
|
||||
from colorama import Fore, Style
|
||||
import PIL
|
||||
|
||||
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
|
||||
|
@ -34,22 +29,20 @@ class DataLoaderMultiAspect():
|
|||
|
||||
data_root: root folder of training data
|
||||
batch_size: number of images per batch
|
||||
flip_p: probability of flipping image horizontally (i.e. 0-0.5)
|
||||
"""
|
||||
def __init__(self, data_root, seed=555, debug_level=0, batch_size=1, flip_p=0.0, resolution=512, log_folder=None):
|
||||
self.data_root = data_root
|
||||
self.debug_level = debug_level
|
||||
self.flip_p = flip_p
|
||||
self.log_folder = log_folder
|
||||
def __init__(self, image_train_items, seed=555, batch_size=1):
|
||||
self.seed = seed
|
||||
self.batch_size = batch_size
|
||||
self.has_scanned = False
|
||||
self.aspects = aspects.get_aspect_buckets(resolution=resolution, square_only=False)
|
||||
|
||||
logging.info(f"* DLMA resolution {resolution}, buckets: {self.aspects}")
|
||||
self.__prepare_train_data()
|
||||
(self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings()
|
||||
|
||||
# Prepare data
|
||||
self.prepared_train_data = image_train_items
|
||||
random.Random(self.seed).shuffle(self.prepared_train_data)
|
||||
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
|
||||
# Initialize ratings
|
||||
self.rating_overall_sum: float = 0.0
|
||||
self.ratings_summed: list[float] = []
|
||||
for image in self.prepared_train_data:
|
||||
self.rating_overall_sum += image.caption.rating()
|
||||
self.ratings_summed.append(self.rating_overall_sum)
|
||||
|
||||
def __pick_multiplied_set(self, randomizer):
|
||||
"""
|
||||
|
@ -138,54 +131,6 @@ class DataLoaderMultiAspect():
|
|||
|
||||
return image_caption_pairs
|
||||
|
||||
def __sort_and_precalc_image_ratings(self) -> tuple[float, list[float]]:
|
||||
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
|
||||
|
||||
rating_overall_sum: float = 0.0
|
||||
ratings_summed: list[float] = []
|
||||
for image in self.prepared_train_data:
|
||||
rating_overall_sum += image.caption.rating()
|
||||
ratings_summed.append(rating_overall_sum)
|
||||
|
||||
return rating_overall_sum, ratings_summed
|
||||
|
||||
def __prepare_train_data(self, flip_p=0.0) -> list[ImageTrainItem]:
|
||||
"""
|
||||
Create ImageTrainItem objects with metadata for hydration later
|
||||
"""
|
||||
|
||||
if not self.has_scanned:
|
||||
self.has_scanned = True
|
||||
|
||||
logging.info(" Preloading images...")
|
||||
|
||||
items = resolver.resolve(self.data_root, self.aspects, flip_p=flip_p, seed=self.seed)
|
||||
image_paths = set(map(lambda item: item.pathname, items))
|
||||
|
||||
print (f" * DLMA: {len(items)} images loaded from {len(image_paths)} files")
|
||||
|
||||
self.prepared_train_data = [item for item in items if item.error is None]
|
||||
random.Random(self.seed).shuffle(self.prepared_train_data)
|
||||
self.__report_errors(items)
|
||||
|
||||
def __report_errors(self, items: list[ImageTrainItem]):
|
||||
for item in items:
|
||||
if item.error is not None:
|
||||
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
|
||||
logging.error(f" *** exception: {item.error}")
|
||||
|
||||
undersized_items = [item for item in items if item.is_undersized]
|
||||
|
||||
if len(undersized_items) > 0:
|
||||
underized_log_path = os.path.join(self.log_folder, "undersized_images.txt")
|
||||
logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}")
|
||||
logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}")
|
||||
with open(underized_log_path, "w") as undersized_images_file:
|
||||
undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:")
|
||||
for undersized_item in undersized_items:
|
||||
message = f" *** {undersized_item.pathname} with size: {undersized_item.image_size} is smaller than target size: {undersized_item.target_wh}\n"
|
||||
undersized_images_file.write(message)
|
||||
|
||||
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
|
||||
"""
|
||||
Picks a random subset of all images
|
||||
|
|
|
@ -38,12 +38,9 @@ class EveryDreamBatch(Dataset):
|
|||
jitter: number of pixels to jitter the crop by, only for non-square images
|
||||
"""
|
||||
def __init__(self,
|
||||
data_root,
|
||||
flip_p=0.0,
|
||||
data_loader: dlma,
|
||||
debug_level=0,
|
||||
batch_size=1,
|
||||
conditional_dropout=0.02,
|
||||
resolution=512,
|
||||
crop_jitter=20,
|
||||
seed=555,
|
||||
tokenizer=None,
|
||||
|
@ -54,8 +51,8 @@ class EveryDreamBatch(Dataset):
|
|||
rated_dataset=False,
|
||||
rated_dataset_dropout_target=0.5
|
||||
):
|
||||
self.data_root = data_root
|
||||
self.batch_size = batch_size
|
||||
self.data_loader = data_loader
|
||||
self.batch_size = data_loader.batch_size
|
||||
self.debug_level = debug_level
|
||||
self.conditional_dropout = conditional_dropout
|
||||
self.crop_jitter = crop_jitter
|
||||
|
@ -70,26 +67,11 @@ class EveryDreamBatch(Dataset):
|
|||
self.seed = seed
|
||||
self.rated_dataset = rated_dataset
|
||||
self.rated_dataset_dropout_target = rated_dataset_dropout_target
|
||||
|
||||
if seed == -1:
|
||||
seed = random.randint(0, 99999)
|
||||
|
||||
if not dls.shared_dataloader:
|
||||
logging.info(" * Creating new dataloader singleton")
|
||||
dls.shared_dataloader = dlma(data_root=data_root,
|
||||
seed=seed,
|
||||
debug_level=debug_level,
|
||||
batch_size=self.batch_size,
|
||||
flip_p=flip_p,
|
||||
resolution=resolution,
|
||||
log_folder=self.log_folder,
|
||||
)
|
||||
|
||||
self.image_train_items = dls.shared_dataloader.get_shuffled_image_buckets(1.0) # First epoch always trains on all images
|
||||
self.image_train_items = self.data_loader.get_shuffled_image_buckets(1.0) # First epoch always trains on all images
|
||||
|
||||
num_images = len(self.image_train_items)
|
||||
|
||||
logging.info(f" ** Trainer Set: {num_images / batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}")
|
||||
logging.info(f" ** Trainer Set: {num_images / self.batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}")
|
||||
if self.write_schedule:
|
||||
self.__write_batch_schedule(0)
|
||||
|
||||
|
|
|
@ -4,18 +4,21 @@ import os
|
|||
import random
|
||||
import typing
|
||||
import zipfile
|
||||
import argparse
|
||||
|
||||
import PIL.Image as Image
|
||||
import tqdm
|
||||
from colorama import Fore, Style
|
||||
|
||||
from data.image_train_item import ImageCaption, ImageTrainItem
|
||||
|
||||
class DataResolver:
|
||||
def __init__(self, aspects: list[typing.Tuple[int, int]], flip_p=0.0, seed=555):
|
||||
self.seed = seed
|
||||
self.aspects = aspects
|
||||
self.flip_p = flip_p
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
"""
|
||||
:param args: EveryDream configuration, an `argparse.Namespace` object.
|
||||
"""
|
||||
self.aspects = args.aspects
|
||||
self.flip_p = args.flip_p
|
||||
self.seed = args.seed
|
||||
|
||||
def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
|
||||
"""
|
||||
|
@ -173,8 +176,11 @@ class DirectoryResolver(DataResolver):
|
|||
if os.path.isdir(current):
|
||||
yield from DirectoryResolver.recurse_data_root(current)
|
||||
|
||||
|
||||
def strategy(data_root: str):
|
||||
def strategy(data_root: str) -> typing.Type[DataResolver]:
|
||||
"""
|
||||
Determine the strategy to use for resolving the data.
|
||||
:param data_root: The root directory or JSON file to resolve.
|
||||
"""
|
||||
if os.path.isfile(data_root) and data_root.endswith('.json'):
|
||||
return JSONResolver
|
||||
|
||||
|
@ -183,41 +189,37 @@ def strategy(data_root: str):
|
|||
|
||||
raise ValueError(f"data_root '{data_root}' is not a valid directory or JSON file.")
|
||||
|
||||
|
||||
def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0, seed=555) -> list[ImageTrainItem]:
|
||||
def resolve_root(path: str, args: argparse.Namespace) -> list[ImageTrainItem]:
|
||||
"""
|
||||
:param data_root: Directory or JSON file.
|
||||
:param aspects: The list of aspect ratios to use
|
||||
:param flip_p: The probability of flipping the image
|
||||
Resolve the training data from the root path.
|
||||
:param path: The root path to resolve.
|
||||
:param args: EveryDream configuration, an `argparse.Namespace` object.
|
||||
"""
|
||||
if os.path.isfile(path) and path.endswith('.json'):
|
||||
return JSONResolver(aspects, flip_p, seed).image_train_items(path)
|
||||
|
||||
if os.path.isdir(path):
|
||||
return DirectoryResolver(aspects, flip_p, seed).image_train_items(path)
|
||||
|
||||
raise ValueError(f"data_root '{path}' is not a valid directory or JSON file.")
|
||||
resolver = strategy(path)
|
||||
return resolver(args).image_train_items(path)
|
||||
|
||||
def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=0.0, seed=555) -> list[ImageTrainItem]:
|
||||
def resolve(value: typing.Union[dict, str], args: argparse.Namespace) -> list[ImageTrainItem]:
|
||||
"""
|
||||
Resolve the training data from the value.
|
||||
:param value: The value to resolve, either a dict or a string.
|
||||
:param aspects: The list of aspect ratios to use
|
||||
:param flip_p: The probability of flipping the image
|
||||
:param value: The value to resolve, either a dict, an array, or a string.
|
||||
:param args: EveryDream configuration, an `argparse.Namespace` object.
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
return resolve_root(value, aspects, flip_p)
|
||||
return resolve_root(value, args)
|
||||
|
||||
if isinstance(value, dict):
|
||||
resolver = value.get('resolver', None)
|
||||
match resolver:
|
||||
case 'directory' | 'json':
|
||||
path = value.get('path', None)
|
||||
return resolve_root(path, aspects, flip_p, seed)
|
||||
return resolve_root(path, args)
|
||||
case 'multi':
|
||||
items = []
|
||||
for resolver in value.get('resolvers', []):
|
||||
items += resolve(resolver, aspects, flip_p, seed)
|
||||
return items
|
||||
return resolve(value.get('resolvers', []), args)
|
||||
case _:
|
||||
raise ValueError(f"Cannot resolve training data for resolver value '{resolver}'")
|
||||
raise ValueError(f"Cannot resolve training data for resolver value '{resolver}'")
|
||||
|
||||
if isinstance(value, list):
|
||||
items = []
|
||||
for item in value:
|
||||
items += resolve(item, args)
|
||||
return items
|
|
@ -2,6 +2,7 @@ import json
|
|||
import glob
|
||||
import os
|
||||
import unittest
|
||||
import argparse
|
||||
|
||||
import PIL.Image as Image
|
||||
|
||||
|
@ -10,13 +11,18 @@ import data.resolver as resolver
|
|||
|
||||
DATA_PATH = os.path.abspath('./test/data')
|
||||
JSON_ROOT_PATH = os.path.join(DATA_PATH, 'test_root.json')
|
||||
ASPECTS = aspects.get_aspect_buckets(512)
|
||||
|
||||
IMAGE_1_PATH = os.path.join(DATA_PATH, 'test1.jpg')
|
||||
CAPTION_1_PATH = os.path.join(DATA_PATH, 'test1.txt')
|
||||
IMAGE_2_PATH = os.path.join(DATA_PATH, 'test2.jpg')
|
||||
IMAGE_3_PATH = os.path.join(DATA_PATH, 'test3.jpg')
|
||||
|
||||
ARGS = argparse.Namespace(
|
||||
aspects=aspects.get_aspect_buckets(512),
|
||||
flip_p=0.5,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
class TestResolve(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
@ -51,7 +57,7 @@ class TestResolve(unittest.TestCase):
|
|||
os.remove(file)
|
||||
|
||||
def test_directory_resolve_with_str(self):
|
||||
items = resolver.resolve(DATA_PATH, ASPECTS)
|
||||
items = resolver.resolve(DATA_PATH, ARGS)
|
||||
image_paths = [item.pathname for item in items]
|
||||
image_captions = [item.caption for item in items]
|
||||
captions = [caption.get_caption() for caption in image_captions]
|
||||
|
@ -69,7 +75,7 @@ class TestResolve(unittest.TestCase):
|
|||
'path': DATA_PATH,
|
||||
}
|
||||
|
||||
items = resolver.resolve(data_root_spec, ASPECTS)
|
||||
items = resolver.resolve(data_root_spec, ARGS)
|
||||
image_paths = [item.pathname for item in items]
|
||||
image_captions = [item.caption for item in items]
|
||||
captions = [caption.get_caption() for caption in image_captions]
|
||||
|
@ -82,7 +88,7 @@ class TestResolve(unittest.TestCase):
|
|||
self.assertEqual(len(undersized_images), 1)
|
||||
|
||||
def test_json_resolve_with_str(self):
|
||||
items = resolver.resolve(JSON_ROOT_PATH, ASPECTS)
|
||||
items = resolver.resolve(JSON_ROOT_PATH, ARGS)
|
||||
image_paths = [item.pathname for item in items]
|
||||
image_captions = [item.caption for item in items]
|
||||
captions = [caption.get_caption() for caption in image_captions]
|
||||
|
@ -100,7 +106,7 @@ class TestResolve(unittest.TestCase):
|
|||
'path': JSON_ROOT_PATH,
|
||||
}
|
||||
|
||||
items = resolver.resolve(data_root_spec, ASPECTS)
|
||||
items = resolver.resolve(data_root_spec, ARGS)
|
||||
image_paths = [item.pathname for item in items]
|
||||
image_captions = [item.caption for item in items]
|
||||
captions = [caption.get_caption() for caption in image_captions]
|
||||
|
|
Loading…
Reference in New Issue