diff --git a/data/data_loader.py b/data/data_loader.py index ae276bd..0c0fb3c 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -18,22 +18,15 @@ import math import os import logging -import yaml -from PIL import Image import random -from data.image_train_item import ImageTrainItem, ImageCaption +from data.image_train_item import ImageTrainItem import data.aspects as aspects import data.resolver as resolver -from data.resolver import DirectoryResolver from colorama import Fore, Style -import zipfile -import tqdm import PIL PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default -DEFAULT_MAX_CAPTION_LENGTH = 2048 - class DataLoaderMultiAspect(): """ Data loader for multi-aspect-ratio training and bucketing @@ -43,25 +36,18 @@ class DataLoaderMultiAspect(): flip_p: probability of flipping image horizontally (i.e. 0-0.5) """ def __init__(self, data_root, seed=555, debug_level=0, batch_size=1, flip_p=0.0, resolution=512, log_folder=None): - self.image_paths = [] + self.data_root = data_root self.debug_level = debug_level self.flip_p = flip_p self.log_folder = log_folder self.seed = seed self.batch_size = batch_size self.has_scanned = False - self.aspects = aspects.get_aspect_buckets(resolution=resolution, square_only=False) + logging.info(f"* DLMA resolution {resolution}, buckets: {self.aspects}") logging.info(" Preloading images...") - - DirectoryResolver.unzip_all(data_root) - - for image_path in DirectoryResolver.recurse_data_root(data_root): - self.image_paths.append(image_path) - - random.Random(seed).shuffle(self.image_paths) - self.prepared_train_data = self.__prescan_images(self.image_paths, flip_p) + self.__prepare_train_data() (self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings() @@ -160,150 +146,28 @@ class DataLoaderMultiAspect(): return rating_overall_sum, ratings_summed - @staticmethod - def __read_caption_from_file(file_path, fallback_caption: ImageCaption) -> ImageCaption: - try: - with open(file_path, encoding='utf-8', mode='r') as caption_file: - caption_text = caption_file.read() - caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_text) - except: - logging.error(f" *** Error reading {file_path} to get caption, falling back to filename") - caption = fallback_caption - pass - return caption - - @staticmethod - def __read_caption_from_yaml(file_path: str, fallback_caption: ImageCaption) -> ImageCaption: - with open(file_path, "r") as stream: - try: - file_content = yaml.safe_load(stream) - main_prompt = file_content.get("main_prompt", "") - rating = file_content.get("rating", 1.0) - unparsed_tags = file_content.get("tags", []) - - max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH) - - tags = [] - tag_weights = [] - last_weight = None - weights_differ = False - for unparsed_tag in unparsed_tags: - tag = unparsed_tag.get("tag", "").strip() - if len(tag) == 0: - continue - - tags.append(tag) - tag_weight = unparsed_tag.get("weight", 1.0) - tag_weights.append(tag_weight) - - if last_weight is not None and weights_differ is False: - weights_differ = last_weight != tag_weight - - last_weight = tag_weight - - return ImageCaption(main_prompt, rating, tags, tag_weights, max_caption_length, weights_differ) - - except: - logging.error(f" *** Error reading {file_path} to get caption, falling back to filename") - return fallback_caption - - @staticmethod - def __split_caption_into_tags(caption_string: str) -> ImageCaption: - """ - Splits a string by "," into the main prompt and additional tags with equal weights - """ - split_caption = caption_string.split(",") - main_prompt = split_caption.pop(0).strip() - tags = [] - for tag in split_caption: - tags.append(tag.strip()) - - return ImageCaption(main_prompt, 1.0, tags, [1.0] * len(tags), DEFAULT_MAX_CAPTION_LENGTH, False) - - def __prescan_images(self, image_paths: list, flip_p=0.0) -> list[ImageTrainItem]: + def __prepare_train_data(self, flip_p=0.0) -> list[ImageTrainItem]: """ Create ImageTrainItem objects with metadata for hydration later """ - decorated_image_train_items = [] - - if not self.has_scanned: - undersized_images = [] - - multipliers = {} - skip_folders = [] - - for pathname in tqdm.tqdm(image_paths): - caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0] - caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_from_filename) - - file_path_without_ext = os.path.splitext(pathname)[0] - yaml_file_path = file_path_without_ext + ".yaml" - txt_file_path = file_path_without_ext + ".txt" - caption_file_path = file_path_without_ext + ".caption" - - current_dir = os.path.dirname(pathname) - - try: - if current_dir not in multipliers: - multiply_txt_path = os.path.join(current_dir, "multiply.txt") - #print(current_dir, multiply_txt_path) - if os.path.exists(multiply_txt_path): - with open(multiply_txt_path, 'r') as f: - val = float(f.read().strip()) - multipliers[current_dir] = val - logging.info(f" * DLMA multiply.txt in {current_dir} set to {val}") - else: - skip_folders.append(current_dir) - multipliers[current_dir] = 1.0 - except Exception as e: - logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}") - skip_folders.append(current_dir) - multipliers[current_dir] = 1.0 - - if os.path.exists(yaml_file_path): - caption = self.__read_caption_from_yaml(yaml_file_path, caption) - elif os.path.exists(txt_file_path): - caption = self.__read_caption_from_file(txt_file_path, caption) - elif os.path.exists(caption_file_path): - caption = self.__read_caption_from_file(caption_file_path, caption) - - try: - image = Image.open(pathname) - width, height = image.size - image_aspect = width / height - - target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect)) - if not self.has_scanned: - if width * height < target_wh[0] * target_wh[1]: - undersized_images.append(f" {pathname}, size: {width},{height}, target size: {target_wh}") - - image_train_item = ImageTrainItem(image=None, # image loaded at runtime to apply jitter - caption=caption, - target_wh=target_wh, - pathname=pathname, - flip_p=flip_p, - multiplier=multipliers[current_dir], - ) - - decorated_image_train_items.append(image_train_item) - - except Exception as e: - logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}") - logging.error(f" *** exception: {e}") - pass - if not self.has_scanned: self.has_scanned = True - if len(undersized_images) > 0: - underized_log_path = os.path.join(self.log_folder, "undersized_images.txt") - logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}") - logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}") - with open(underized_log_path, "w") as undersized_images_file: - undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:") - for undersized_image in undersized_images: - undersized_images_file.write(f"{undersized_image}\n") - - return decorated_image_train_items + self.prepared_train_data, events = resolver.resolve(self.data_root, self.aspects, flip_p=flip_p) + random.Random(self.seed).shuffle(self.prepared_train_data) + self.__report_undersized_images(events) + + def __report_undersized_images(self, events: list[resolver.Event]): + events = [event for event in events if isinstance(event, resolver.UndersizedImageEvent)] + + if len(events) > 0: + underized_log_path = os.path.join(self.log_folder, "undersized_images.txt") + logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}") + logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}") + with open(underized_log_path, "w") as undersized_images_file: + undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:") + for event in events: + message = f" *** {event.image_path} with size: {event.image_size} is smaller than target size: {event.target_size}, consider using larger images" + undersized_images_file.write(message) def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]: """