From a6cabe8d7d4850b190973a9b2c2c9946529b5e6c Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Sun, 22 Jan 2023 16:13:50 -0800 Subject: [PATCH 01/22] Add static methods to ImageCaption for deriving captions from various sources --- data/image_train_item.py | 154 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 147 insertions(+), 7 deletions(-) diff --git a/data/image_train_item.py b/data/image_train_item.py index 08bf736..d3b20d3 100644 --- a/data/image_train_item.py +++ b/data/image_train_item.py @@ -18,6 +18,8 @@ import logging import math import os import random +import typing +import yaml import PIL import numpy as np @@ -25,6 +27,9 @@ from torchvision import transforms _RANDOM_TRIM = 0.04 +DEFAULT_MAX_CAPTION_LENGTH = 2048 + +OptionalImageCaption = typing.Optional['ImageCaption'] class ImageCaption: """ @@ -60,13 +65,15 @@ class ImageCaption: :param seed used to initialize the randomizer :return: generated caption string """ - max_target_tag_length = self.__max_target_length - len(self.__main_prompt) + if self.__tags: + max_target_tag_length = self.__max_target_length - len(self.__main_prompt) - if self.__use_weights: - tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length) - else: - tags_caption = self.__get_shuffled_tags(seed, self.__tags) + if self.__use_weights: + tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length) + else: + tags_caption = self.__get_shuffled_tags(seed, self.__tags) + return self.__main_prompt + ", " + tags_caption return self.__main_prompt + ", " + tags_caption def get_caption(self) -> str: @@ -91,7 +98,10 @@ class ImageCaption: weights_copy.pop(pos) tag = tags_copy.pop(pos) - caption += ", " + tag + + if caption: + caption += ", " + caption += tag return caption @@ -100,6 +110,136 @@ class ImageCaption: random.Random(seed).shuffle(tags) return ", ".join(tags) + @staticmethod + def parse(string: str) -> 'ImageCaption': + """ + Parses a string to get the caption. + + :param string: String to parse. + :return: `ImageCaption` object. + """ + split_caption = list(map(str.strip, string.split(","))) + main_prompt = split_caption[0] + tags = split_caption[1:] + tag_weights = [1.0] * len(tags) + + return ImageCaption(main_prompt, 1.0, tags, tag_weights, DEFAULT_MAX_CAPTION_LENGTH, False) + + @staticmethod + def from_file_name(file_path: str) -> 'ImageCaption': + """ + Parses the file name to get the caption. + + :param file_path: Path to the image file. + :return: `ImageCaption` object. + """ + (file_name, _) = os.path.splitext(os.path.basename(file_path)) + caption = file_name.split("_")[0] + return ImageCaption(caption, 1.0, [], [], DEFAULT_MAX_CAPTION_LENGTH, False) + + @staticmethod + def from_text_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption: + """ + Parses a text file to get the caption. Returns the default caption if + the file does not exist or is invalid. + + :param file_path: Path to the text file. + :param default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid. + :return: `ImageCaption` object or `None`. + """ + try: + with open(file_path, encoding='utf-8', mode='r') as caption_file: + caption_text = caption_file.read() + return ImageCaption.parse(caption_text) + except: + logging.error(f" *** Error reading {file_path} to get caption") + return default_caption + + @staticmethod + def from_yaml_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption: + """ + Parses a yaml file to get the caption. Returns the default caption if + the file does not exist or is invalid. + + :param file_path: path to the yaml file + :param default_caption: caption to return if the file does not exist or is invalid + :return: `ImageCaption` object or `None`. + """ + try: + with open(file_path, "r") as stream: + file_content = yaml.safe_load(stream) + main_prompt = file_content.get("main_prompt", "") + rating = file_content.get("rating", 1.0) + unparsed_tags = file_content.get("tags", []) + + max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH) + + tags = [] + tag_weights = [] + last_weight = None + weights_differ = False + for unparsed_tag in unparsed_tags: + tag = unparsed_tag.get("tag", "").strip() + if len(tag) == 0: + continue + + tags.append(tag) + tag_weight = unparsed_tag.get("weight", 1.0) + tag_weights.append(tag_weight) + + if last_weight is not None and weights_differ is False: + weights_differ = last_weight != tag_weight + + last_weight = tag_weight + + return ImageCaption(main_prompt, rating, tags, tag_weights, max_caption_length, weights_differ) + except: + logging.error(f" *** Error reading {file_path} to get caption") + return default_caption + + @staticmethod + def from_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption: + """ + Try to resolve a caption from a file path or return `default_caption`. + + :string: The path to the file to parse. + :default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid. + :return: `ImageCaption` object or `None`. + """ + if os.path.exists(file_path): + (file_path_without_ext, ext) = os.path.splitext(file_path) + match ext: + case ".yaml" | ".yml": + return ImageCaption.from_yaml_file(file_path, default_caption) + + case ".txt" | ".caption": + return ImageCaption.from_text_file(file_path, default_caption) + + case '.jpg'| '.jpeg'| '.png'| '.bmp'| '.webp'| '.jfif': + for ext in [".yaml", ".yml", ".txt", ".caption"]: + file_path = file_path_without_ext + ext + image_caption = ImageCaption.from_file(file_path) + if image_caption is not None: + return image_caption + return ImageCaption.from_file_name(file_path) + + case _: + return default_caption + else: + return default_caption + + @staticmethod + def resolve(string: str) -> 'ImageCaption': + """ + Try to resolve a caption from a string. If the string is a file path, + the caption will be read from the file, otherwise the string will be + parsed as a caption. + + :string: The string to resolve. + :return: `ImageCaption` object. + """ + return ImageCaption.from_file(string, None) or ImageCaption.parse(string) + class ImageTrainItem: """ @@ -229,4 +369,4 @@ class ImageTrainItem: min_xy = min(x, y) image = image.crop((x_crop, y_crop, x_crop + min_xy, y_crop + min_xy)) - return image + return image \ No newline at end of file From f4f684a91560fd2f908f8359818ca6720b49d5c1 Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Sun, 22 Jan 2023 22:43:59 -0800 Subject: [PATCH 02/22] Add unit tests for data.image_train_item --- test/data/.gitkeep | 0 test/test_image_train_item.py | 71 +++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+) create mode 100644 test/data/.gitkeep create mode 100644 test/test_image_train_item.py diff --git a/test/data/.gitkeep b/test/data/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/test/test_image_train_item.py b/test/test_image_train_item.py new file mode 100644 index 0000000..bf12b43 --- /dev/null +++ b/test/test_image_train_item.py @@ -0,0 +1,71 @@ +import unittest +import os +import pathlib +import PIL.Image as Image + +from data.image_train_item import ImageCaption, ImageTrainItem + +DATA_PATH = pathlib.Path('./test/data') + +class TestImageCaption(unittest.TestCase): + + def setUp(self) -> None: + with open(DATA_PATH / "test1.txt", encoding='utf-8', mode='w') as f: + f.write("caption for test1") + + Image.new("RGB", (512,512)).save(DATA_PATH / "test1.jpg") + Image.new("RGB", (512,512)).save(DATA_PATH / "test2.jpg") + + with open(DATA_PATH / "test_caption.caption", encoding='utf-8', mode='w') as f: + f.write("caption for test2") + + return super().setUp() + + def tearDown(self) -> None: + for file in DATA_PATH.glob("test*"): + file.unlink() + + return super().tearDown() + + def test_constructor(self): + caption = ImageCaption("hello world", 1.0, ["one", "two", "three"], [1.0]*3, 2048, False) + self.assertEqual(caption.get_caption(), "hello world, one, two, three") + + caption = ImageCaption("hello world", 1.0, [], [], 2048, False) + self.assertEqual(caption.get_caption(), "hello world") + + def test_parse(self): + caption = ImageCaption.parse("hello world, one, two, three") + + self.assertEqual(caption.get_caption(), "hello world, one, two, three") + + def test_from_file_name(self): + caption = ImageCaption.from_file_name("foo bar_1_2_3.jpg") + self.assertEqual(caption.get_caption(), "foo bar") + + def test_from_text_file(self): + caption = ImageCaption.from_text_file("test/data/test1.txt") + self.assertEqual(caption.get_caption(), "caption for test1") + + def test_from_file(self): + caption = ImageCaption.from_file("test/data/test1.txt") + self.assertEqual(caption.get_caption(), "caption for test1") + + caption = ImageCaption.from_file("test/data/test_caption.caption") + self.assertEqual(caption.get_caption(), "caption for test2") + + def test_resolve(self): + caption = ImageCaption.resolve("test/data/test1.txt") + self.assertEqual(caption.get_caption(), "caption for test1") + + caption = ImageCaption.resolve("test/data/test_caption.caption") + self.assertEqual(caption.get_caption(), "caption for test2") + + caption = ImageCaption.resolve("hello world") + self.assertEqual(caption.get_caption(), "hello world") + + caption = ImageCaption.resolve("test/data/test1.jpg") + self.assertEqual(caption.get_caption(), "caption for test1") + + caption = ImageCaption.resolve("test/data/test2.jpg") + self.assertEqual(caption.get_caption(), "test2") \ No newline at end of file From 85b6aad6a99f6fd12701dfd4584af469956b4a6e Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Sun, 22 Jan 2023 22:44:44 -0800 Subject: [PATCH 03/22] If tags are empty, returne __main_prompt --- data/image_train_item.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/data/image_train_item.py b/data/image_train_item.py index d3b20d3..0e05145 100644 --- a/data/image_train_item.py +++ b/data/image_train_item.py @@ -77,7 +77,9 @@ class ImageCaption: return self.__main_prompt + ", " + tags_caption def get_caption(self) -> str: - return self.__main_prompt + ", " + ", ".join(self.__tags) + if self.__tags: + return self.__main_prompt + ", " + ", ".join(self.__tags) + return self.__main_prompt @staticmethod def __get_weighted_shuffled_tags(seed: int, tags: list[str], weights: list[float], max_target_tag_length: int) -> str: From 914a51b057ffc2026f6cc4fc96a6f9efbd36a531 Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Sun, 22 Jan 2023 22:46:04 -0800 Subject: [PATCH 04/22] Add data.resolver module for training data resolution --- data/resolver.py | 259 +++++++++++++++++++++++++++++++++++++ test/test_data_resolver.py | 117 +++++++++++++++++ 2 files changed, 376 insertions(+) create mode 100644 data/resolver.py create mode 100644 test/test_data_resolver.py diff --git a/data/resolver.py b/data/resolver.py new file mode 100644 index 0000000..16ea565 --- /dev/null +++ b/data/resolver.py @@ -0,0 +1,259 @@ +import json +import logging +import os +import typing +import zipfile + +import PIL.Image as Image +import tqdm +from colorama import Fore, Style + +from data.image_train_item import ImageCaption, ImageTrainItem + + +OptionalCallable = typing.Optional[typing.Callable] + +class Event: + def __init__(self, name: str): + self.name = name + +class UndersizedImageEvent(Event): + def __init__(self, image_path: str, image_size: typing.Tuple[int, int], target_size: typing.Tuple[int, int]): + super().__init__('undersized_image') + self.image_path = image_path + self.image_size = image_size + self.target_size = target_size + +class DataResolver: + def __init__(self, aspects: list[typing.Tuple[int, int]], flip_p=0.0, on_event: OptionalCallable=None): + self.aspects = aspects + self.flip_p = flip_p + self.on_event = on_event or (lambda data: None) + + def image_train_items(self, data_root: str) -> list[ImageTrainItem]: + """ + Get the list of `ImageTrainItem` for the given data root. + + :param data_root: The data root, a directory, a file, etc.. + :return: The list of `ImageTrainItem`. + """ + raise NotImplementedError() + + def compute_target_width_height(self, image_path: str) -> typing.Optional[typing.Tuple[int, int]]: + # Compute the target width and height for the image based on the aspect ratio. + with Image.open(image_path) as image: + width, height = image.size + image_aspect = width / height + target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect)) + + if width * height < target_wh[0] * target_wh[1]: + event = UndersizedImageEvent(image_path, (width, height), target_wh) + self.on_event(event) + + return target_wh + + def image_train_item(self, image_path: str, caption: ImageCaption) -> ImageTrainItem: + #try: + target_wh = self.compute_target_width_height(image_path) + return ImageTrainItem(image=None, caption=caption, target_wh=target_wh, pathname=image_path, flip_p=self.flip_p) + # except Exception as e: + # logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{image_path}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}") + # logging.error(f" *** exception: {e}") + + +class JSONResolver(DataResolver): + def image_train_items(self, json_path: str) -> list[ImageTrainItem]: + items = [] + with open(json_path, encoding='utf-8', mode='r') as f: + json_data = json.load(f) + + for data in json_data: + caption = JSONResolver.image_caption(data) + if caption: + image_value = JSONResolver.get_image_value(data) + item = self.image_train_item(image_value, caption) + if item: + items.append(item) + + return items + + @staticmethod + def get_image_value(json_data: dict) -> typing.Optional[str]: + """ + Get the image from the json data if possible. + + :param json_data: The json data, a dict. + :return: The image, or None if not found. + """ + image_value = json_data.get("image", None) + if isinstance(image_value, str): + image_value = image_value.strip() + if os.path.exists(image_value): + return image_value + + @staticmethod + def get_caption_value(json_data: dict) -> typing.Optional[str]: + """ + Get the caption from the json data if possible. + + :param json_data: The json data, a dict. + :return: The caption, or None if not found. + """ + caption_value = json_data.get("caption", None) + if isinstance(caption_value, str): + return caption_value.strip() + + @staticmethod + def image_caption(json_data: dict) -> typing.Optional[ImageCaption]: + """ + Get the caption from the json data if possible. + + :param json_data: The json data, a dict. + :return: The `ImageCaption`, or None if not found. + """ + image_value = JSONResolver.get_image_value(json_data) + caption_value = JSONResolver.get_caption_value(json_data) + if image_value: + if caption_value: + return ImageCaption.resolve(caption_value) + return ImageCaption.from_file(image_value) + + +class DirectoryResolver(DataResolver): + def image_train_items(self, data_root: str) -> list[ImageTrainItem]: + """ + Create `ImageTrainItem` objects with metadata for hydration later. + Unzips all zip files in `data_root` and then recursively searches the + `data_root` for images and captions. + + :param data_root: The root directory to recurse through + """ + DirectoryResolver.unzip_all(data_root) + image_paths = list(DirectoryResolver.recurse_data_root(data_root)) + items = [] + + for pathname in tqdm.tqdm(image_paths): + caption = ImageCaption.from_file(pathname) + item = self.image_train_item(pathname, caption) + + if item: + items.append(item) + return items + + @staticmethod + def unzip_all(path): + try: + for root, dirs, files in os.walk(path): + for file in files: + if file.endswith('.zip'): + logging.info(f"Unzipping {file}") + with zipfile.ZipFile(path, 'r') as zip_ref: + zip_ref.extractall(path) + except Exception as e: + logging.error(f"Error unzipping files {e}") + + @staticmethod + def recurse_data_root(recurse_root): + multiply = 1 + multiply_path = os.path.join(recurse_root, "multiply.txt") + if os.path.exists(multiply_path): + try: + with open(multiply_path, encoding='utf-8', mode='r') as f: + multiply = int(float(f.read().strip())) + logging.info(f" * DLMA multiply.txt in {recurse_root} set to {multiply}") + except: + logging.error(f" *** Error reading multiply.txt in {recurse_root}, defaulting to 1") + pass + + for f in os.listdir(recurse_root): + current = os.path.join(recurse_root, f) + + if os.path.isfile(current): + ext = os.path.splitext(f)[1].lower() + if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']: + # Add image multiplyrepeats number of times + for _ in range(multiply): + yield current + + for d in os.listdir(recurse_root): + current = os.path.join(recurse_root, d) + if os.path.isdir(current): + for file in DirectoryResolver.recurse_data_root(recurse_root=dir): + yield file + + +def strategy(data_root: str): + if os.path.isfile(data_root) and data_root.endswith('.json'): + return JSONResolver + + if os.path.isdir(data_root): + return DirectoryResolver + + raise ValueError(f"data_root '{data_root}' is not a valid directory or JSON file.") + + +def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0, on_event: OptionalCallable=None) -> list[ImageTrainItem]: + """ + :param data_root: Directory or JSON file. + :param aspects: The list of aspect ratios to use + :param flip_p: The probability of flipping the image + """ + if os.path.isfile(path) and path.endswith('.json'): + resolver = JSONResolver(aspects, flip_p, on_event) + + if os.path.isdir(path): + resolver = DirectoryResolver(aspects, flip_p, on_event) + + if not resolver: + raise ValueError(f"data_root '{path}' is not a valid directory or JSON file.") + + return resolver.image_train_items(path) + +def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=0.0, on_event: OptionalCallable=None) -> list[ImageTrainItem]: + """ + Resolve the training data from the value. + :param value: The value to resolve, either a dict or a string. + :param aspects: The list of aspect ratios to use + :param flip_p: The probability of flipping the image + :param on_event: The callback to call when an event occurs (e.g. undersized image detected) + """ + if isinstance(value, str): + return resolve_root(value, aspects, flip_p, on_event) + + if isinstance(value, dict): + resolver = value.get('resolver', None) + match resolver: + case 'directory' | 'json': + path = value.get('path', None) + return resolve_root(path, aspects, flip_p, on_event) + case 'multi': + items = [] + for resolver in value.get('resolvers', []): + items += resolve(resolver, aspects, flip_p, on_event) + return items + case _: + raise ValueError(f"Cannot resolve training data for resolver value '{resolver}'") + + +# example = { +# 'resolver': 'directory', +# 'data_root': 'data', +# } + +# example = { +# 'resolver': 'json', +# 'data_root': 'data.json', +# } + +# example = { +# 'resolver': 'multi', +# 'resolvers': [ +# { +# 'resolver': 'directory', +# 'data_root': 'data', +# }, { +# 'resolver': 'json', +# 'data_root': 'data.json', +# }, +# ] +# } \ No newline at end of file diff --git a/test/test_data_resolver.py b/test/test_data_resolver.py new file mode 100644 index 0000000..db3d176 --- /dev/null +++ b/test/test_data_resolver.py @@ -0,0 +1,117 @@ +import json +import glob +import os +import unittest + +import PIL.Image as Image + +import data.aspects as aspects +import data.resolver as resolver + +DATA_PATH = os.path.abspath('./test/data') +JSON_ROOT_PATH = os.path.join(DATA_PATH, 'test_root.json') +ASPECTS = aspects.get_aspect_buckets(512) +FLIP_P = 0.0 + +IMAGE_1_PATH = os.path.join(DATA_PATH, 'test1.jpg') +CAPTION_1_PATH = os.path.join(DATA_PATH, 'test1.txt') +IMAGE_2_PATH = os.path.join(DATA_PATH, 'test2.jpg') +IMAGE_3_PATH = os.path.join(DATA_PATH, 'test3.jpg') + +class TestResolve(unittest.TestCase): + @classmethod + def setUpClass(cls): + Image.new('RGB', (512, 512)).save(IMAGE_1_PATH) + with open(CAPTION_1_PATH, 'w') as f: + f.write('caption for test1') + + Image.new('RGB', (512, 512)).save(IMAGE_2_PATH) + # Undersized image. Should cause an event. + Image.new('RGB', (256, 256)).save(IMAGE_3_PATH) + + json_data = [ + { + 'image': IMAGE_1_PATH, + 'caption': CAPTION_1_PATH + }, + { + 'image': IMAGE_2_PATH, + 'caption': 'caption for test2' + }, + { + 'image': IMAGE_3_PATH, + } + ] + + with open(JSON_ROOT_PATH, 'w') as f: + json.dump(json_data, f, indent=4) + + + @classmethod + def tearDownClass(cls): + for file in glob.glob(os.path.join(DATA_PATH, 'test*')): + os.remove(file) + + def setUp(self) -> None: + self.events = [] + self.on_event = lambda event: self.events.append(event.name) + return super().setUp() + + def tearDown(self) -> None: + self.events = [] + self.on_event = None + return super().tearDown() + + def test_directory_resolve_with_str(self): + image_train_items = resolver.resolve(DATA_PATH, ASPECTS, FLIP_P, self.on_event) + image_paths = [item.pathname for item in image_train_items] + image_captions = [item.caption for item in image_train_items] + captions = [caption.get_caption() for caption in image_captions] + + self.assertEqual(len(image_train_items), 3) + self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH]) + self.assertEqual(captions, ['caption for test1', 'test2', 'test3']) + self.assertEqual(self.events, ['undersized_image']) + + def test_directory_resolve_with_dict(self): + data_root_spec = { + 'resolver': 'directory', + 'path': DATA_PATH, + } + + image_train_items = resolver.resolve(data_root_spec, ASPECTS, FLIP_P, self.on_event) + image_paths = [item.pathname for item in image_train_items] + image_captions = [item.caption for item in image_train_items] + captions = [caption.get_caption() for caption in image_captions] + + self.assertEqual(len(image_train_items), 3) + self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH]) + self.assertEqual(captions, ['caption for test1', 'test2', 'test3']) + self.assertEqual(self.events, ['undersized_image']) + + def test_json_resolve_with_str(self): + image_train_items = resolver.resolve(JSON_ROOT_PATH, ASPECTS, FLIP_P, self.on_event) + image_paths = [item.pathname for item in image_train_items] + image_captions = [item.caption for item in image_train_items] + captions = [caption.get_caption() for caption in image_captions] + + self.assertEqual(len(image_train_items), 3) + self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH]) + self.assertEqual(captions, ['caption for test1', 'caption for test2', 'test3']) + self.assertEqual(self.events, ['undersized_image']) + + def test_json_resolve_with_dict(self): + data_root_spec = { + 'resolver': 'json', + 'path': JSON_ROOT_PATH, + } + + image_train_items = resolver.resolve(data_root_spec, ASPECTS, FLIP_P, self.on_event) + image_paths = [item.pathname for item in image_train_items] + image_captions = [item.caption for item in image_train_items] + captions = [caption.get_caption() for caption in image_captions] + + self.assertEqual(len(image_train_items), 3) + self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH]) + self.assertEqual(captions, ['caption for test1', 'caption for test2', 'test3']) + self.assertEqual(self.events, ['undersized_image']) \ No newline at end of file From aa0a2a176540d8219e17d17db39c50c8517cc3d5 Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Sun, 22 Jan 2023 23:09:09 -0800 Subject: [PATCH 05/22] Sync recurse_data_root changes with data loader --- data/resolver.py | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/data/resolver.py b/data/resolver.py index 16ea565..92f5165 100644 --- a/data/resolver.py +++ b/data/resolver.py @@ -154,33 +154,23 @@ class DirectoryResolver(DataResolver): @staticmethod def recurse_data_root(recurse_root): - multiply = 1 - multiply_path = os.path.join(recurse_root, "multiply.txt") - if os.path.exists(multiply_path): - try: - with open(multiply_path, encoding='utf-8', mode='r') as f: - multiply = int(float(f.read().strip())) - logging.info(f" * DLMA multiply.txt in {recurse_root} set to {multiply}") - except: - logging.error(f" *** Error reading multiply.txt in {recurse_root}, defaulting to 1") - pass - for f in os.listdir(recurse_root): current = os.path.join(recurse_root, f) if os.path.isfile(current): ext = os.path.splitext(f)[1].lower() if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']: - # Add image multiplyrepeats number of times - for _ in range(multiply): - yield current + yield current + + sub_dirs = [] for d in os.listdir(recurse_root): current = os.path.join(recurse_root, d) if os.path.isdir(current): - for file in DirectoryResolver.recurse_data_root(recurse_root=dir): - yield file - + sub_dirs.append(current) + + for dir in sub_dirs: + DirectoryResolver.__recurse_data_root(dir) def strategy(data_root: str): if os.path.isfile(data_root) and data_root.endswith('.json'): From 08813eabb5c6c9bcb2de8911b590e5b7386dca97 Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Sun, 22 Jan 2023 23:13:05 -0800 Subject: [PATCH 06/22] Use DirectoryResolver.recurse_data_root --- data/data_loader.py | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/data/data_loader.py b/data/data_loader.py index 60edce2..5172624 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -23,6 +23,8 @@ from PIL import Image import random from data.image_train_item import ImageTrainItem, ImageCaption import data.aspects as aspects +import data.resolver as resolver +from data.resolver import DirectoryResolver from colorama import Fore, Style import zipfile import tqdm @@ -54,8 +56,10 @@ class DataLoaderMultiAspect(): logging.info(" Preloading images...") self.unzip_all(data_root) + + for image_path in DirectoryResolver.recurse_data_root(data_root): + self.image_paths.append(image_path) - self.__recurse_data_root(self=self, recurse_root=data_root) random.Random(seed).shuffle(self.image_paths) self.prepared_train_data = self.__prescan_images(self.image_paths, flip_p) (self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings() @@ -349,23 +353,3 @@ class DataLoaderMultiAspect(): prepared_train_data.pop(pos) return picked_images - - @staticmethod - def __recurse_data_root(self, recurse_root): - for f in os.listdir(recurse_root): - current = os.path.join(recurse_root, f) - - if os.path.isfile(current): - ext = os.path.splitext(f)[1].lower() - if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']: - self.image_paths.append(current) - - sub_dirs = [] - - for d in os.listdir(recurse_root): - current = os.path.join(recurse_root, d) - if os.path.isdir(current): - sub_dirs.append(current) - - for dir in sub_dirs: - self.__recurse_data_root(self=self, recurse_root=dir) From 9c6df69e4ee06a8982212c72fb40577416be4d6f Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Sun, 22 Jan 2023 23:14:16 -0800 Subject: [PATCH 07/22] Use DirectoryResolver.unzip_all --- data/data_loader.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/data/data_loader.py b/data/data_loader.py index 5172624..ae276bd 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -55,7 +55,7 @@ class DataLoaderMultiAspect(): logging.info(f"* DLMA resolution {resolution}, buckets: {self.aspects}") logging.info(" Preloading images...") - self.unzip_all(data_root) + DirectoryResolver.unzip_all(data_root) for image_path in DirectoryResolver.recurse_data_root(data_root): self.image_paths.append(image_path) @@ -149,18 +149,6 @@ class DataLoaderMultiAspect(): return image_caption_pairs - @staticmethod - def unzip_all(path): - try: - for root, dirs, files in os.walk(path): - for file in files: - if file.endswith('.zip'): - logging.info(f"Unzipping {file}") - with zipfile.ZipFile(path, 'r') as zip_ref: - zip_ref.extractall(path) - except Exception as e: - logging.error(f"Error unzipping files {e}") - def __sort_and_precalc_image_ratings(self) -> tuple[float, list[float]]: self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating()) From 4e6c5f4d003e8a4f905d6b5cb85d01b8f383b084 Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Sun, 22 Jan 2023 23:58:25 -0800 Subject: [PATCH 08/22] Get rid of on_event callback --- data/resolver.py | 104 ++++++++++++++++++++----------------- test/test_data_resolver.py | 60 ++++++++++----------- 2 files changed, 84 insertions(+), 80 deletions(-) diff --git a/data/resolver.py b/data/resolver.py index 92f5165..d432b9d 100644 --- a/data/resolver.py +++ b/data/resolver.py @@ -25,10 +25,10 @@ class UndersizedImageEvent(Event): self.target_size = target_size class DataResolver: - def __init__(self, aspects: list[typing.Tuple[int, int]], flip_p=0.0, on_event: OptionalCallable=None): + def __init__(self, aspects: list[typing.Tuple[int, int]], flip_p=0.0): self.aspects = aspects self.flip_p = flip_p - self.on_event = on_event or (lambda data: None) + self.events = [] def image_train_items(self, data_root: str) -> list[ImageTrainItem]: """ @@ -48,17 +48,25 @@ class DataResolver: if width * height < target_wh[0] * target_wh[1]: event = UndersizedImageEvent(image_path, (width, height), target_wh) - self.on_event(event) + self.events.append(event) return target_wh - def image_train_item(self, image_path: str, caption: ImageCaption) -> ImageTrainItem: - #try: - target_wh = self.compute_target_width_height(image_path) - return ImageTrainItem(image=None, caption=caption, target_wh=target_wh, pathname=image_path, flip_p=self.flip_p) - # except Exception as e: - # logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{image_path}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}") - # logging.error(f" *** exception: {e}") + def image_train_item(self, image_path: str, caption: ImageCaption, multiplier: float=1) -> ImageTrainItem: + try: + target_wh = self.compute_target_width_height(image_path) + return ImageTrainItem( + image=None, + caption=caption, + target_wh=target_wh, + pathname=image_path, + flip_p=self.flip_p, + multiplier=multiplier + ) + # TODO: This should only handle Image errors. + except Exception as e: + logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{image_path}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}") + logging.error(f" *** exception: {e}") class JSONResolver(DataResolver): @@ -131,10 +139,30 @@ class DirectoryResolver(DataResolver): DirectoryResolver.unzip_all(data_root) image_paths = list(DirectoryResolver.recurse_data_root(data_root)) items = [] + multipliers = {} + skip_folders = [] for pathname in tqdm.tqdm(image_paths): - caption = ImageCaption.from_file(pathname) - item = self.image_train_item(pathname, caption) + current_dir = os.path.dirname(pathname) + + if current_dir not in multipliers: + multiply_txt_path = os.path.join(current_dir, "multiply.txt") + if os.path.exists(multiply_txt_path): + try: + with open(multiply_txt_path, 'r') as f: + val = float(f.read().strip()) + multipliers[current_dir] = val + logging.info(f" * DLMA multiply.txt in {current_dir} set to {val}") + except Exception as e: + logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}") + skip_folders.append(current_dir) + multipliers[current_dir] = 1.0 + else: + skip_folders.append(current_dir) + multipliers[current_dir] = 1.0 + + caption = ImageCaption.resolve(pathname) + item = self.image_train_item(pathname, caption, multiplier=multipliers[current_dir]) if item: items.append(item) @@ -182,68 +210,48 @@ def strategy(data_root: str): raise ValueError(f"data_root '{data_root}' is not a valid directory or JSON file.") -def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0, on_event: OptionalCallable=None) -> list[ImageTrainItem]: +def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0) -> typing.Tuple[list[ImageTrainItem], list[Event]]: """ :param data_root: Directory or JSON file. :param aspects: The list of aspect ratios to use :param flip_p: The probability of flipping the image """ if os.path.isfile(path) and path.endswith('.json'): - resolver = JSONResolver(aspects, flip_p, on_event) + resolver = JSONResolver(aspects, flip_p) if os.path.isdir(path): - resolver = DirectoryResolver(aspects, flip_p, on_event) + resolver = DirectoryResolver(aspects, flip_p) if not resolver: raise ValueError(f"data_root '{path}' is not a valid directory or JSON file.") - return resolver.image_train_items(path) + items = resolver.image_train_items(path) + events = resolver.events + return items, events -def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=0.0, on_event: OptionalCallable=None) -> list[ImageTrainItem]: +def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=0.0) -> typing.Tuple[list[ImageTrainItem], list[Event]]: """ Resolve the training data from the value. :param value: The value to resolve, either a dict or a string. :param aspects: The list of aspect ratios to use :param flip_p: The probability of flipping the image - :param on_event: The callback to call when an event occurs (e.g. undersized image detected) """ if isinstance(value, str): - return resolve_root(value, aspects, flip_p, on_event) + return resolve_root(value, aspects, flip_p) if isinstance(value, dict): resolver = value.get('resolver', None) match resolver: case 'directory' | 'json': path = value.get('path', None) - return resolve_root(path, aspects, flip_p, on_event) + return resolve_root(path, aspects, flip_p) case 'multi': - items = [] + resolved_items = [] + resolved_events = [] for resolver in value.get('resolvers', []): - items += resolve(resolver, aspects, flip_p, on_event) - return items + items, events = resolve(resolver, aspects, flip_p) + resolved_items.extend(items) + resolved_events.extend(events) + return resolved_items, resolved_events case _: - raise ValueError(f"Cannot resolve training data for resolver value '{resolver}'") - - -# example = { -# 'resolver': 'directory', -# 'data_root': 'data', -# } - -# example = { -# 'resolver': 'json', -# 'data_root': 'data.json', -# } - -# example = { -# 'resolver': 'multi', -# 'resolvers': [ -# { -# 'resolver': 'directory', -# 'data_root': 'data', -# }, { -# 'resolver': 'json', -# 'data_root': 'data.json', -# }, -# ] -# } \ No newline at end of file + raise ValueError(f"Cannot resolve training data for resolver value '{resolver}'") \ No newline at end of file diff --git a/test/test_data_resolver.py b/test/test_data_resolver.py index db3d176..299a7b6 100644 --- a/test/test_data_resolver.py +++ b/test/test_data_resolver.py @@ -11,7 +11,6 @@ import data.resolver as resolver DATA_PATH = os.path.abspath('./test/data') JSON_ROOT_PATH = os.path.join(DATA_PATH, 'test_root.json') ASPECTS = aspects.get_aspect_buckets(512) -FLIP_P = 0.0 IMAGE_1_PATH = os.path.join(DATA_PATH, 'test1.jpg') CAPTION_1_PATH = os.path.join(DATA_PATH, 'test1.txt') @@ -46,32 +45,23 @@ class TestResolve(unittest.TestCase): with open(JSON_ROOT_PATH, 'w') as f: json.dump(json_data, f, indent=4) - @classmethod def tearDownClass(cls): for file in glob.glob(os.path.join(DATA_PATH, 'test*')): os.remove(file) - def setUp(self) -> None: - self.events = [] - self.on_event = lambda event: self.events.append(event.name) - return super().setUp() - - def tearDown(self) -> None: - self.events = [] - self.on_event = None - return super().tearDown() - def test_directory_resolve_with_str(self): - image_train_items = resolver.resolve(DATA_PATH, ASPECTS, FLIP_P, self.on_event) - image_paths = [item.pathname for item in image_train_items] - image_captions = [item.caption for item in image_train_items] + items, events = resolver.resolve(DATA_PATH, ASPECTS) + image_paths = [item.pathname for item in items] + image_captions = [item.caption for item in items] captions = [caption.get_caption() for caption in image_captions] - self.assertEqual(len(image_train_items), 3) + self.assertEqual(len(items), 3) self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH]) self.assertEqual(captions, ['caption for test1', 'test2', 'test3']) - self.assertEqual(self.events, ['undersized_image']) + + events = list(map(lambda e: e.name, events)) + self.assertEqual(events, ['undersized_image']) def test_directory_resolve_with_dict(self): data_root_spec = { @@ -79,26 +69,30 @@ class TestResolve(unittest.TestCase): 'path': DATA_PATH, } - image_train_items = resolver.resolve(data_root_spec, ASPECTS, FLIP_P, self.on_event) - image_paths = [item.pathname for item in image_train_items] - image_captions = [item.caption for item in image_train_items] + items, events = resolver.resolve(data_root_spec, ASPECTS) + image_paths = [item.pathname for item in items] + image_captions = [item.caption for item in items] captions = [caption.get_caption() for caption in image_captions] - self.assertEqual(len(image_train_items), 3) + self.assertEqual(len(items), 3) self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH]) self.assertEqual(captions, ['caption for test1', 'test2', 'test3']) - self.assertEqual(self.events, ['undersized_image']) + + events = list(map(lambda e: e.name, events)) + self.assertEqual(events, ['undersized_image']) def test_json_resolve_with_str(self): - image_train_items = resolver.resolve(JSON_ROOT_PATH, ASPECTS, FLIP_P, self.on_event) - image_paths = [item.pathname for item in image_train_items] - image_captions = [item.caption for item in image_train_items] + items, events = resolver.resolve(JSON_ROOT_PATH, ASPECTS) + image_paths = [item.pathname for item in items] + image_captions = [item.caption for item in items] captions = [caption.get_caption() for caption in image_captions] - self.assertEqual(len(image_train_items), 3) + self.assertEqual(len(items), 3) self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH]) self.assertEqual(captions, ['caption for test1', 'caption for test2', 'test3']) - self.assertEqual(self.events, ['undersized_image']) + + events = list(map(lambda e: e.name, events)) + self.assertEqual(events, ['undersized_image']) def test_json_resolve_with_dict(self): data_root_spec = { @@ -106,12 +100,14 @@ class TestResolve(unittest.TestCase): 'path': JSON_ROOT_PATH, } - image_train_items = resolver.resolve(data_root_spec, ASPECTS, FLIP_P, self.on_event) - image_paths = [item.pathname for item in image_train_items] - image_captions = [item.caption for item in image_train_items] + items, events = resolver.resolve(data_root_spec, ASPECTS) + image_paths = [item.pathname for item in items] + image_captions = [item.caption for item in items] captions = [caption.get_caption() for caption in image_captions] - self.assertEqual(len(image_train_items), 3) + self.assertEqual(len(items), 3) self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH]) self.assertEqual(captions, ['caption for test1', 'caption for test2', 'test3']) - self.assertEqual(self.events, ['undersized_image']) \ No newline at end of file + + events = list(map(lambda e: e.name, events)) + self.assertEqual(events, ['undersized_image']) \ No newline at end of file From 0cf2cd71de72a812340953e6ad9aa69ad5823b41 Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Mon, 23 Jan 2023 00:14:50 -0800 Subject: [PATCH 09/22] Fix mistake in ImageCaption.parse --- data/image_train_item.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data/image_train_item.py b/data/image_train_item.py index 0e05145..db6f4e4 100644 --- a/data/image_train_item.py +++ b/data/image_train_item.py @@ -137,7 +137,7 @@ class ImageCaption: """ (file_name, _) = os.path.splitext(os.path.basename(file_path)) caption = file_name.split("_")[0] - return ImageCaption(caption, 1.0, [], [], DEFAULT_MAX_CAPTION_LENGTH, False) + return ImageCaption.parse(caption) @staticmethod def from_text_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption: From 316df2db7ea5e54470042fd01a7374826b87e048 Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Mon, 23 Jan 2023 00:15:32 -0800 Subject: [PATCH 10/22] Use data_resolver.resolve for data loading in data_loader --- data/data_loader.py | 178 ++++++-------------------------------------- 1 file changed, 21 insertions(+), 157 deletions(-) diff --git a/data/data_loader.py b/data/data_loader.py index ae276bd..0c0fb3c 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -18,22 +18,15 @@ import math import os import logging -import yaml -from PIL import Image import random -from data.image_train_item import ImageTrainItem, ImageCaption +from data.image_train_item import ImageTrainItem import data.aspects as aspects import data.resolver as resolver -from data.resolver import DirectoryResolver from colorama import Fore, Style -import zipfile -import tqdm import PIL PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default -DEFAULT_MAX_CAPTION_LENGTH = 2048 - class DataLoaderMultiAspect(): """ Data loader for multi-aspect-ratio training and bucketing @@ -43,25 +36,18 @@ class DataLoaderMultiAspect(): flip_p: probability of flipping image horizontally (i.e. 0-0.5) """ def __init__(self, data_root, seed=555, debug_level=0, batch_size=1, flip_p=0.0, resolution=512, log_folder=None): - self.image_paths = [] + self.data_root = data_root self.debug_level = debug_level self.flip_p = flip_p self.log_folder = log_folder self.seed = seed self.batch_size = batch_size self.has_scanned = False - self.aspects = aspects.get_aspect_buckets(resolution=resolution, square_only=False) + logging.info(f"* DLMA resolution {resolution}, buckets: {self.aspects}") logging.info(" Preloading images...") - - DirectoryResolver.unzip_all(data_root) - - for image_path in DirectoryResolver.recurse_data_root(data_root): - self.image_paths.append(image_path) - - random.Random(seed).shuffle(self.image_paths) - self.prepared_train_data = self.__prescan_images(self.image_paths, flip_p) + self.__prepare_train_data() (self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings() @@ -160,150 +146,28 @@ class DataLoaderMultiAspect(): return rating_overall_sum, ratings_summed - @staticmethod - def __read_caption_from_file(file_path, fallback_caption: ImageCaption) -> ImageCaption: - try: - with open(file_path, encoding='utf-8', mode='r') as caption_file: - caption_text = caption_file.read() - caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_text) - except: - logging.error(f" *** Error reading {file_path} to get caption, falling back to filename") - caption = fallback_caption - pass - return caption - - @staticmethod - def __read_caption_from_yaml(file_path: str, fallback_caption: ImageCaption) -> ImageCaption: - with open(file_path, "r") as stream: - try: - file_content = yaml.safe_load(stream) - main_prompt = file_content.get("main_prompt", "") - rating = file_content.get("rating", 1.0) - unparsed_tags = file_content.get("tags", []) - - max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH) - - tags = [] - tag_weights = [] - last_weight = None - weights_differ = False - for unparsed_tag in unparsed_tags: - tag = unparsed_tag.get("tag", "").strip() - if len(tag) == 0: - continue - - tags.append(tag) - tag_weight = unparsed_tag.get("weight", 1.0) - tag_weights.append(tag_weight) - - if last_weight is not None and weights_differ is False: - weights_differ = last_weight != tag_weight - - last_weight = tag_weight - - return ImageCaption(main_prompt, rating, tags, tag_weights, max_caption_length, weights_differ) - - except: - logging.error(f" *** Error reading {file_path} to get caption, falling back to filename") - return fallback_caption - - @staticmethod - def __split_caption_into_tags(caption_string: str) -> ImageCaption: - """ - Splits a string by "," into the main prompt and additional tags with equal weights - """ - split_caption = caption_string.split(",") - main_prompt = split_caption.pop(0).strip() - tags = [] - for tag in split_caption: - tags.append(tag.strip()) - - return ImageCaption(main_prompt, 1.0, tags, [1.0] * len(tags), DEFAULT_MAX_CAPTION_LENGTH, False) - - def __prescan_images(self, image_paths: list, flip_p=0.0) -> list[ImageTrainItem]: + def __prepare_train_data(self, flip_p=0.0) -> list[ImageTrainItem]: """ Create ImageTrainItem objects with metadata for hydration later """ - decorated_image_train_items = [] - - if not self.has_scanned: - undersized_images = [] - - multipliers = {} - skip_folders = [] - - for pathname in tqdm.tqdm(image_paths): - caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0] - caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_from_filename) - - file_path_without_ext = os.path.splitext(pathname)[0] - yaml_file_path = file_path_without_ext + ".yaml" - txt_file_path = file_path_without_ext + ".txt" - caption_file_path = file_path_without_ext + ".caption" - - current_dir = os.path.dirname(pathname) - - try: - if current_dir not in multipliers: - multiply_txt_path = os.path.join(current_dir, "multiply.txt") - #print(current_dir, multiply_txt_path) - if os.path.exists(multiply_txt_path): - with open(multiply_txt_path, 'r') as f: - val = float(f.read().strip()) - multipliers[current_dir] = val - logging.info(f" * DLMA multiply.txt in {current_dir} set to {val}") - else: - skip_folders.append(current_dir) - multipliers[current_dir] = 1.0 - except Exception as e: - logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}") - skip_folders.append(current_dir) - multipliers[current_dir] = 1.0 - - if os.path.exists(yaml_file_path): - caption = self.__read_caption_from_yaml(yaml_file_path, caption) - elif os.path.exists(txt_file_path): - caption = self.__read_caption_from_file(txt_file_path, caption) - elif os.path.exists(caption_file_path): - caption = self.__read_caption_from_file(caption_file_path, caption) - - try: - image = Image.open(pathname) - width, height = image.size - image_aspect = width / height - - target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect)) - if not self.has_scanned: - if width * height < target_wh[0] * target_wh[1]: - undersized_images.append(f" {pathname}, size: {width},{height}, target size: {target_wh}") - - image_train_item = ImageTrainItem(image=None, # image loaded at runtime to apply jitter - caption=caption, - target_wh=target_wh, - pathname=pathname, - flip_p=flip_p, - multiplier=multipliers[current_dir], - ) - - decorated_image_train_items.append(image_train_item) - - except Exception as e: - logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}") - logging.error(f" *** exception: {e}") - pass - if not self.has_scanned: self.has_scanned = True - if len(undersized_images) > 0: - underized_log_path = os.path.join(self.log_folder, "undersized_images.txt") - logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}") - logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}") - with open(underized_log_path, "w") as undersized_images_file: - undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:") - for undersized_image in undersized_images: - undersized_images_file.write(f"{undersized_image}\n") - - return decorated_image_train_items + self.prepared_train_data, events = resolver.resolve(self.data_root, self.aspects, flip_p=flip_p) + random.Random(self.seed).shuffle(self.prepared_train_data) + self.__report_undersized_images(events) + + def __report_undersized_images(self, events: list[resolver.Event]): + events = [event for event in events if isinstance(event, resolver.UndersizedImageEvent)] + + if len(events) > 0: + underized_log_path = os.path.join(self.log_folder, "undersized_images.txt") + logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}") + logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}") + with open(underized_log_path, "w") as undersized_images_file: + undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:") + for event in events: + message = f" *** {event.image_path} with size: {event.image_size} is smaller than target size: {event.target_size}, consider using larger images" + undersized_images_file.write(message) def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]: """ From 993eabf99ad7cd72963341e05d6d0ad67df37ac7 Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Sun, 22 Jan 2023 16:08:50 -0800 Subject: [PATCH 11/22] Add static methods on ImageCaption for deriving captions from various sources --- data/image_train_item.py | 152 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 146 insertions(+), 6 deletions(-) diff --git a/data/image_train_item.py b/data/image_train_item.py index 08bf736..9d1bdcb 100644 --- a/data/image_train_item.py +++ b/data/image_train_item.py @@ -18,6 +18,8 @@ import logging import math import os import random +import typing +import yaml import PIL import numpy as np @@ -25,6 +27,9 @@ from torchvision import transforms _RANDOM_TRIM = 0.04 +DEFAULT_MAX_CAPTION_LENGTH = 2048 + +OptionalImageCaption = typing.Optional['ImageCaption'] class ImageCaption: """ @@ -60,13 +65,15 @@ class ImageCaption: :param seed used to initialize the randomizer :return: generated caption string """ - max_target_tag_length = self.__max_target_length - len(self.__main_prompt) + if self.__tags: + max_target_tag_length = self.__max_target_length - len(self.__main_prompt) - if self.__use_weights: - tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length) - else: - tags_caption = self.__get_shuffled_tags(seed, self.__tags) + if self.__use_weights: + tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length) + else: + tags_caption = self.__get_shuffled_tags(seed, self.__tags) + return self.__main_prompt + ", " + tags_caption return self.__main_prompt + ", " + tags_caption def get_caption(self) -> str: @@ -91,7 +98,10 @@ class ImageCaption: weights_copy.pop(pos) tag = tags_copy.pop(pos) - caption += ", " + tag + + if caption: + caption += ", " + caption += tag return caption @@ -100,6 +110,136 @@ class ImageCaption: random.Random(seed).shuffle(tags) return ", ".join(tags) + @staticmethod + def parse(string: str) -> 'ImageCaption': + """ + Parses a string to get the caption. + + :param string: String to parse. + :return: `ImageCaption` object. + """ + split_caption = list(map(str.strip, string.split(","))) + main_prompt = split_caption[0] + tags = split_caption[1:] + tag_weights = [1.0] * len(tags) + + return ImageCaption(main_prompt, 1.0, tags, tag_weights, DEFAULT_MAX_CAPTION_LENGTH, False) + + @staticmethod + def from_file_name(file_path: str) -> 'ImageCaption': + """ + Parses the file name to get the caption. + + :param file_path: Path to the image file. + :return: `ImageCaption` object. + """ + (file_name, _) = os.path.splitext(os.path.basename(file_path)) + caption = file_name.split("_")[0] + return ImageCaption(caption, 1.0, [], [], DEFAULT_MAX_CAPTION_LENGTH, False) + + @staticmethod + def from_text_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption: + """ + Parses a text file to get the caption. Returns the default caption if + the file does not exist or is invalid. + + :param file_path: Path to the text file. + :param default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid. + :return: `ImageCaption` object or `None`. + """ + try: + with open(file_path, encoding='utf-8', mode='r') as caption_file: + caption_text = caption_file.read() + return ImageCaption.parse(caption_text) + except: + logging.error(f" *** Error reading {file_path} to get caption") + return default_caption + + @staticmethod + def from_yaml_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption: + """ + Parses a yaml file to get the caption. Returns the default caption if + the file does not exist or is invalid. + + :param file_path: path to the yaml file + :param default_caption: caption to return if the file does not exist or is invalid + :return: `ImageCaption` object or `None`. + """ + try: + with open(file_path, "r") as stream: + file_content = yaml.safe_load(stream) + main_prompt = file_content.get("main_prompt", "") + rating = file_content.get("rating", 1.0) + unparsed_tags = file_content.get("tags", []) + + max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH) + + tags = [] + tag_weights = [] + last_weight = None + weights_differ = False + for unparsed_tag in unparsed_tags: + tag = unparsed_tag.get("tag", "").strip() + if len(tag) == 0: + continue + + tags.append(tag) + tag_weight = unparsed_tag.get("weight", 1.0) + tag_weights.append(tag_weight) + + if last_weight is not None and weights_differ is False: + weights_differ = last_weight != tag_weight + + last_weight = tag_weight + + return ImageCaption(main_prompt, rating, tags, tag_weights, max_caption_length, weights_differ) + except: + logging.error(f" *** Error reading {file_path} to get caption") + return default_caption + + @staticmethod + def from_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption: + """ + Try to resolve a caption from a file path or return `default_caption`. + + :string: The path to the file to parse. + :default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid. + :return: `ImageCaption` object or `None`. + """ + if os.path.exists(file_path): + (file_path_without_ext, ext) = os.path.splitext(file_path) + match ext: + case ".yaml" | ".yml": + return ImageCaption.from_yaml_file(file_path, default_caption) + + case ".txt" | ".caption": + return ImageCaption.from_text_file(file_path, default_caption) + + case '.jpg'| '.jpeg'| '.png'| '.bmp'| '.webp'| '.jfif': + for ext in [".yaml", ".yml", ".txt", ".caption"]: + file_path = file_path_without_ext + ext + image_caption = ImageCaption.from_file(file_path) + if image_caption is not None: + return image_caption + return ImageCaption.from_file_name(file_path) + + case _: + return default_caption + else: + return default_caption + + @staticmethod + def resolve(string: str) -> 'ImageCaption': + """ + Try to resolve a caption from a string. If the string is a file path, + the caption will be read from the file, otherwise the string will be + parsed as a caption. + + :string: The string to resolve. + :return: `ImageCaption` object. + """ + return ImageCaption.from_file(string, None) or ImageCaption.parse(string) + class ImageTrainItem: """ From c1a66317cda6000a5fe6303533ea418dbb2865aa Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Mon, 23 Jan 2023 11:16:34 -0800 Subject: [PATCH 12/22] Forgot to set prepared_train_data --- data/data_loader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/data/data_loader.py b/data/data_loader.py index 139ea3f..0df4e82 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -166,6 +166,7 @@ class DataLoaderMultiAspect(): print (f" * DLMA: {len(items)} images loaded from {len(image_paths)} files") + self.prepared_train_data = items random.Random(self.seed).shuffle(self.prepared_train_data) self.__report_undersized_images(events) From 646f3831888806f7fb0996d6ea866c8b90098486 Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Mon, 23 Jan 2023 11:27:12 -0800 Subject: [PATCH 13/22] If there are no tags just return the main prompt --- data/image_train_item.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data/image_train_item.py b/data/image_train_item.py index 6ce10fb..84f92de 100644 --- a/data/image_train_item.py +++ b/data/image_train_item.py @@ -74,7 +74,7 @@ class ImageCaption: tags_caption = self.__get_shuffled_tags(seed, self.__tags) return self.__main_prompt + ", " + tags_caption - return self.__main_prompt + ", " + tags_caption + return self.__main_prompt def get_caption(self) -> str: if self.__tags: From 1dfda8d6d4cadcc217a06c44e458169aa5f2954a Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Mon, 23 Jan 2023 11:28:11 -0800 Subject: [PATCH 14/22] Remove unused OptionalCallable alias --- data/resolver.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/data/resolver.py b/data/resolver.py index 7e031d6..90653cd 100644 --- a/data/resolver.py +++ b/data/resolver.py @@ -11,9 +11,6 @@ from colorama import Fore, Style from data.image_train_item import ImageCaption, ImageTrainItem - -OptionalCallable = typing.Optional[typing.Callable] - class Event: def __init__(self, name: str): self.name = name From 1a0b7994f4734679e968d5fb9fa703736ed3a738 Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Mon, 23 Jan 2023 12:00:42 -0800 Subject: [PATCH 15/22] Move target_wh calculation to ImageTrainItem --- data/data_loader.py | 19 ++++++++---- data/image_train_item.py | 20 +++++++++++-- data/resolver.py | 64 +++++++++------------------------------- 3 files changed, 45 insertions(+), 58 deletions(-) diff --git a/data/data_loader.py b/data/data_loader.py index 0df4e82..f910b6e 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -161,27 +161,34 @@ class DataLoaderMultiAspect(): logging.info(" Preloading images...") - items, events = resolver.resolve(self.data_root, self.aspects, flip_p=flip_p, seed=self.seed) + items = resolver.resolve(self.data_root, self.aspects, flip_p=flip_p, seed=self.seed) image_paths = set(map(lambda item: item.pathname, items)) print (f" * DLMA: {len(items)} images loaded from {len(image_paths)} files") self.prepared_train_data = items random.Random(self.seed).shuffle(self.prepared_train_data) - self.__report_undersized_images(events) + self.__report_errors(items) - def __report_undersized_images(self, events: list[resolver.Event]): - events = [event for event in events if isinstance(event, resolver.UndersizedImageEvent)] + def __report_errors(self, items: list[ImageTrainItem]): + for item in items: + if item.error is not None: + logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.image_path}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}") + logging.error(f" *** exception: {item.error}") + + undersized_items = [item for item in items if item.is_undersized] - if len(events) > 0: + if len(undersized_items) > 0: underized_log_path = os.path.join(self.log_folder, "undersized_images.txt") logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}") logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}") with open(underized_log_path, "w") as undersized_images_file: undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:") - for event in events: + for event in undersized_items: message = f" *** {event.image_path} with size: {event.image_size} is smaller than target size: {event.target_size}, consider using larger images" undersized_images_file.write(message) + + def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]: """ diff --git a/data/image_train_item.py b/data/image_train_item.py index 84f92de..44a03be 100644 --- a/data/image_train_item.py +++ b/data/image_train_item.py @@ -22,6 +22,7 @@ import typing import yaml import PIL +import PIL.Image as Image import numpy as np from torchvision import transforms @@ -256,9 +257,9 @@ class ImageTrainItem: flip_p: probability of flipping image (0.0 to 1.0) rating: the relative rating of the images. The rating is measured in comparison to the other images. """ - def __init__(self, image: PIL.Image, caption: ImageCaption, target_wh: list, pathname: str, flip_p=0.0, multiplier: float=1.0): + def __init__(self, image: PIL.Image, caption: ImageCaption, aspects: list[float], pathname: str, flip_p=0.0, multiplier: float=1.0): self.caption = caption - self.target_wh = target_wh + self.aspects = aspects self.pathname = pathname self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.cropped_img = None @@ -269,6 +270,9 @@ class ImageTrainItem: self.image = [] else: self.image = image + + self.error = None + self.__compute_target_width_height() def hydrate(self, crop=False, save=False, crop_jitter=20): """ @@ -345,6 +349,18 @@ class ImageTrainItem: # print(self.image.shape) return self + + def __compute_target_width_height(self): + try: + with Image.open(self.image_path) as image: + width, height = image.size + image_aspect = width / height + target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect)) + + self.is_undersized = width * height < target_wh[0] * target_wh[1] + self.target_wh = target_wh + except Exception as e: + self.error = e @staticmethod def __autocrop(image: PIL.Image, q=.404): diff --git a/data/resolver.py b/data/resolver.py index 90653cd..185a0fb 100644 --- a/data/resolver.py +++ b/data/resolver.py @@ -11,22 +11,10 @@ from colorama import Fore, Style from data.image_train_item import ImageCaption, ImageTrainItem -class Event: - def __init__(self, name: str): - self.name = name - -class UndersizedImageEvent(Event): - def __init__(self, image_path: str, image_size: typing.Tuple[int, int], target_size: typing.Tuple[int, int]): - super().__init__('undersized_image') - self.image_path = image_path - self.image_size = image_size - self.target_size = target_size - class DataResolver: def __init__(self, aspects: list[typing.Tuple[int, int]], flip_p=0.0, seed=555): self.aspects = aspects self.flip_p = flip_p - self.events = [] def image_train_items(self, data_root: str) -> list[ImageTrainItem]: """ @@ -37,35 +25,15 @@ class DataResolver: """ raise NotImplementedError() - def compute_target_width_height(self, image_path: str) -> typing.Optional[typing.Tuple[int, int]]: - # Compute the target width and height for the image based on the aspect ratio. - with Image.open(image_path) as image: - width, height = image.size - image_aspect = width / height - target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect)) - - if width * height < target_wh[0] * target_wh[1]: - event = UndersizedImageEvent(image_path, (width, height), target_wh) - self.events.append(event) - - return target_wh - def image_train_item(self, image_path: str, caption: ImageCaption, multiplier: float=1) -> ImageTrainItem: - try: - target_wh = self.compute_target_width_height(image_path) - return ImageTrainItem( - image=None, - caption=caption, - target_wh=target_wh, - pathname=image_path, - flip_p=self.flip_p, - multiplier=multiplier - ) - # TODO: This should only handle Image errors. - except Exception as e: - logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{image_path}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}") - logging.error(f" *** exception: {e}") - + return ImageTrainItem( + image=None, + caption=caption, + aspects=self.aspects, + pathname=image_path, + flip_p=self.flip_p, + multiplier=multiplier + ) class JSONResolver(DataResolver): def image_train_items(self, json_path: str) -> list[ImageTrainItem]: @@ -219,7 +187,7 @@ def strategy(data_root: str): raise ValueError(f"data_root '{data_root}' is not a valid directory or JSON file.") -def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0, seed) -> typing.Tuple[list[ImageTrainItem], list[Event]]: +def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0, seed=555) -> list[ImageTrainItem]: """ :param data_root: Directory or JSON file. :param aspects: The list of aspect ratios to use @@ -235,10 +203,9 @@ def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0, seed) -> raise ValueError(f"data_root '{path}' is not a valid directory or JSON file.") items = resolver.image_train_items(path) - events = resolver.events - return items, events + return items -def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=0.0, seed=555) -> typing.Tuple[list[ImageTrainItem], list[Event]]: +def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=0.0, seed=555) -> list[ImageTrainItem]: """ Resolve the training data from the value. :param value: The value to resolve, either a dict or a string. @@ -255,12 +222,9 @@ def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float= path = value.get('path', None) return resolve_root(path, aspects, flip_p, seed) case 'multi': - resolved_items = [] - resolved_events = [] + items = [] for resolver in value.get('resolvers', []): - items, events = resolve(resolver, aspects, flip_p, seed) - resolved_items.extend(items) - resolved_events.extend(events) - return resolved_items, resolved_events + items.extend(resolve(resolver, aspects, flip_p, seed)) + return items case _: raise ValueError(f"Cannot resolve training data for resolver value '{resolver}'") \ No newline at end of file From d9081e8198ad17a341105016e78f54cc540b6fd4 Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Mon, 23 Jan 2023 12:11:42 -0800 Subject: [PATCH 16/22] Use tqdm in JSONResolver, add a docstring --- data/resolver.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/data/resolver.py b/data/resolver.py index 185a0fb..973c0b2 100644 --- a/data/resolver.py +++ b/data/resolver.py @@ -37,11 +37,17 @@ class DataResolver: class JSONResolver(DataResolver): def image_train_items(self, json_path: str) -> list[ImageTrainItem]: + """ + Create `ImageTrainItem` objects with metadata for hydration later. + Extracts images and captions from a JSON file. + + :param json_path: The path to the JSON file. + """ items = [] with open(json_path, encoding='utf-8', mode='r') as f: json_data = json.load(f) - for data in json_data: + for data in tqdm.tqdm(json_data): caption = JSONResolver.image_caption(data) if caption: image_value = JSONResolver.get_image_value(data) From 685bada57029237b5595a8d4e1fb21ff7398d420 Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Mon, 23 Jan 2023 12:15:35 -0800 Subject: [PATCH 17/22] Update tests --- test/test_data_resolver.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/test/test_data_resolver.py b/test/test_data_resolver.py index 299a7b6..095165b 100644 --- a/test/test_data_resolver.py +++ b/test/test_data_resolver.py @@ -51,7 +51,7 @@ class TestResolve(unittest.TestCase): os.remove(file) def test_directory_resolve_with_str(self): - items, events = resolver.resolve(DATA_PATH, ASPECTS) + items = resolver.resolve(DATA_PATH, ASPECTS) image_paths = [item.pathname for item in items] image_captions = [item.caption for item in items] captions = [caption.get_caption() for caption in image_captions] @@ -60,8 +60,8 @@ class TestResolve(unittest.TestCase): self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH]) self.assertEqual(captions, ['caption for test1', 'test2', 'test3']) - events = list(map(lambda e: e.name, events)) - self.assertEqual(events, ['undersized_image']) + undersized_images = list(filter(lambda i: i.is_undersized, items)) + self.assertEqual(undersized_images, 1) def test_directory_resolve_with_dict(self): data_root_spec = { @@ -69,7 +69,7 @@ class TestResolve(unittest.TestCase): 'path': DATA_PATH, } - items, events = resolver.resolve(data_root_spec, ASPECTS) + items = resolver.resolve(data_root_spec, ASPECTS) image_paths = [item.pathname for item in items] image_captions = [item.caption for item in items] captions = [caption.get_caption() for caption in image_captions] @@ -78,11 +78,11 @@ class TestResolve(unittest.TestCase): self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH]) self.assertEqual(captions, ['caption for test1', 'test2', 'test3']) - events = list(map(lambda e: e.name, events)) - self.assertEqual(events, ['undersized_image']) + undersized_images = list(filter(lambda i: i.is_undersized, items)) + self.assertEqual(undersized_images, 1) def test_json_resolve_with_str(self): - items, events = resolver.resolve(JSON_ROOT_PATH, ASPECTS) + items = resolver.resolve(JSON_ROOT_PATH, ASPECTS) image_paths = [item.pathname for item in items] image_captions = [item.caption for item in items] captions = [caption.get_caption() for caption in image_captions] @@ -91,8 +91,8 @@ class TestResolve(unittest.TestCase): self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH]) self.assertEqual(captions, ['caption for test1', 'caption for test2', 'test3']) - events = list(map(lambda e: e.name, events)) - self.assertEqual(events, ['undersized_image']) + undersized_images = list(filter(lambda i: i.is_undersized, items)) + self.assertEqual(undersized_images, 1) def test_json_resolve_with_dict(self): data_root_spec = { @@ -100,7 +100,7 @@ class TestResolve(unittest.TestCase): 'path': JSON_ROOT_PATH, } - items, events = resolver.resolve(data_root_spec, ASPECTS) + items = resolver.resolve(data_root_spec, ASPECTS) image_paths = [item.pathname for item in items] image_captions = [item.caption for item in items] captions = [caption.get_caption() for caption in image_captions] @@ -109,5 +109,5 @@ class TestResolve(unittest.TestCase): self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH]) self.assertEqual(captions, ['caption for test1', 'caption for test2', 'test3']) - events = list(map(lambda e: e.name, events)) - self.assertEqual(events, ['undersized_image']) \ No newline at end of file + undersized_images = list(filter(lambda i: i.is_undersized, items)) + self.assertEqual(undersized_images, 1) \ No newline at end of file From 9491ae430ca637659ca38dd380d9fa498dd21bf3 Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Mon, 23 Jan 2023 12:44:48 -0800 Subject: [PATCH 18/22] Initialize is_undersized to False, fix bug in directory resolver --- data/image_train_item.py | 1 + data/resolver.py | 6 ++---- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/data/image_train_item.py b/data/image_train_item.py index 44a03be..fa13e3b 100644 --- a/data/image_train_item.py +++ b/data/image_train_item.py @@ -271,6 +271,7 @@ class ImageTrainItem: else: self.image = image + self.is_undersized = False self.error = None self.__compute_target_width_height() diff --git a/data/resolver.py b/data/resolver.py index 973c0b2..0662991 100644 --- a/data/resolver.py +++ b/data/resolver.py @@ -13,6 +13,7 @@ from data.image_train_item import ImageCaption, ImageTrainItem class DataResolver: def __init__(self, aspects: list[typing.Tuple[int, int]], flip_p=0.0, seed=555): + self.seed = seed self.aspects = aspects self.flip_p = flip_p @@ -146,9 +147,6 @@ class DirectoryResolver(DataResolver): if cur_file_multiplier > 0: if randomizer.random() < cur_file_multiplier: items.append(item) - - if item: - items.append(item) return items @staticmethod @@ -230,7 +228,7 @@ def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float= case 'multi': items = [] for resolver in value.get('resolvers', []): - items.extend(resolve(resolver, aspects, flip_p, seed)) + items += resolve(resolver, aspects, flip_p, seed) return items case _: raise ValueError(f"Cannot resolve training data for resolver value '{resolver}'") \ No newline at end of file From e99a948d3cbf4c21f334ecf1821ef686d345e51e Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Mon, 23 Jan 2023 12:52:53 -0800 Subject: [PATCH 19/22] Fix property name bug --- data/image_train_item.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/data/image_train_item.py b/data/image_train_item.py index fa13e3b..9e1565b 100644 --- a/data/image_train_item.py +++ b/data/image_train_item.py @@ -353,11 +353,11 @@ class ImageTrainItem: def __compute_target_width_height(self): try: - with Image.open(self.image_path) as image: + with Image.open(self.pathname) as image: width, height = image.size image_aspect = width / height target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect)) - + self.is_undersized = width * height < target_wh[0] * target_wh[1] self.target_wh = target_wh except Exception as e: From 620b157e6a39aa3c21f496c97c64525441f8576d Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Mon, 23 Jan 2023 16:57:02 -0800 Subject: [PATCH 20/22] Fix some name errors --- data/data_loader.py | 6 +++--- data/image_train_item.py | 3 +++ test/test_data_resolver.py | 10 +++++----- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/data/data_loader.py b/data/data_loader.py index f910b6e..8ea2d4b 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -173,7 +173,7 @@ class DataLoaderMultiAspect(): def __report_errors(self, items: list[ImageTrainItem]): for item in items: if item.error is not None: - logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.image_path}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}") + logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}") logging.error(f" *** exception: {item.error}") undersized_items = [item for item in items if item.is_undersized] @@ -184,8 +184,8 @@ class DataLoaderMultiAspect(): logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}") with open(underized_log_path, "w") as undersized_images_file: undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:") - for event in undersized_items: - message = f" *** {event.image_path} with size: {event.image_size} is smaller than target size: {event.target_size}, consider using larger images" + for undersized_item in undersized_items: + message = f" *** {undersized_item.pathname} with size: {undersized_item.image_size} is smaller than target size: {undersized_item.target_wh}, consider using larger images" undersized_images_file.write(message) diff --git a/data/image_train_item.py b/data/image_train_item.py index 9e1565b..d882678 100644 --- a/data/image_train_item.py +++ b/data/image_train_item.py @@ -266,10 +266,13 @@ class ImageTrainItem: self.runt_size = 0 self.multiplier = multiplier + self.image_size = None if image is None: self.image = [] else: self.image = image + self.image_size = image.size + self.target_size = None self.is_undersized = False self.error = None diff --git a/test/test_data_resolver.py b/test/test_data_resolver.py index 095165b..625f228 100644 --- a/test/test_data_resolver.py +++ b/test/test_data_resolver.py @@ -25,7 +25,7 @@ class TestResolve(unittest.TestCase): f.write('caption for test1') Image.new('RGB', (512, 512)).save(IMAGE_2_PATH) - # Undersized image. Should cause an event. + # Undersized image Image.new('RGB', (256, 256)).save(IMAGE_3_PATH) json_data = [ @@ -61,7 +61,7 @@ class TestResolve(unittest.TestCase): self.assertEqual(captions, ['caption for test1', 'test2', 'test3']) undersized_images = list(filter(lambda i: i.is_undersized, items)) - self.assertEqual(undersized_images, 1) + self.assertEqual(len(undersized_images), 1) def test_directory_resolve_with_dict(self): data_root_spec = { @@ -79,7 +79,7 @@ class TestResolve(unittest.TestCase): self.assertEqual(captions, ['caption for test1', 'test2', 'test3']) undersized_images = list(filter(lambda i: i.is_undersized, items)) - self.assertEqual(undersized_images, 1) + self.assertEqual(len(undersized_images), 1) def test_json_resolve_with_str(self): items = resolver.resolve(JSON_ROOT_PATH, ASPECTS) @@ -92,7 +92,7 @@ class TestResolve(unittest.TestCase): self.assertEqual(captions, ['caption for test1', 'caption for test2', 'test3']) undersized_images = list(filter(lambda i: i.is_undersized, items)) - self.assertEqual(undersized_images, 1) + self.assertEqual(len(undersized_images), 1) def test_json_resolve_with_dict(self): data_root_spec = { @@ -110,4 +110,4 @@ class TestResolve(unittest.TestCase): self.assertEqual(captions, ['caption for test1', 'caption for test2', 'test3']) undersized_images = list(filter(lambda i: i.is_undersized, items)) - self.assertEqual(undersized_images, 1) \ No newline at end of file + self.assertEqual(len(undersized_images), 1) \ No newline at end of file From c106de827c0b6efe5ad52ea20d9eb581a4b1ba10 Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Mon, 23 Jan 2023 16:58:13 -0800 Subject: [PATCH 21/22] Fix name typo --- data/resolver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data/resolver.py b/data/resolver.py index 0662991..7973e60 100644 --- a/data/resolver.py +++ b/data/resolver.py @@ -179,7 +179,7 @@ class DirectoryResolver(DataResolver): sub_dirs.append(current) for dir in sub_dirs: - DirectoryResolver.__recurse_data_root(dir) + DirectoryResolver.recurse_data_root(dir) def strategy(data_root: str): if os.path.isfile(data_root) and data_root.endswith('.json'): From 7e2a7ae3874db44a050a60262c22c0b8a1e5566e Mon Sep 17 00:00:00 2001 From: Joel Holdbrooks Date: Mon, 23 Jan 2023 17:12:46 -0800 Subject: [PATCH 22/22] Need to yield from --- data/resolver.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/data/resolver.py b/data/resolver.py index 7973e60..94168a8 100644 --- a/data/resolver.py +++ b/data/resolver.py @@ -171,15 +171,11 @@ class DirectoryResolver(DataResolver): if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']: yield current - sub_dirs = [] - for d in os.listdir(recurse_root): current = os.path.join(recurse_root, d) if os.path.isdir(current): - sub_dirs.append(current) - - for dir in sub_dirs: - DirectoryResolver.recurse_data_root(dir) + yield from DirectoryResolver.recurse_data_root(current) + def strategy(data_root: str): if os.path.isfile(data_root) and data_root.endswith('.json'):