Use data_resolver.resolve for data loading in data_loader
This commit is contained in:
parent
0cf2cd71de
commit
316df2db7e
|
@ -18,22 +18,15 @@ import math
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import yaml
|
|
||||||
from PIL import Image
|
|
||||||
import random
|
import random
|
||||||
from data.image_train_item import ImageTrainItem, ImageCaption
|
from data.image_train_item import ImageTrainItem
|
||||||
import data.aspects as aspects
|
import data.aspects as aspects
|
||||||
import data.resolver as resolver
|
import data.resolver as resolver
|
||||||
from data.resolver import DirectoryResolver
|
|
||||||
from colorama import Fore, Style
|
from colorama import Fore, Style
|
||||||
import zipfile
|
|
||||||
import tqdm
|
|
||||||
import PIL
|
import PIL
|
||||||
|
|
||||||
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
|
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
|
||||||
|
|
||||||
DEFAULT_MAX_CAPTION_LENGTH = 2048
|
|
||||||
|
|
||||||
class DataLoaderMultiAspect():
|
class DataLoaderMultiAspect():
|
||||||
"""
|
"""
|
||||||
Data loader for multi-aspect-ratio training and bucketing
|
Data loader for multi-aspect-ratio training and bucketing
|
||||||
|
@ -43,25 +36,18 @@ class DataLoaderMultiAspect():
|
||||||
flip_p: probability of flipping image horizontally (i.e. 0-0.5)
|
flip_p: probability of flipping image horizontally (i.e. 0-0.5)
|
||||||
"""
|
"""
|
||||||
def __init__(self, data_root, seed=555, debug_level=0, batch_size=1, flip_p=0.0, resolution=512, log_folder=None):
|
def __init__(self, data_root, seed=555, debug_level=0, batch_size=1, flip_p=0.0, resolution=512, log_folder=None):
|
||||||
self.image_paths = []
|
self.data_root = data_root
|
||||||
self.debug_level = debug_level
|
self.debug_level = debug_level
|
||||||
self.flip_p = flip_p
|
self.flip_p = flip_p
|
||||||
self.log_folder = log_folder
|
self.log_folder = log_folder
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.has_scanned = False
|
self.has_scanned = False
|
||||||
|
|
||||||
self.aspects = aspects.get_aspect_buckets(resolution=resolution, square_only=False)
|
self.aspects = aspects.get_aspect_buckets(resolution=resolution, square_only=False)
|
||||||
|
|
||||||
logging.info(f"* DLMA resolution {resolution}, buckets: {self.aspects}")
|
logging.info(f"* DLMA resolution {resolution}, buckets: {self.aspects}")
|
||||||
logging.info(" Preloading images...")
|
logging.info(" Preloading images...")
|
||||||
|
self.__prepare_train_data()
|
||||||
DirectoryResolver.unzip_all(data_root)
|
|
||||||
|
|
||||||
for image_path in DirectoryResolver.recurse_data_root(data_root):
|
|
||||||
self.image_paths.append(image_path)
|
|
||||||
|
|
||||||
random.Random(seed).shuffle(self.image_paths)
|
|
||||||
self.prepared_train_data = self.__prescan_images(self.image_paths, flip_p)
|
|
||||||
(self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings()
|
(self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings()
|
||||||
|
|
||||||
|
|
||||||
|
@ -160,150 +146,28 @@ class DataLoaderMultiAspect():
|
||||||
|
|
||||||
return rating_overall_sum, ratings_summed
|
return rating_overall_sum, ratings_summed
|
||||||
|
|
||||||
@staticmethod
|
def __prepare_train_data(self, flip_p=0.0) -> list[ImageTrainItem]:
|
||||||
def __read_caption_from_file(file_path, fallback_caption: ImageCaption) -> ImageCaption:
|
|
||||||
try:
|
|
||||||
with open(file_path, encoding='utf-8', mode='r') as caption_file:
|
|
||||||
caption_text = caption_file.read()
|
|
||||||
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_text)
|
|
||||||
except:
|
|
||||||
logging.error(f" *** Error reading {file_path} to get caption, falling back to filename")
|
|
||||||
caption = fallback_caption
|
|
||||||
pass
|
|
||||||
return caption
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def __read_caption_from_yaml(file_path: str, fallback_caption: ImageCaption) -> ImageCaption:
|
|
||||||
with open(file_path, "r") as stream:
|
|
||||||
try:
|
|
||||||
file_content = yaml.safe_load(stream)
|
|
||||||
main_prompt = file_content.get("main_prompt", "")
|
|
||||||
rating = file_content.get("rating", 1.0)
|
|
||||||
unparsed_tags = file_content.get("tags", [])
|
|
||||||
|
|
||||||
max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH)
|
|
||||||
|
|
||||||
tags = []
|
|
||||||
tag_weights = []
|
|
||||||
last_weight = None
|
|
||||||
weights_differ = False
|
|
||||||
for unparsed_tag in unparsed_tags:
|
|
||||||
tag = unparsed_tag.get("tag", "").strip()
|
|
||||||
if len(tag) == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
tags.append(tag)
|
|
||||||
tag_weight = unparsed_tag.get("weight", 1.0)
|
|
||||||
tag_weights.append(tag_weight)
|
|
||||||
|
|
||||||
if last_weight is not None and weights_differ is False:
|
|
||||||
weights_differ = last_weight != tag_weight
|
|
||||||
|
|
||||||
last_weight = tag_weight
|
|
||||||
|
|
||||||
return ImageCaption(main_prompt, rating, tags, tag_weights, max_caption_length, weights_differ)
|
|
||||||
|
|
||||||
except:
|
|
||||||
logging.error(f" *** Error reading {file_path} to get caption, falling back to filename")
|
|
||||||
return fallback_caption
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def __split_caption_into_tags(caption_string: str) -> ImageCaption:
|
|
||||||
"""
|
|
||||||
Splits a string by "," into the main prompt and additional tags with equal weights
|
|
||||||
"""
|
|
||||||
split_caption = caption_string.split(",")
|
|
||||||
main_prompt = split_caption.pop(0).strip()
|
|
||||||
tags = []
|
|
||||||
for tag in split_caption:
|
|
||||||
tags.append(tag.strip())
|
|
||||||
|
|
||||||
return ImageCaption(main_prompt, 1.0, tags, [1.0] * len(tags), DEFAULT_MAX_CAPTION_LENGTH, False)
|
|
||||||
|
|
||||||
def __prescan_images(self, image_paths: list, flip_p=0.0) -> list[ImageTrainItem]:
|
|
||||||
"""
|
"""
|
||||||
Create ImageTrainItem objects with metadata for hydration later
|
Create ImageTrainItem objects with metadata for hydration later
|
||||||
"""
|
"""
|
||||||
decorated_image_train_items = []
|
|
||||||
|
|
||||||
if not self.has_scanned:
|
|
||||||
undersized_images = []
|
|
||||||
|
|
||||||
multipliers = {}
|
|
||||||
skip_folders = []
|
|
||||||
|
|
||||||
for pathname in tqdm.tqdm(image_paths):
|
|
||||||
caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0]
|
|
||||||
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_from_filename)
|
|
||||||
|
|
||||||
file_path_without_ext = os.path.splitext(pathname)[0]
|
|
||||||
yaml_file_path = file_path_without_ext + ".yaml"
|
|
||||||
txt_file_path = file_path_without_ext + ".txt"
|
|
||||||
caption_file_path = file_path_without_ext + ".caption"
|
|
||||||
|
|
||||||
current_dir = os.path.dirname(pathname)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if current_dir not in multipliers:
|
|
||||||
multiply_txt_path = os.path.join(current_dir, "multiply.txt")
|
|
||||||
#print(current_dir, multiply_txt_path)
|
|
||||||
if os.path.exists(multiply_txt_path):
|
|
||||||
with open(multiply_txt_path, 'r') as f:
|
|
||||||
val = float(f.read().strip())
|
|
||||||
multipliers[current_dir] = val
|
|
||||||
logging.info(f" * DLMA multiply.txt in {current_dir} set to {val}")
|
|
||||||
else:
|
|
||||||
skip_folders.append(current_dir)
|
|
||||||
multipliers[current_dir] = 1.0
|
|
||||||
except Exception as e:
|
|
||||||
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}")
|
|
||||||
skip_folders.append(current_dir)
|
|
||||||
multipliers[current_dir] = 1.0
|
|
||||||
|
|
||||||
if os.path.exists(yaml_file_path):
|
|
||||||
caption = self.__read_caption_from_yaml(yaml_file_path, caption)
|
|
||||||
elif os.path.exists(txt_file_path):
|
|
||||||
caption = self.__read_caption_from_file(txt_file_path, caption)
|
|
||||||
elif os.path.exists(caption_file_path):
|
|
||||||
caption = self.__read_caption_from_file(caption_file_path, caption)
|
|
||||||
|
|
||||||
try:
|
|
||||||
image = Image.open(pathname)
|
|
||||||
width, height = image.size
|
|
||||||
image_aspect = width / height
|
|
||||||
|
|
||||||
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
|
|
||||||
if not self.has_scanned:
|
|
||||||
if width * height < target_wh[0] * target_wh[1]:
|
|
||||||
undersized_images.append(f" {pathname}, size: {width},{height}, target size: {target_wh}")
|
|
||||||
|
|
||||||
image_train_item = ImageTrainItem(image=None, # image loaded at runtime to apply jitter
|
|
||||||
caption=caption,
|
|
||||||
target_wh=target_wh,
|
|
||||||
pathname=pathname,
|
|
||||||
flip_p=flip_p,
|
|
||||||
multiplier=multipliers[current_dir],
|
|
||||||
)
|
|
||||||
|
|
||||||
decorated_image_train_items.append(image_train_item)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
|
|
||||||
logging.error(f" *** exception: {e}")
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not self.has_scanned:
|
if not self.has_scanned:
|
||||||
self.has_scanned = True
|
self.has_scanned = True
|
||||||
if len(undersized_images) > 0:
|
self.prepared_train_data, events = resolver.resolve(self.data_root, self.aspects, flip_p=flip_p)
|
||||||
underized_log_path = os.path.join(self.log_folder, "undersized_images.txt")
|
random.Random(self.seed).shuffle(self.prepared_train_data)
|
||||||
logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}")
|
self.__report_undersized_images(events)
|
||||||
logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}")
|
|
||||||
with open(underized_log_path, "w") as undersized_images_file:
|
def __report_undersized_images(self, events: list[resolver.Event]):
|
||||||
undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:")
|
events = [event for event in events if isinstance(event, resolver.UndersizedImageEvent)]
|
||||||
for undersized_image in undersized_images:
|
|
||||||
undersized_images_file.write(f"{undersized_image}\n")
|
if len(events) > 0:
|
||||||
|
underized_log_path = os.path.join(self.log_folder, "undersized_images.txt")
|
||||||
return decorated_image_train_items
|
logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}")
|
||||||
|
logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}")
|
||||||
|
with open(underized_log_path, "w") as undersized_images_file:
|
||||||
|
undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:")
|
||||||
|
for event in events:
|
||||||
|
message = f" *** {event.image_path} with size: {event.image_size} is smaller than target size: {event.target_size}, consider using larger images"
|
||||||
|
undersized_images_file.write(message)
|
||||||
|
|
||||||
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
|
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue