diff --git a/data/data_loader.py b/data/data_loader.py index e7afd6d..8ea2d4b 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -19,20 +19,15 @@ import os import logging import copy -import yaml -from PIL import Image import random -from data.image_train_item import ImageTrainItem, ImageCaption +from data.image_train_item import ImageTrainItem import data.aspects as aspects +import data.resolver as resolver from colorama import Fore, Style -import zipfile -import tqdm import PIL PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default -DEFAULT_MAX_CAPTION_LENGTH = 2048 - class DataLoaderMultiAspect(): """ Data loader for multi-aspect-ratio training and bucketing @@ -42,24 +37,17 @@ class DataLoaderMultiAspect(): flip_p: probability of flipping image horizontally (i.e. 0-0.5) """ def __init__(self, data_root, seed=555, debug_level=0, batch_size=1, flip_p=0.0, resolution=512, log_folder=None): - self.image_paths = [] + self.data_root = data_root self.debug_level = debug_level self.flip_p = flip_p self.log_folder = log_folder self.seed = seed self.batch_size = batch_size self.has_scanned = False - self.aspects = aspects.get_aspect_buckets(resolution=resolution, square_only=False) + logging.info(f"* DLMA resolution {resolution}, buckets: {self.aspects}") - logging.info(" Preloading images...") - - self.unzip_all(data_root) - - self.__recurse_data_root(self=self, recurse_root=data_root) - random.Random(seed).shuffle(self.image_paths) - self.prepared_train_data = self.__prescan_images(self.image_paths, flip_p) - print(f"DLMA Loaded {len(self.prepared_train_data)} images") + self.__prepare_train_data() (self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings() @@ -152,18 +140,6 @@ class DataLoaderMultiAspect(): return image_caption_pairs - @staticmethod - def unzip_all(path): - try: - for root, dirs, files in os.walk(path): - for file in files: - if file.endswith('.zip'): - logging.info(f"Unzipping {file}") - with zipfile.ZipFile(path, 'r') as zip_ref: - zip_ref.extractall(path) - except Exception as e: - logging.error(f"Error unzipping files {e}") - def __sort_and_precalc_image_ratings(self) -> tuple[float, list[float]]: self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating()) @@ -175,161 +151,44 @@ class DataLoaderMultiAspect(): return rating_overall_sum, ratings_summed - @staticmethod - def __read_caption_from_file(file_path, fallback_caption: ImageCaption) -> ImageCaption: - try: - with open(file_path, encoding='utf-8', mode='r') as caption_file: - caption_text = caption_file.read() - caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_text) - except: - logging.error(f" *** Error reading {file_path} to get caption, falling back to filename") - caption = fallback_caption - pass - return caption - - @staticmethod - def __read_caption_from_yaml(file_path: str, fallback_caption: ImageCaption) -> ImageCaption: - with open(file_path, "r") as stream: - try: - file_content = yaml.safe_load(stream) - main_prompt = file_content.get("main_prompt", "") - rating = file_content.get("rating", 1.0) - unparsed_tags = file_content.get("tags", []) - - max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH) - - tags = [] - tag_weights = [] - last_weight = None - weights_differ = False - for unparsed_tag in unparsed_tags: - tag = unparsed_tag.get("tag", "").strip() - if len(tag) == 0: - continue - - tags.append(tag) - tag_weight = unparsed_tag.get("weight", 1.0) - tag_weights.append(tag_weight) - - if last_weight is not None and weights_differ is False: - weights_differ = last_weight != tag_weight - - last_weight = tag_weight - - return ImageCaption(main_prompt, rating, tags, tag_weights, max_caption_length, weights_differ) - - except: - logging.error(f" *** Error reading {file_path} to get caption, falling back to filename") - return fallback_caption - - @staticmethod - def __split_caption_into_tags(caption_string: str) -> ImageCaption: - """ - Splits a string by "," into the main prompt and additional tags with equal weights - """ - split_caption = caption_string.split(",") - main_prompt = split_caption.pop(0).strip() - tags = [] - for tag in split_caption: - tags.append(tag.strip()) - - return ImageCaption(main_prompt, 1.0, tags, [1.0] * len(tags), DEFAULT_MAX_CAPTION_LENGTH, False) - - def __prescan_images(self, image_paths: list, flip_p=0.0) -> list[ImageTrainItem]: + def __prepare_train_data(self, flip_p=0.0) -> list[ImageTrainItem]: """ Create ImageTrainItem objects with metadata for hydration later """ - decorated_image_train_items = [] - - if not self.has_scanned: - undersized_images = [] - - multipliers = {} - skip_folders = [] - randomizer = random.Random(self.seed) - - for pathname in tqdm.tqdm(image_paths): - caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0] - caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_from_filename) - - file_path_without_ext = os.path.splitext(pathname)[0] - yaml_file_path = file_path_without_ext + ".yaml" - txt_file_path = file_path_without_ext + ".txt" - caption_file_path = file_path_without_ext + ".caption" - - current_dir = os.path.dirname(pathname) - - try: - if current_dir not in multipliers: - multiply_txt_path = os.path.join(current_dir, "multiply.txt") - #print(current_dir, multiply_txt_path) - if os.path.exists(multiply_txt_path): - with open(multiply_txt_path, 'r') as f: - val = float(f.read().strip()) - multipliers[current_dir] = val - logging.info(f" * DLMA multiply.txt in {current_dir} set to {val}") - else: - skip_folders.append(current_dir) - multipliers[current_dir] = 1.0 - except Exception as e: - logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}") - skip_folders.append(current_dir) - multipliers[current_dir] = 1.0 - - if os.path.exists(yaml_file_path): - caption = self.__read_caption_from_yaml(yaml_file_path, caption) - elif os.path.exists(txt_file_path): - caption = self.__read_caption_from_file(txt_file_path, caption) - elif os.path.exists(caption_file_path): - caption = self.__read_caption_from_file(caption_file_path, caption) - - try: - image = Image.open(pathname) - width, height = image.size - image_aspect = width / height - - target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect)) - if not self.has_scanned: - if width * height < target_wh[0] * target_wh[1]: - undersized_images.append(f" {pathname}, size: {width},{height}, target size: {target_wh}") - - image_train_item = ImageTrainItem(image=None, # image loaded at runtime to apply jitter - caption=caption, - target_wh=target_wh, - pathname=pathname, - flip_p=flip_p, - multiplier=multipliers[current_dir], - ) - - cur_file_multiplier = multipliers[current_dir] - - while cur_file_multiplier >= 1.0: - decorated_image_train_items.append(image_train_item) - cur_file_multiplier -= 1 - - if cur_file_multiplier > 0: - if randomizer.random() < cur_file_multiplier: - decorated_image_train_items.append(image_train_item) - - except Exception as e: - logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}") - logging.error(f" *** exception: {e}") - pass if not self.has_scanned: self.has_scanned = True - if len(undersized_images) > 0: - underized_log_path = os.path.join(self.log_folder, "undersized_images.txt") - logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}") - logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}") - with open(underized_log_path, "w") as undersized_images_file: - undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:") - for undersized_image in undersized_images: - undersized_images_file.write(f"{undersized_image}\n") - - print (f" * DLMA: {len(decorated_image_train_items)} images loaded from {len(image_paths)} files") + + logging.info(" Preloading images...") + + items = resolver.resolve(self.data_root, self.aspects, flip_p=flip_p, seed=self.seed) + image_paths = set(map(lambda item: item.pathname, items)) + + print (f" * DLMA: {len(items)} images loaded from {len(image_paths)} files") + + self.prepared_train_data = items + random.Random(self.seed).shuffle(self.prepared_train_data) + self.__report_errors(items) + + def __report_errors(self, items: list[ImageTrainItem]): + for item in items: + if item.error is not None: + logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}") + logging.error(f" *** exception: {item.error}") + + undersized_items = [item for item in items if item.is_undersized] + + if len(undersized_items) > 0: + underized_log_path = os.path.join(self.log_folder, "undersized_images.txt") + logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}") + logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}") + with open(underized_log_path, "w") as undersized_images_file: + undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:") + for undersized_item in undersized_items: + message = f" *** {undersized_item.pathname} with size: {undersized_item.image_size} is smaller than target size: {undersized_item.target_wh}, consider using larger images" + undersized_images_file.write(message) + - return decorated_image_train_items def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]: """ @@ -367,23 +226,3 @@ class DataLoaderMultiAspect(): prepared_train_data.pop(pos) return picked_images - - @staticmethod - def __recurse_data_root(self, recurse_root): - for f in os.listdir(recurse_root): - current = os.path.join(recurse_root, f) - - if os.path.isfile(current): - ext = os.path.splitext(f)[1].lower() - if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']: - self.image_paths.append(current) - - sub_dirs = [] - - for d in os.listdir(recurse_root): - current = os.path.join(recurse_root, d) - if os.path.isdir(current): - sub_dirs.append(current) - - for dir in sub_dirs: - self.__recurse_data_root(self=self, recurse_root=dir) diff --git a/data/image_train_item.py b/data/image_train_item.py index 08bf736..d882678 100644 --- a/data/image_train_item.py +++ b/data/image_train_item.py @@ -18,13 +18,19 @@ import logging import math import os import random +import typing +import yaml import PIL +import PIL.Image as Image import numpy as np from torchvision import transforms _RANDOM_TRIM = 0.04 +DEFAULT_MAX_CAPTION_LENGTH = 2048 + +OptionalImageCaption = typing.Optional['ImageCaption'] class ImageCaption: """ @@ -60,17 +66,21 @@ class ImageCaption: :param seed used to initialize the randomizer :return: generated caption string """ - max_target_tag_length = self.__max_target_length - len(self.__main_prompt) + if self.__tags: + max_target_tag_length = self.__max_target_length - len(self.__main_prompt) - if self.__use_weights: - tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length) - else: - tags_caption = self.__get_shuffled_tags(seed, self.__tags) + if self.__use_weights: + tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length) + else: + tags_caption = self.__get_shuffled_tags(seed, self.__tags) - return self.__main_prompt + ", " + tags_caption + return self.__main_prompt + ", " + tags_caption + return self.__main_prompt def get_caption(self) -> str: - return self.__main_prompt + ", " + ", ".join(self.__tags) + if self.__tags: + return self.__main_prompt + ", " + ", ".join(self.__tags) + return self.__main_prompt @staticmethod def __get_weighted_shuffled_tags(seed: int, tags: list[str], weights: list[float], max_target_tag_length: int) -> str: @@ -91,7 +101,14 @@ class ImageCaption: weights_copy.pop(pos) tag = tags_copy.pop(pos) - caption += ", " + tag + + if caption: + caption += ", " + caption += tag + + if caption: + caption += ", " + caption += tag return caption @@ -100,6 +117,136 @@ class ImageCaption: random.Random(seed).shuffle(tags) return ", ".join(tags) + @staticmethod + def parse(string: str) -> 'ImageCaption': + """ + Parses a string to get the caption. + + :param string: String to parse. + :return: `ImageCaption` object. + """ + split_caption = list(map(str.strip, string.split(","))) + main_prompt = split_caption[0] + tags = split_caption[1:] + tag_weights = [1.0] * len(tags) + + return ImageCaption(main_prompt, 1.0, tags, tag_weights, DEFAULT_MAX_CAPTION_LENGTH, False) + + @staticmethod + def from_file_name(file_path: str) -> 'ImageCaption': + """ + Parses the file name to get the caption. + + :param file_path: Path to the image file. + :return: `ImageCaption` object. + """ + (file_name, _) = os.path.splitext(os.path.basename(file_path)) + caption = file_name.split("_")[0] + return ImageCaption.parse(caption) + + @staticmethod + def from_text_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption: + """ + Parses a text file to get the caption. Returns the default caption if + the file does not exist or is invalid. + + :param file_path: Path to the text file. + :param default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid. + :return: `ImageCaption` object or `None`. + """ + try: + with open(file_path, encoding='utf-8', mode='r') as caption_file: + caption_text = caption_file.read() + return ImageCaption.parse(caption_text) + except: + logging.error(f" *** Error reading {file_path} to get caption") + return default_caption + + @staticmethod + def from_yaml_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption: + """ + Parses a yaml file to get the caption. Returns the default caption if + the file does not exist or is invalid. + + :param file_path: path to the yaml file + :param default_caption: caption to return if the file does not exist or is invalid + :return: `ImageCaption` object or `None`. + """ + try: + with open(file_path, "r") as stream: + file_content = yaml.safe_load(stream) + main_prompt = file_content.get("main_prompt", "") + rating = file_content.get("rating", 1.0) + unparsed_tags = file_content.get("tags", []) + + max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH) + + tags = [] + tag_weights = [] + last_weight = None + weights_differ = False + for unparsed_tag in unparsed_tags: + tag = unparsed_tag.get("tag", "").strip() + if len(tag) == 0: + continue + + tags.append(tag) + tag_weight = unparsed_tag.get("weight", 1.0) + tag_weights.append(tag_weight) + + if last_weight is not None and weights_differ is False: + weights_differ = last_weight != tag_weight + + last_weight = tag_weight + + return ImageCaption(main_prompt, rating, tags, tag_weights, max_caption_length, weights_differ) + except: + logging.error(f" *** Error reading {file_path} to get caption") + return default_caption + + @staticmethod + def from_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption: + """ + Try to resolve a caption from a file path or return `default_caption`. + + :string: The path to the file to parse. + :default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid. + :return: `ImageCaption` object or `None`. + """ + if os.path.exists(file_path): + (file_path_without_ext, ext) = os.path.splitext(file_path) + match ext: + case ".yaml" | ".yml": + return ImageCaption.from_yaml_file(file_path, default_caption) + + case ".txt" | ".caption": + return ImageCaption.from_text_file(file_path, default_caption) + + case '.jpg'| '.jpeg'| '.png'| '.bmp'| '.webp'| '.jfif': + for ext in [".yaml", ".yml", ".txt", ".caption"]: + file_path = file_path_without_ext + ext + image_caption = ImageCaption.from_file(file_path) + if image_caption is not None: + return image_caption + return ImageCaption.from_file_name(file_path) + + case _: + return default_caption + else: + return default_caption + + @staticmethod + def resolve(string: str) -> 'ImageCaption': + """ + Try to resolve a caption from a string. If the string is a file path, + the caption will be read from the file, otherwise the string will be + parsed as a caption. + + :string: The string to resolve. + :return: `ImageCaption` object. + """ + return ImageCaption.from_file(string, None) or ImageCaption.parse(string) + class ImageTrainItem: """ @@ -110,19 +257,26 @@ class ImageTrainItem: flip_p: probability of flipping image (0.0 to 1.0) rating: the relative rating of the images. The rating is measured in comparison to the other images. """ - def __init__(self, image: PIL.Image, caption: ImageCaption, target_wh: list, pathname: str, flip_p=0.0, multiplier: float=1.0): + def __init__(self, image: PIL.Image, caption: ImageCaption, aspects: list[float], pathname: str, flip_p=0.0, multiplier: float=1.0): self.caption = caption - self.target_wh = target_wh + self.aspects = aspects self.pathname = pathname self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.cropped_img = None self.runt_size = 0 self.multiplier = multiplier + self.image_size = None if image is None: self.image = [] else: self.image = image + self.image_size = image.size + self.target_size = None + + self.is_undersized = False + self.error = None + self.__compute_target_width_height() def hydrate(self, crop=False, save=False, crop_jitter=20): """ @@ -199,6 +353,18 @@ class ImageTrainItem: # print(self.image.shape) return self + + def __compute_target_width_height(self): + try: + with Image.open(self.pathname) as image: + width, height = image.size + image_aspect = width / height + target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect)) + + self.is_undersized = width * height < target_wh[0] * target_wh[1] + self.target_wh = target_wh + except Exception as e: + self.error = e @staticmethod def __autocrop(image: PIL.Image, q=.404): @@ -229,4 +395,4 @@ class ImageTrainItem: min_xy = min(x, y) image = image.crop((x_crop, y_crop, x_crop + min_xy, y_crop + min_xy)) - return image + return image \ No newline at end of file diff --git a/data/resolver.py b/data/resolver.py new file mode 100644 index 0000000..94168a8 --- /dev/null +++ b/data/resolver.py @@ -0,0 +1,230 @@ +import json +import logging +import os +import random +import typing +import zipfile + +import PIL.Image as Image +import tqdm +from colorama import Fore, Style + +from data.image_train_item import ImageCaption, ImageTrainItem + +class DataResolver: + def __init__(self, aspects: list[typing.Tuple[int, int]], flip_p=0.0, seed=555): + self.seed = seed + self.aspects = aspects + self.flip_p = flip_p + + def image_train_items(self, data_root: str) -> list[ImageTrainItem]: + """ + Get the list of `ImageTrainItem` for the given data root. + + :param data_root: The data root, a directory, a file, etc.. + :return: The list of `ImageTrainItem`. + """ + raise NotImplementedError() + + def image_train_item(self, image_path: str, caption: ImageCaption, multiplier: float=1) -> ImageTrainItem: + return ImageTrainItem( + image=None, + caption=caption, + aspects=self.aspects, + pathname=image_path, + flip_p=self.flip_p, + multiplier=multiplier + ) + +class JSONResolver(DataResolver): + def image_train_items(self, json_path: str) -> list[ImageTrainItem]: + """ + Create `ImageTrainItem` objects with metadata for hydration later. + Extracts images and captions from a JSON file. + + :param json_path: The path to the JSON file. + """ + items = [] + with open(json_path, encoding='utf-8', mode='r') as f: + json_data = json.load(f) + + for data in tqdm.tqdm(json_data): + caption = JSONResolver.image_caption(data) + if caption: + image_value = JSONResolver.get_image_value(data) + item = self.image_train_item(image_value, caption) + if item: + items.append(item) + + return items + + @staticmethod + def get_image_value(json_data: dict) -> typing.Optional[str]: + """ + Get the image from the json data if possible. + + :param json_data: The json data, a dict. + :return: The image, or None if not found. + """ + image_value = json_data.get("image", None) + if isinstance(image_value, str): + image_value = image_value.strip() + if os.path.exists(image_value): + return image_value + + @staticmethod + def get_caption_value(json_data: dict) -> typing.Optional[str]: + """ + Get the caption from the json data if possible. + + :param json_data: The json data, a dict. + :return: The caption, or None if not found. + """ + caption_value = json_data.get("caption", None) + if isinstance(caption_value, str): + return caption_value.strip() + + @staticmethod + def image_caption(json_data: dict) -> typing.Optional[ImageCaption]: + """ + Get the caption from the json data if possible. + + :param json_data: The json data, a dict. + :return: The `ImageCaption`, or None if not found. + """ + image_value = JSONResolver.get_image_value(json_data) + caption_value = JSONResolver.get_caption_value(json_data) + if image_value: + if caption_value: + return ImageCaption.resolve(caption_value) + return ImageCaption.from_file(image_value) + + +class DirectoryResolver(DataResolver): + def image_train_items(self, data_root: str) -> list[ImageTrainItem]: + """ + Create `ImageTrainItem` objects with metadata for hydration later. + Unzips all zip files in `data_root` and then recursively searches the + `data_root` for images and captions. + + :param data_root: The root directory to recurse through + """ + DirectoryResolver.unzip_all(data_root) + image_paths = list(DirectoryResolver.recurse_data_root(data_root)) + items = [] + multipliers = {} + skip_folders = [] + randomizer = random.Random(self.seed) + + for pathname in tqdm.tqdm(image_paths): + current_dir = os.path.dirname(pathname) + + if current_dir not in multipliers: + multiply_txt_path = os.path.join(current_dir, "multiply.txt") + if os.path.exists(multiply_txt_path): + try: + with open(multiply_txt_path, 'r') as f: + val = float(f.read().strip()) + multipliers[current_dir] = val + logging.info(f" * DLMA multiply.txt in {current_dir} set to {val}") + except Exception as e: + logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}") + skip_folders.append(current_dir) + multipliers[current_dir] = 1.0 + else: + skip_folders.append(current_dir) + multipliers[current_dir] = 1.0 + + caption = ImageCaption.resolve(pathname) + item = self.image_train_item(pathname, caption, multiplier=multipliers[current_dir]) + + cur_file_multiplier = multipliers[current_dir] + + while cur_file_multiplier >= 1.0: + items.append(item) + cur_file_multiplier -= 1 + + if cur_file_multiplier > 0: + if randomizer.random() < cur_file_multiplier: + items.append(item) + return items + + @staticmethod + def unzip_all(path): + try: + for root, dirs, files in os.walk(path): + for file in files: + if file.endswith('.zip'): + logging.info(f"Unzipping {file}") + with zipfile.ZipFile(path, 'r') as zip_ref: + zip_ref.extractall(path) + except Exception as e: + logging.error(f"Error unzipping files {e}") + + @staticmethod + def recurse_data_root(recurse_root): + for f in os.listdir(recurse_root): + current = os.path.join(recurse_root, f) + + if os.path.isfile(current): + ext = os.path.splitext(f)[1].lower() + if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']: + yield current + + for d in os.listdir(recurse_root): + current = os.path.join(recurse_root, d) + if os.path.isdir(current): + yield from DirectoryResolver.recurse_data_root(current) + + +def strategy(data_root: str): + if os.path.isfile(data_root) and data_root.endswith('.json'): + return JSONResolver + + if os.path.isdir(data_root): + return DirectoryResolver + + raise ValueError(f"data_root '{data_root}' is not a valid directory or JSON file.") + + +def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0, seed=555) -> list[ImageTrainItem]: + """ + :param data_root: Directory or JSON file. + :param aspects: The list of aspect ratios to use + :param flip_p: The probability of flipping the image + """ + if os.path.isfile(path) and path.endswith('.json'): + resolver = JSONResolver(aspects, flip_p, seed) + + if os.path.isdir(path): + resolver = DirectoryResolver(aspects, flip_p, seed) + + if not resolver: + raise ValueError(f"data_root '{path}' is not a valid directory or JSON file.") + + items = resolver.image_train_items(path) + return items + +def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=0.0, seed=555) -> list[ImageTrainItem]: + """ + Resolve the training data from the value. + :param value: The value to resolve, either a dict or a string. + :param aspects: The list of aspect ratios to use + :param flip_p: The probability of flipping the image + """ + if isinstance(value, str): + return resolve_root(value, aspects, flip_p) + + if isinstance(value, dict): + resolver = value.get('resolver', None) + match resolver: + case 'directory' | 'json': + path = value.get('path', None) + return resolve_root(path, aspects, flip_p, seed) + case 'multi': + items = [] + for resolver in value.get('resolvers', []): + items += resolve(resolver, aspects, flip_p, seed) + return items + case _: + raise ValueError(f"Cannot resolve training data for resolver value '{resolver}'") \ No newline at end of file diff --git a/test/data/.gitkeep b/test/data/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/test/test_data_resolver.py b/test/test_data_resolver.py new file mode 100644 index 0000000..625f228 --- /dev/null +++ b/test/test_data_resolver.py @@ -0,0 +1,113 @@ +import json +import glob +import os +import unittest + +import PIL.Image as Image + +import data.aspects as aspects +import data.resolver as resolver + +DATA_PATH = os.path.abspath('./test/data') +JSON_ROOT_PATH = os.path.join(DATA_PATH, 'test_root.json') +ASPECTS = aspects.get_aspect_buckets(512) + +IMAGE_1_PATH = os.path.join(DATA_PATH, 'test1.jpg') +CAPTION_1_PATH = os.path.join(DATA_PATH, 'test1.txt') +IMAGE_2_PATH = os.path.join(DATA_PATH, 'test2.jpg') +IMAGE_3_PATH = os.path.join(DATA_PATH, 'test3.jpg') + +class TestResolve(unittest.TestCase): + @classmethod + def setUpClass(cls): + Image.new('RGB', (512, 512)).save(IMAGE_1_PATH) + with open(CAPTION_1_PATH, 'w') as f: + f.write('caption for test1') + + Image.new('RGB', (512, 512)).save(IMAGE_2_PATH) + # Undersized image + Image.new('RGB', (256, 256)).save(IMAGE_3_PATH) + + json_data = [ + { + 'image': IMAGE_1_PATH, + 'caption': CAPTION_1_PATH + }, + { + 'image': IMAGE_2_PATH, + 'caption': 'caption for test2' + }, + { + 'image': IMAGE_3_PATH, + } + ] + + with open(JSON_ROOT_PATH, 'w') as f: + json.dump(json_data, f, indent=4) + + @classmethod + def tearDownClass(cls): + for file in glob.glob(os.path.join(DATA_PATH, 'test*')): + os.remove(file) + + def test_directory_resolve_with_str(self): + items = resolver.resolve(DATA_PATH, ASPECTS) + image_paths = [item.pathname for item in items] + image_captions = [item.caption for item in items] + captions = [caption.get_caption() for caption in image_captions] + + self.assertEqual(len(items), 3) + self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH]) + self.assertEqual(captions, ['caption for test1', 'test2', 'test3']) + + undersized_images = list(filter(lambda i: i.is_undersized, items)) + self.assertEqual(len(undersized_images), 1) + + def test_directory_resolve_with_dict(self): + data_root_spec = { + 'resolver': 'directory', + 'path': DATA_PATH, + } + + items = resolver.resolve(data_root_spec, ASPECTS) + image_paths = [item.pathname for item in items] + image_captions = [item.caption for item in items] + captions = [caption.get_caption() for caption in image_captions] + + self.assertEqual(len(items), 3) + self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH]) + self.assertEqual(captions, ['caption for test1', 'test2', 'test3']) + + undersized_images = list(filter(lambda i: i.is_undersized, items)) + self.assertEqual(len(undersized_images), 1) + + def test_json_resolve_with_str(self): + items = resolver.resolve(JSON_ROOT_PATH, ASPECTS) + image_paths = [item.pathname for item in items] + image_captions = [item.caption for item in items] + captions = [caption.get_caption() for caption in image_captions] + + self.assertEqual(len(items), 3) + self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH]) + self.assertEqual(captions, ['caption for test1', 'caption for test2', 'test3']) + + undersized_images = list(filter(lambda i: i.is_undersized, items)) + self.assertEqual(len(undersized_images), 1) + + def test_json_resolve_with_dict(self): + data_root_spec = { + 'resolver': 'json', + 'path': JSON_ROOT_PATH, + } + + items = resolver.resolve(data_root_spec, ASPECTS) + image_paths = [item.pathname for item in items] + image_captions = [item.caption for item in items] + captions = [caption.get_caption() for caption in image_captions] + + self.assertEqual(len(items), 3) + self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH]) + self.assertEqual(captions, ['caption for test1', 'caption for test2', 'test3']) + + undersized_images = list(filter(lambda i: i.is_undersized, items)) + self.assertEqual(len(undersized_images), 1) \ No newline at end of file diff --git a/test/test_image_train_item.py b/test/test_image_train_item.py new file mode 100644 index 0000000..bf12b43 --- /dev/null +++ b/test/test_image_train_item.py @@ -0,0 +1,71 @@ +import unittest +import os +import pathlib +import PIL.Image as Image + +from data.image_train_item import ImageCaption, ImageTrainItem + +DATA_PATH = pathlib.Path('./test/data') + +class TestImageCaption(unittest.TestCase): + + def setUp(self) -> None: + with open(DATA_PATH / "test1.txt", encoding='utf-8', mode='w') as f: + f.write("caption for test1") + + Image.new("RGB", (512,512)).save(DATA_PATH / "test1.jpg") + Image.new("RGB", (512,512)).save(DATA_PATH / "test2.jpg") + + with open(DATA_PATH / "test_caption.caption", encoding='utf-8', mode='w') as f: + f.write("caption for test2") + + return super().setUp() + + def tearDown(self) -> None: + for file in DATA_PATH.glob("test*"): + file.unlink() + + return super().tearDown() + + def test_constructor(self): + caption = ImageCaption("hello world", 1.0, ["one", "two", "three"], [1.0]*3, 2048, False) + self.assertEqual(caption.get_caption(), "hello world, one, two, three") + + caption = ImageCaption("hello world", 1.0, [], [], 2048, False) + self.assertEqual(caption.get_caption(), "hello world") + + def test_parse(self): + caption = ImageCaption.parse("hello world, one, two, three") + + self.assertEqual(caption.get_caption(), "hello world, one, two, three") + + def test_from_file_name(self): + caption = ImageCaption.from_file_name("foo bar_1_2_3.jpg") + self.assertEqual(caption.get_caption(), "foo bar") + + def test_from_text_file(self): + caption = ImageCaption.from_text_file("test/data/test1.txt") + self.assertEqual(caption.get_caption(), "caption for test1") + + def test_from_file(self): + caption = ImageCaption.from_file("test/data/test1.txt") + self.assertEqual(caption.get_caption(), "caption for test1") + + caption = ImageCaption.from_file("test/data/test_caption.caption") + self.assertEqual(caption.get_caption(), "caption for test2") + + def test_resolve(self): + caption = ImageCaption.resolve("test/data/test1.txt") + self.assertEqual(caption.get_caption(), "caption for test1") + + caption = ImageCaption.resolve("test/data/test_caption.caption") + self.assertEqual(caption.get_caption(), "caption for test2") + + caption = ImageCaption.resolve("hello world") + self.assertEqual(caption.get_caption(), "hello world") + + caption = ImageCaption.resolve("test/data/test1.jpg") + self.assertEqual(caption.get_caption(), "caption for test1") + + caption = ImageCaption.resolve("test/data/test2.jpg") + self.assertEqual(caption.get_caption(), "test2") \ No newline at end of file