Use data_resolver.resolve for data loading in data_loader

This commit is contained in:
Joel Holdbrooks 2023-01-23 00:15:32 -08:00
parent 0cf2cd71de
commit 316df2db7e
1 changed files with 21 additions and 157 deletions

View File

@ -18,22 +18,15 @@ import math
import os import os
import logging import logging
import yaml
from PIL import Image
import random import random
from data.image_train_item import ImageTrainItem, ImageCaption from data.image_train_item import ImageTrainItem
import data.aspects as aspects import data.aspects as aspects
import data.resolver as resolver import data.resolver as resolver
from data.resolver import DirectoryResolver
from colorama import Fore, Style from colorama import Fore, Style
import zipfile
import tqdm
import PIL import PIL
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
DEFAULT_MAX_CAPTION_LENGTH = 2048
class DataLoaderMultiAspect(): class DataLoaderMultiAspect():
""" """
Data loader for multi-aspect-ratio training and bucketing Data loader for multi-aspect-ratio training and bucketing
@ -43,25 +36,18 @@ class DataLoaderMultiAspect():
flip_p: probability of flipping image horizontally (i.e. 0-0.5) flip_p: probability of flipping image horizontally (i.e. 0-0.5)
""" """
def __init__(self, data_root, seed=555, debug_level=0, batch_size=1, flip_p=0.0, resolution=512, log_folder=None): def __init__(self, data_root, seed=555, debug_level=0, batch_size=1, flip_p=0.0, resolution=512, log_folder=None):
self.image_paths = [] self.data_root = data_root
self.debug_level = debug_level self.debug_level = debug_level
self.flip_p = flip_p self.flip_p = flip_p
self.log_folder = log_folder self.log_folder = log_folder
self.seed = seed self.seed = seed
self.batch_size = batch_size self.batch_size = batch_size
self.has_scanned = False self.has_scanned = False
self.aspects = aspects.get_aspect_buckets(resolution=resolution, square_only=False) self.aspects = aspects.get_aspect_buckets(resolution=resolution, square_only=False)
logging.info(f"* DLMA resolution {resolution}, buckets: {self.aspects}") logging.info(f"* DLMA resolution {resolution}, buckets: {self.aspects}")
logging.info(" Preloading images...") logging.info(" Preloading images...")
self.__prepare_train_data()
DirectoryResolver.unzip_all(data_root)
for image_path in DirectoryResolver.recurse_data_root(data_root):
self.image_paths.append(image_path)
random.Random(seed).shuffle(self.image_paths)
self.prepared_train_data = self.__prescan_images(self.image_paths, flip_p)
(self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings() (self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings()
@ -160,150 +146,28 @@ class DataLoaderMultiAspect():
return rating_overall_sum, ratings_summed return rating_overall_sum, ratings_summed
@staticmethod def __prepare_train_data(self, flip_p=0.0) -> list[ImageTrainItem]:
def __read_caption_from_file(file_path, fallback_caption: ImageCaption) -> ImageCaption:
try:
with open(file_path, encoding='utf-8', mode='r') as caption_file:
caption_text = caption_file.read()
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_text)
except:
logging.error(f" *** Error reading {file_path} to get caption, falling back to filename")
caption = fallback_caption
pass
return caption
@staticmethod
def __read_caption_from_yaml(file_path: str, fallback_caption: ImageCaption) -> ImageCaption:
with open(file_path, "r") as stream:
try:
file_content = yaml.safe_load(stream)
main_prompt = file_content.get("main_prompt", "")
rating = file_content.get("rating", 1.0)
unparsed_tags = file_content.get("tags", [])
max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH)
tags = []
tag_weights = []
last_weight = None
weights_differ = False
for unparsed_tag in unparsed_tags:
tag = unparsed_tag.get("tag", "").strip()
if len(tag) == 0:
continue
tags.append(tag)
tag_weight = unparsed_tag.get("weight", 1.0)
tag_weights.append(tag_weight)
if last_weight is not None and weights_differ is False:
weights_differ = last_weight != tag_weight
last_weight = tag_weight
return ImageCaption(main_prompt, rating, tags, tag_weights, max_caption_length, weights_differ)
except:
logging.error(f" *** Error reading {file_path} to get caption, falling back to filename")
return fallback_caption
@staticmethod
def __split_caption_into_tags(caption_string: str) -> ImageCaption:
"""
Splits a string by "," into the main prompt and additional tags with equal weights
"""
split_caption = caption_string.split(",")
main_prompt = split_caption.pop(0).strip()
tags = []
for tag in split_caption:
tags.append(tag.strip())
return ImageCaption(main_prompt, 1.0, tags, [1.0] * len(tags), DEFAULT_MAX_CAPTION_LENGTH, False)
def __prescan_images(self, image_paths: list, flip_p=0.0) -> list[ImageTrainItem]:
""" """
Create ImageTrainItem objects with metadata for hydration later Create ImageTrainItem objects with metadata for hydration later
""" """
decorated_image_train_items = []
if not self.has_scanned:
undersized_images = []
multipliers = {}
skip_folders = []
for pathname in tqdm.tqdm(image_paths):
caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0]
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_from_filename)
file_path_without_ext = os.path.splitext(pathname)[0]
yaml_file_path = file_path_without_ext + ".yaml"
txt_file_path = file_path_without_ext + ".txt"
caption_file_path = file_path_without_ext + ".caption"
current_dir = os.path.dirname(pathname)
try:
if current_dir not in multipliers:
multiply_txt_path = os.path.join(current_dir, "multiply.txt")
#print(current_dir, multiply_txt_path)
if os.path.exists(multiply_txt_path):
with open(multiply_txt_path, 'r') as f:
val = float(f.read().strip())
multipliers[current_dir] = val
logging.info(f" * DLMA multiply.txt in {current_dir} set to {val}")
else:
skip_folders.append(current_dir)
multipliers[current_dir] = 1.0
except Exception as e:
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}")
skip_folders.append(current_dir)
multipliers[current_dir] = 1.0
if os.path.exists(yaml_file_path):
caption = self.__read_caption_from_yaml(yaml_file_path, caption)
elif os.path.exists(txt_file_path):
caption = self.__read_caption_from_file(txt_file_path, caption)
elif os.path.exists(caption_file_path):
caption = self.__read_caption_from_file(caption_file_path, caption)
try:
image = Image.open(pathname)
width, height = image.size
image_aspect = width / height
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
if not self.has_scanned:
if width * height < target_wh[0] * target_wh[1]:
undersized_images.append(f" {pathname}, size: {width},{height}, target size: {target_wh}")
image_train_item = ImageTrainItem(image=None, # image loaded at runtime to apply jitter
caption=caption,
target_wh=target_wh,
pathname=pathname,
flip_p=flip_p,
multiplier=multipliers[current_dir],
)
decorated_image_train_items.append(image_train_item)
except Exception as e:
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
logging.error(f" *** exception: {e}")
pass
if not self.has_scanned: if not self.has_scanned:
self.has_scanned = True self.has_scanned = True
if len(undersized_images) > 0: self.prepared_train_data, events = resolver.resolve(self.data_root, self.aspects, flip_p=flip_p)
underized_log_path = os.path.join(self.log_folder, "undersized_images.txt") random.Random(self.seed).shuffle(self.prepared_train_data)
logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}") self.__report_undersized_images(events)
logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}")
with open(underized_log_path, "w") as undersized_images_file: def __report_undersized_images(self, events: list[resolver.Event]):
undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:") events = [event for event in events if isinstance(event, resolver.UndersizedImageEvent)]
for undersized_image in undersized_images:
undersized_images_file.write(f"{undersized_image}\n") if len(events) > 0:
underized_log_path = os.path.join(self.log_folder, "undersized_images.txt")
return decorated_image_train_items logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}")
logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}")
with open(underized_log_path, "w") as undersized_images_file:
undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:")
for event in events:
message = f" *** {event.image_path} with size: {event.image_size} is smaller than target size: {event.target_size}, consider using larger images"
undersized_images_file.write(message)
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]: def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
""" """