Merge pull request #24 from noprompt/refactor-data-resolution
Refactor data resolution
This commit is contained in:
commit
eea899f4a0
|
@ -19,20 +19,15 @@ import os
|
|||
import logging
|
||||
import copy
|
||||
|
||||
import yaml
|
||||
from PIL import Image
|
||||
import random
|
||||
from data.image_train_item import ImageTrainItem, ImageCaption
|
||||
from data.image_train_item import ImageTrainItem
|
||||
import data.aspects as aspects
|
||||
import data.resolver as resolver
|
||||
from colorama import Fore, Style
|
||||
import zipfile
|
||||
import tqdm
|
||||
import PIL
|
||||
|
||||
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
|
||||
|
||||
DEFAULT_MAX_CAPTION_LENGTH = 2048
|
||||
|
||||
class DataLoaderMultiAspect():
|
||||
"""
|
||||
Data loader for multi-aspect-ratio training and bucketing
|
||||
|
@ -42,24 +37,17 @@ class DataLoaderMultiAspect():
|
|||
flip_p: probability of flipping image horizontally (i.e. 0-0.5)
|
||||
"""
|
||||
def __init__(self, data_root, seed=555, debug_level=0, batch_size=1, flip_p=0.0, resolution=512, log_folder=None):
|
||||
self.image_paths = []
|
||||
self.data_root = data_root
|
||||
self.debug_level = debug_level
|
||||
self.flip_p = flip_p
|
||||
self.log_folder = log_folder
|
||||
self.seed = seed
|
||||
self.batch_size = batch_size
|
||||
self.has_scanned = False
|
||||
|
||||
self.aspects = aspects.get_aspect_buckets(resolution=resolution, square_only=False)
|
||||
|
||||
logging.info(f"* DLMA resolution {resolution}, buckets: {self.aspects}")
|
||||
logging.info(" Preloading images...")
|
||||
|
||||
self.unzip_all(data_root)
|
||||
|
||||
self.__recurse_data_root(self=self, recurse_root=data_root)
|
||||
random.Random(seed).shuffle(self.image_paths)
|
||||
self.prepared_train_data = self.__prescan_images(self.image_paths, flip_p)
|
||||
print(f"DLMA Loaded {len(self.prepared_train_data)} images")
|
||||
self.__prepare_train_data()
|
||||
(self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings()
|
||||
|
||||
|
||||
|
@ -152,18 +140,6 @@ class DataLoaderMultiAspect():
|
|||
|
||||
return image_caption_pairs
|
||||
|
||||
@staticmethod
|
||||
def unzip_all(path):
|
||||
try:
|
||||
for root, dirs, files in os.walk(path):
|
||||
for file in files:
|
||||
if file.endswith('.zip'):
|
||||
logging.info(f"Unzipping {file}")
|
||||
with zipfile.ZipFile(path, 'r') as zip_ref:
|
||||
zip_ref.extractall(path)
|
||||
except Exception as e:
|
||||
logging.error(f"Error unzipping files {e}")
|
||||
|
||||
def __sort_and_precalc_image_ratings(self) -> tuple[float, list[float]]:
|
||||
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
|
||||
|
||||
|
@ -175,161 +151,44 @@ class DataLoaderMultiAspect():
|
|||
|
||||
return rating_overall_sum, ratings_summed
|
||||
|
||||
@staticmethod
|
||||
def __read_caption_from_file(file_path, fallback_caption: ImageCaption) -> ImageCaption:
|
||||
try:
|
||||
with open(file_path, encoding='utf-8', mode='r') as caption_file:
|
||||
caption_text = caption_file.read()
|
||||
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_text)
|
||||
except:
|
||||
logging.error(f" *** Error reading {file_path} to get caption, falling back to filename")
|
||||
caption = fallback_caption
|
||||
pass
|
||||
return caption
|
||||
|
||||
@staticmethod
|
||||
def __read_caption_from_yaml(file_path: str, fallback_caption: ImageCaption) -> ImageCaption:
|
||||
with open(file_path, "r") as stream:
|
||||
try:
|
||||
file_content = yaml.safe_load(stream)
|
||||
main_prompt = file_content.get("main_prompt", "")
|
||||
rating = file_content.get("rating", 1.0)
|
||||
unparsed_tags = file_content.get("tags", [])
|
||||
|
||||
max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH)
|
||||
|
||||
tags = []
|
||||
tag_weights = []
|
||||
last_weight = None
|
||||
weights_differ = False
|
||||
for unparsed_tag in unparsed_tags:
|
||||
tag = unparsed_tag.get("tag", "").strip()
|
||||
if len(tag) == 0:
|
||||
continue
|
||||
|
||||
tags.append(tag)
|
||||
tag_weight = unparsed_tag.get("weight", 1.0)
|
||||
tag_weights.append(tag_weight)
|
||||
|
||||
if last_weight is not None and weights_differ is False:
|
||||
weights_differ = last_weight != tag_weight
|
||||
|
||||
last_weight = tag_weight
|
||||
|
||||
return ImageCaption(main_prompt, rating, tags, tag_weights, max_caption_length, weights_differ)
|
||||
|
||||
except:
|
||||
logging.error(f" *** Error reading {file_path} to get caption, falling back to filename")
|
||||
return fallback_caption
|
||||
|
||||
@staticmethod
|
||||
def __split_caption_into_tags(caption_string: str) -> ImageCaption:
|
||||
"""
|
||||
Splits a string by "," into the main prompt and additional tags with equal weights
|
||||
"""
|
||||
split_caption = caption_string.split(",")
|
||||
main_prompt = split_caption.pop(0).strip()
|
||||
tags = []
|
||||
for tag in split_caption:
|
||||
tags.append(tag.strip())
|
||||
|
||||
return ImageCaption(main_prompt, 1.0, tags, [1.0] * len(tags), DEFAULT_MAX_CAPTION_LENGTH, False)
|
||||
|
||||
def __prescan_images(self, image_paths: list, flip_p=0.0) -> list[ImageTrainItem]:
|
||||
def __prepare_train_data(self, flip_p=0.0) -> list[ImageTrainItem]:
|
||||
"""
|
||||
Create ImageTrainItem objects with metadata for hydration later
|
||||
"""
|
||||
decorated_image_train_items = []
|
||||
|
||||
if not self.has_scanned:
|
||||
undersized_images = []
|
||||
|
||||
multipliers = {}
|
||||
skip_folders = []
|
||||
randomizer = random.Random(self.seed)
|
||||
|
||||
for pathname in tqdm.tqdm(image_paths):
|
||||
caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0]
|
||||
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_from_filename)
|
||||
|
||||
file_path_without_ext = os.path.splitext(pathname)[0]
|
||||
yaml_file_path = file_path_without_ext + ".yaml"
|
||||
txt_file_path = file_path_without_ext + ".txt"
|
||||
caption_file_path = file_path_without_ext + ".caption"
|
||||
|
||||
current_dir = os.path.dirname(pathname)
|
||||
|
||||
try:
|
||||
if current_dir not in multipliers:
|
||||
multiply_txt_path = os.path.join(current_dir, "multiply.txt")
|
||||
#print(current_dir, multiply_txt_path)
|
||||
if os.path.exists(multiply_txt_path):
|
||||
with open(multiply_txt_path, 'r') as f:
|
||||
val = float(f.read().strip())
|
||||
multipliers[current_dir] = val
|
||||
logging.info(f" * DLMA multiply.txt in {current_dir} set to {val}")
|
||||
else:
|
||||
skip_folders.append(current_dir)
|
||||
multipliers[current_dir] = 1.0
|
||||
except Exception as e:
|
||||
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}")
|
||||
skip_folders.append(current_dir)
|
||||
multipliers[current_dir] = 1.0
|
||||
|
||||
if os.path.exists(yaml_file_path):
|
||||
caption = self.__read_caption_from_yaml(yaml_file_path, caption)
|
||||
elif os.path.exists(txt_file_path):
|
||||
caption = self.__read_caption_from_file(txt_file_path, caption)
|
||||
elif os.path.exists(caption_file_path):
|
||||
caption = self.__read_caption_from_file(caption_file_path, caption)
|
||||
|
||||
try:
|
||||
image = Image.open(pathname)
|
||||
width, height = image.size
|
||||
image_aspect = width / height
|
||||
|
||||
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
|
||||
if not self.has_scanned:
|
||||
if width * height < target_wh[0] * target_wh[1]:
|
||||
undersized_images.append(f" {pathname}, size: {width},{height}, target size: {target_wh}")
|
||||
|
||||
image_train_item = ImageTrainItem(image=None, # image loaded at runtime to apply jitter
|
||||
caption=caption,
|
||||
target_wh=target_wh,
|
||||
pathname=pathname,
|
||||
flip_p=flip_p,
|
||||
multiplier=multipliers[current_dir],
|
||||
)
|
||||
|
||||
cur_file_multiplier = multipliers[current_dir]
|
||||
|
||||
while cur_file_multiplier >= 1.0:
|
||||
decorated_image_train_items.append(image_train_item)
|
||||
cur_file_multiplier -= 1
|
||||
|
||||
if cur_file_multiplier > 0:
|
||||
if randomizer.random() < cur_file_multiplier:
|
||||
decorated_image_train_items.append(image_train_item)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
|
||||
logging.error(f" *** exception: {e}")
|
||||
pass
|
||||
|
||||
if not self.has_scanned:
|
||||
self.has_scanned = True
|
||||
if len(undersized_images) > 0:
|
||||
underized_log_path = os.path.join(self.log_folder, "undersized_images.txt")
|
||||
logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}")
|
||||
logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}")
|
||||
with open(underized_log_path, "w") as undersized_images_file:
|
||||
undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:")
|
||||
for undersized_image in undersized_images:
|
||||
undersized_images_file.write(f"{undersized_image}\n")
|
||||
|
||||
print (f" * DLMA: {len(decorated_image_train_items)} images loaded from {len(image_paths)} files")
|
||||
|
||||
logging.info(" Preloading images...")
|
||||
|
||||
items = resolver.resolve(self.data_root, self.aspects, flip_p=flip_p, seed=self.seed)
|
||||
image_paths = set(map(lambda item: item.pathname, items))
|
||||
|
||||
print (f" * DLMA: {len(items)} images loaded from {len(image_paths)} files")
|
||||
|
||||
self.prepared_train_data = items
|
||||
random.Random(self.seed).shuffle(self.prepared_train_data)
|
||||
self.__report_errors(items)
|
||||
|
||||
def __report_errors(self, items: list[ImageTrainItem]):
|
||||
for item in items:
|
||||
if item.error is not None:
|
||||
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
|
||||
logging.error(f" *** exception: {item.error}")
|
||||
|
||||
undersized_items = [item for item in items if item.is_undersized]
|
||||
|
||||
if len(undersized_items) > 0:
|
||||
underized_log_path = os.path.join(self.log_folder, "undersized_images.txt")
|
||||
logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}")
|
||||
logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}")
|
||||
with open(underized_log_path, "w") as undersized_images_file:
|
||||
undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:")
|
||||
for undersized_item in undersized_items:
|
||||
message = f" *** {undersized_item.pathname} with size: {undersized_item.image_size} is smaller than target size: {undersized_item.target_wh}, consider using larger images"
|
||||
undersized_images_file.write(message)
|
||||
|
||||
|
||||
return decorated_image_train_items
|
||||
|
||||
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
|
||||
"""
|
||||
|
@ -367,23 +226,3 @@ class DataLoaderMultiAspect():
|
|||
prepared_train_data.pop(pos)
|
||||
|
||||
return picked_images
|
||||
|
||||
@staticmethod
|
||||
def __recurse_data_root(self, recurse_root):
|
||||
for f in os.listdir(recurse_root):
|
||||
current = os.path.join(recurse_root, f)
|
||||
|
||||
if os.path.isfile(current):
|
||||
ext = os.path.splitext(f)[1].lower()
|
||||
if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']:
|
||||
self.image_paths.append(current)
|
||||
|
||||
sub_dirs = []
|
||||
|
||||
for d in os.listdir(recurse_root):
|
||||
current = os.path.join(recurse_root, d)
|
||||
if os.path.isdir(current):
|
||||
sub_dirs.append(current)
|
||||
|
||||
for dir in sub_dirs:
|
||||
self.__recurse_data_root(self=self, recurse_root=dir)
|
||||
|
|
|
@ -18,13 +18,19 @@ import logging
|
|||
import math
|
||||
import os
|
||||
import random
|
||||
import typing
|
||||
import yaml
|
||||
|
||||
import PIL
|
||||
import PIL.Image as Image
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
|
||||
_RANDOM_TRIM = 0.04
|
||||
|
||||
DEFAULT_MAX_CAPTION_LENGTH = 2048
|
||||
|
||||
OptionalImageCaption = typing.Optional['ImageCaption']
|
||||
|
||||
class ImageCaption:
|
||||
"""
|
||||
|
@ -60,17 +66,21 @@ class ImageCaption:
|
|||
:param seed used to initialize the randomizer
|
||||
:return: generated caption string
|
||||
"""
|
||||
max_target_tag_length = self.__max_target_length - len(self.__main_prompt)
|
||||
if self.__tags:
|
||||
max_target_tag_length = self.__max_target_length - len(self.__main_prompt)
|
||||
|
||||
if self.__use_weights:
|
||||
tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length)
|
||||
else:
|
||||
tags_caption = self.__get_shuffled_tags(seed, self.__tags)
|
||||
if self.__use_weights:
|
||||
tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length)
|
||||
else:
|
||||
tags_caption = self.__get_shuffled_tags(seed, self.__tags)
|
||||
|
||||
return self.__main_prompt + ", " + tags_caption
|
||||
return self.__main_prompt + ", " + tags_caption
|
||||
return self.__main_prompt
|
||||
|
||||
def get_caption(self) -> str:
|
||||
return self.__main_prompt + ", " + ", ".join(self.__tags)
|
||||
if self.__tags:
|
||||
return self.__main_prompt + ", " + ", ".join(self.__tags)
|
||||
return self.__main_prompt
|
||||
|
||||
@staticmethod
|
||||
def __get_weighted_shuffled_tags(seed: int, tags: list[str], weights: list[float], max_target_tag_length: int) -> str:
|
||||
|
@ -91,7 +101,14 @@ class ImageCaption:
|
|||
|
||||
weights_copy.pop(pos)
|
||||
tag = tags_copy.pop(pos)
|
||||
caption += ", " + tag
|
||||
|
||||
if caption:
|
||||
caption += ", "
|
||||
caption += tag
|
||||
|
||||
if caption:
|
||||
caption += ", "
|
||||
caption += tag
|
||||
|
||||
return caption
|
||||
|
||||
|
@ -100,6 +117,136 @@ class ImageCaption:
|
|||
random.Random(seed).shuffle(tags)
|
||||
return ", ".join(tags)
|
||||
|
||||
@staticmethod
|
||||
def parse(string: str) -> 'ImageCaption':
|
||||
"""
|
||||
Parses a string to get the caption.
|
||||
|
||||
:param string: String to parse.
|
||||
:return: `ImageCaption` object.
|
||||
"""
|
||||
split_caption = list(map(str.strip, string.split(",")))
|
||||
main_prompt = split_caption[0]
|
||||
tags = split_caption[1:]
|
||||
tag_weights = [1.0] * len(tags)
|
||||
|
||||
return ImageCaption(main_prompt, 1.0, tags, tag_weights, DEFAULT_MAX_CAPTION_LENGTH, False)
|
||||
|
||||
@staticmethod
|
||||
def from_file_name(file_path: str) -> 'ImageCaption':
|
||||
"""
|
||||
Parses the file name to get the caption.
|
||||
|
||||
:param file_path: Path to the image file.
|
||||
:return: `ImageCaption` object.
|
||||
"""
|
||||
(file_name, _) = os.path.splitext(os.path.basename(file_path))
|
||||
caption = file_name.split("_")[0]
|
||||
return ImageCaption.parse(caption)
|
||||
|
||||
@staticmethod
|
||||
def from_text_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
|
||||
"""
|
||||
Parses a text file to get the caption. Returns the default caption if
|
||||
the file does not exist or is invalid.
|
||||
|
||||
:param file_path: Path to the text file.
|
||||
:param default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid.
|
||||
:return: `ImageCaption` object or `None`.
|
||||
"""
|
||||
try:
|
||||
with open(file_path, encoding='utf-8', mode='r') as caption_file:
|
||||
caption_text = caption_file.read()
|
||||
return ImageCaption.parse(caption_text)
|
||||
except:
|
||||
logging.error(f" *** Error reading {file_path} to get caption")
|
||||
return default_caption
|
||||
|
||||
@staticmethod
|
||||
def from_yaml_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
|
||||
"""
|
||||
Parses a yaml file to get the caption. Returns the default caption if
|
||||
the file does not exist or is invalid.
|
||||
|
||||
:param file_path: path to the yaml file
|
||||
:param default_caption: caption to return if the file does not exist or is invalid
|
||||
:return: `ImageCaption` object or `None`.
|
||||
"""
|
||||
try:
|
||||
with open(file_path, "r") as stream:
|
||||
file_content = yaml.safe_load(stream)
|
||||
main_prompt = file_content.get("main_prompt", "")
|
||||
rating = file_content.get("rating", 1.0)
|
||||
unparsed_tags = file_content.get("tags", [])
|
||||
|
||||
max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH)
|
||||
|
||||
tags = []
|
||||
tag_weights = []
|
||||
last_weight = None
|
||||
weights_differ = False
|
||||
for unparsed_tag in unparsed_tags:
|
||||
tag = unparsed_tag.get("tag", "").strip()
|
||||
if len(tag) == 0:
|
||||
continue
|
||||
|
||||
tags.append(tag)
|
||||
tag_weight = unparsed_tag.get("weight", 1.0)
|
||||
tag_weights.append(tag_weight)
|
||||
|
||||
if last_weight is not None and weights_differ is False:
|
||||
weights_differ = last_weight != tag_weight
|
||||
|
||||
last_weight = tag_weight
|
||||
|
||||
return ImageCaption(main_prompt, rating, tags, tag_weights, max_caption_length, weights_differ)
|
||||
except:
|
||||
logging.error(f" *** Error reading {file_path} to get caption")
|
||||
return default_caption
|
||||
|
||||
@staticmethod
|
||||
def from_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
|
||||
"""
|
||||
Try to resolve a caption from a file path or return `default_caption`.
|
||||
|
||||
:string: The path to the file to parse.
|
||||
:default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid.
|
||||
:return: `ImageCaption` object or `None`.
|
||||
"""
|
||||
if os.path.exists(file_path):
|
||||
(file_path_without_ext, ext) = os.path.splitext(file_path)
|
||||
match ext:
|
||||
case ".yaml" | ".yml":
|
||||
return ImageCaption.from_yaml_file(file_path, default_caption)
|
||||
|
||||
case ".txt" | ".caption":
|
||||
return ImageCaption.from_text_file(file_path, default_caption)
|
||||
|
||||
case '.jpg'| '.jpeg'| '.png'| '.bmp'| '.webp'| '.jfif':
|
||||
for ext in [".yaml", ".yml", ".txt", ".caption"]:
|
||||
file_path = file_path_without_ext + ext
|
||||
image_caption = ImageCaption.from_file(file_path)
|
||||
if image_caption is not None:
|
||||
return image_caption
|
||||
return ImageCaption.from_file_name(file_path)
|
||||
|
||||
case _:
|
||||
return default_caption
|
||||
else:
|
||||
return default_caption
|
||||
|
||||
@staticmethod
|
||||
def resolve(string: str) -> 'ImageCaption':
|
||||
"""
|
||||
Try to resolve a caption from a string. If the string is a file path,
|
||||
the caption will be read from the file, otherwise the string will be
|
||||
parsed as a caption.
|
||||
|
||||
:string: The string to resolve.
|
||||
:return: `ImageCaption` object.
|
||||
"""
|
||||
return ImageCaption.from_file(string, None) or ImageCaption.parse(string)
|
||||
|
||||
|
||||
class ImageTrainItem:
|
||||
"""
|
||||
|
@ -110,19 +257,26 @@ class ImageTrainItem:
|
|||
flip_p: probability of flipping image (0.0 to 1.0)
|
||||
rating: the relative rating of the images. The rating is measured in comparison to the other images.
|
||||
"""
|
||||
def __init__(self, image: PIL.Image, caption: ImageCaption, target_wh: list, pathname: str, flip_p=0.0, multiplier: float=1.0):
|
||||
def __init__(self, image: PIL.Image, caption: ImageCaption, aspects: list[float], pathname: str, flip_p=0.0, multiplier: float=1.0):
|
||||
self.caption = caption
|
||||
self.target_wh = target_wh
|
||||
self.aspects = aspects
|
||||
self.pathname = pathname
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
self.cropped_img = None
|
||||
self.runt_size = 0
|
||||
self.multiplier = multiplier
|
||||
|
||||
self.image_size = None
|
||||
if image is None:
|
||||
self.image = []
|
||||
else:
|
||||
self.image = image
|
||||
self.image_size = image.size
|
||||
self.target_size = None
|
||||
|
||||
self.is_undersized = False
|
||||
self.error = None
|
||||
self.__compute_target_width_height()
|
||||
|
||||
def hydrate(self, crop=False, save=False, crop_jitter=20):
|
||||
"""
|
||||
|
@ -199,6 +353,18 @@ class ImageTrainItem:
|
|||
# print(self.image.shape)
|
||||
|
||||
return self
|
||||
|
||||
def __compute_target_width_height(self):
|
||||
try:
|
||||
with Image.open(self.pathname) as image:
|
||||
width, height = image.size
|
||||
image_aspect = width / height
|
||||
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
|
||||
|
||||
self.is_undersized = width * height < target_wh[0] * target_wh[1]
|
||||
self.target_wh = target_wh
|
||||
except Exception as e:
|
||||
self.error = e
|
||||
|
||||
@staticmethod
|
||||
def __autocrop(image: PIL.Image, q=.404):
|
||||
|
@ -229,4 +395,4 @@ class ImageTrainItem:
|
|||
min_xy = min(x, y)
|
||||
image = image.crop((x_crop, y_crop, x_crop + min_xy, y_crop + min_xy))
|
||||
|
||||
return image
|
||||
return image
|
|
@ -0,0 +1,230 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import typing
|
||||
import zipfile
|
||||
|
||||
import PIL.Image as Image
|
||||
import tqdm
|
||||
from colorama import Fore, Style
|
||||
|
||||
from data.image_train_item import ImageCaption, ImageTrainItem
|
||||
|
||||
class DataResolver:
|
||||
def __init__(self, aspects: list[typing.Tuple[int, int]], flip_p=0.0, seed=555):
|
||||
self.seed = seed
|
||||
self.aspects = aspects
|
||||
self.flip_p = flip_p
|
||||
|
||||
def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
|
||||
"""
|
||||
Get the list of `ImageTrainItem` for the given data root.
|
||||
|
||||
:param data_root: The data root, a directory, a file, etc..
|
||||
:return: The list of `ImageTrainItem`.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def image_train_item(self, image_path: str, caption: ImageCaption, multiplier: float=1) -> ImageTrainItem:
|
||||
return ImageTrainItem(
|
||||
image=None,
|
||||
caption=caption,
|
||||
aspects=self.aspects,
|
||||
pathname=image_path,
|
||||
flip_p=self.flip_p,
|
||||
multiplier=multiplier
|
||||
)
|
||||
|
||||
class JSONResolver(DataResolver):
|
||||
def image_train_items(self, json_path: str) -> list[ImageTrainItem]:
|
||||
"""
|
||||
Create `ImageTrainItem` objects with metadata for hydration later.
|
||||
Extracts images and captions from a JSON file.
|
||||
|
||||
:param json_path: The path to the JSON file.
|
||||
"""
|
||||
items = []
|
||||
with open(json_path, encoding='utf-8', mode='r') as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
for data in tqdm.tqdm(json_data):
|
||||
caption = JSONResolver.image_caption(data)
|
||||
if caption:
|
||||
image_value = JSONResolver.get_image_value(data)
|
||||
item = self.image_train_item(image_value, caption)
|
||||
if item:
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def get_image_value(json_data: dict) -> typing.Optional[str]:
|
||||
"""
|
||||
Get the image from the json data if possible.
|
||||
|
||||
:param json_data: The json data, a dict.
|
||||
:return: The image, or None if not found.
|
||||
"""
|
||||
image_value = json_data.get("image", None)
|
||||
if isinstance(image_value, str):
|
||||
image_value = image_value.strip()
|
||||
if os.path.exists(image_value):
|
||||
return image_value
|
||||
|
||||
@staticmethod
|
||||
def get_caption_value(json_data: dict) -> typing.Optional[str]:
|
||||
"""
|
||||
Get the caption from the json data if possible.
|
||||
|
||||
:param json_data: The json data, a dict.
|
||||
:return: The caption, or None if not found.
|
||||
"""
|
||||
caption_value = json_data.get("caption", None)
|
||||
if isinstance(caption_value, str):
|
||||
return caption_value.strip()
|
||||
|
||||
@staticmethod
|
||||
def image_caption(json_data: dict) -> typing.Optional[ImageCaption]:
|
||||
"""
|
||||
Get the caption from the json data if possible.
|
||||
|
||||
:param json_data: The json data, a dict.
|
||||
:return: The `ImageCaption`, or None if not found.
|
||||
"""
|
||||
image_value = JSONResolver.get_image_value(json_data)
|
||||
caption_value = JSONResolver.get_caption_value(json_data)
|
||||
if image_value:
|
||||
if caption_value:
|
||||
return ImageCaption.resolve(caption_value)
|
||||
return ImageCaption.from_file(image_value)
|
||||
|
||||
|
||||
class DirectoryResolver(DataResolver):
|
||||
def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
|
||||
"""
|
||||
Create `ImageTrainItem` objects with metadata for hydration later.
|
||||
Unzips all zip files in `data_root` and then recursively searches the
|
||||
`data_root` for images and captions.
|
||||
|
||||
:param data_root: The root directory to recurse through
|
||||
"""
|
||||
DirectoryResolver.unzip_all(data_root)
|
||||
image_paths = list(DirectoryResolver.recurse_data_root(data_root))
|
||||
items = []
|
||||
multipliers = {}
|
||||
skip_folders = []
|
||||
randomizer = random.Random(self.seed)
|
||||
|
||||
for pathname in tqdm.tqdm(image_paths):
|
||||
current_dir = os.path.dirname(pathname)
|
||||
|
||||
if current_dir not in multipliers:
|
||||
multiply_txt_path = os.path.join(current_dir, "multiply.txt")
|
||||
if os.path.exists(multiply_txt_path):
|
||||
try:
|
||||
with open(multiply_txt_path, 'r') as f:
|
||||
val = float(f.read().strip())
|
||||
multipliers[current_dir] = val
|
||||
logging.info(f" * DLMA multiply.txt in {current_dir} set to {val}")
|
||||
except Exception as e:
|
||||
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}")
|
||||
skip_folders.append(current_dir)
|
||||
multipliers[current_dir] = 1.0
|
||||
else:
|
||||
skip_folders.append(current_dir)
|
||||
multipliers[current_dir] = 1.0
|
||||
|
||||
caption = ImageCaption.resolve(pathname)
|
||||
item = self.image_train_item(pathname, caption, multiplier=multipliers[current_dir])
|
||||
|
||||
cur_file_multiplier = multipliers[current_dir]
|
||||
|
||||
while cur_file_multiplier >= 1.0:
|
||||
items.append(item)
|
||||
cur_file_multiplier -= 1
|
||||
|
||||
if cur_file_multiplier > 0:
|
||||
if randomizer.random() < cur_file_multiplier:
|
||||
items.append(item)
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def unzip_all(path):
|
||||
try:
|
||||
for root, dirs, files in os.walk(path):
|
||||
for file in files:
|
||||
if file.endswith('.zip'):
|
||||
logging.info(f"Unzipping {file}")
|
||||
with zipfile.ZipFile(path, 'r') as zip_ref:
|
||||
zip_ref.extractall(path)
|
||||
except Exception as e:
|
||||
logging.error(f"Error unzipping files {e}")
|
||||
|
||||
@staticmethod
|
||||
def recurse_data_root(recurse_root):
|
||||
for f in os.listdir(recurse_root):
|
||||
current = os.path.join(recurse_root, f)
|
||||
|
||||
if os.path.isfile(current):
|
||||
ext = os.path.splitext(f)[1].lower()
|
||||
if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']:
|
||||
yield current
|
||||
|
||||
for d in os.listdir(recurse_root):
|
||||
current = os.path.join(recurse_root, d)
|
||||
if os.path.isdir(current):
|
||||
yield from DirectoryResolver.recurse_data_root(current)
|
||||
|
||||
|
||||
def strategy(data_root: str):
|
||||
if os.path.isfile(data_root) and data_root.endswith('.json'):
|
||||
return JSONResolver
|
||||
|
||||
if os.path.isdir(data_root):
|
||||
return DirectoryResolver
|
||||
|
||||
raise ValueError(f"data_root '{data_root}' is not a valid directory or JSON file.")
|
||||
|
||||
|
||||
def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0, seed=555) -> list[ImageTrainItem]:
|
||||
"""
|
||||
:param data_root: Directory or JSON file.
|
||||
:param aspects: The list of aspect ratios to use
|
||||
:param flip_p: The probability of flipping the image
|
||||
"""
|
||||
if os.path.isfile(path) and path.endswith('.json'):
|
||||
resolver = JSONResolver(aspects, flip_p, seed)
|
||||
|
||||
if os.path.isdir(path):
|
||||
resolver = DirectoryResolver(aspects, flip_p, seed)
|
||||
|
||||
if not resolver:
|
||||
raise ValueError(f"data_root '{path}' is not a valid directory or JSON file.")
|
||||
|
||||
items = resolver.image_train_items(path)
|
||||
return items
|
||||
|
||||
def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=0.0, seed=555) -> list[ImageTrainItem]:
|
||||
"""
|
||||
Resolve the training data from the value.
|
||||
:param value: The value to resolve, either a dict or a string.
|
||||
:param aspects: The list of aspect ratios to use
|
||||
:param flip_p: The probability of flipping the image
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
return resolve_root(value, aspects, flip_p)
|
||||
|
||||
if isinstance(value, dict):
|
||||
resolver = value.get('resolver', None)
|
||||
match resolver:
|
||||
case 'directory' | 'json':
|
||||
path = value.get('path', None)
|
||||
return resolve_root(path, aspects, flip_p, seed)
|
||||
case 'multi':
|
||||
items = []
|
||||
for resolver in value.get('resolvers', []):
|
||||
items += resolve(resolver, aspects, flip_p, seed)
|
||||
return items
|
||||
case _:
|
||||
raise ValueError(f"Cannot resolve training data for resolver value '{resolver}'")
|
|
@ -0,0 +1,113 @@
|
|||
import json
|
||||
import glob
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import PIL.Image as Image
|
||||
|
||||
import data.aspects as aspects
|
||||
import data.resolver as resolver
|
||||
|
||||
DATA_PATH = os.path.abspath('./test/data')
|
||||
JSON_ROOT_PATH = os.path.join(DATA_PATH, 'test_root.json')
|
||||
ASPECTS = aspects.get_aspect_buckets(512)
|
||||
|
||||
IMAGE_1_PATH = os.path.join(DATA_PATH, 'test1.jpg')
|
||||
CAPTION_1_PATH = os.path.join(DATA_PATH, 'test1.txt')
|
||||
IMAGE_2_PATH = os.path.join(DATA_PATH, 'test2.jpg')
|
||||
IMAGE_3_PATH = os.path.join(DATA_PATH, 'test3.jpg')
|
||||
|
||||
class TestResolve(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
Image.new('RGB', (512, 512)).save(IMAGE_1_PATH)
|
||||
with open(CAPTION_1_PATH, 'w') as f:
|
||||
f.write('caption for test1')
|
||||
|
||||
Image.new('RGB', (512, 512)).save(IMAGE_2_PATH)
|
||||
# Undersized image
|
||||
Image.new('RGB', (256, 256)).save(IMAGE_3_PATH)
|
||||
|
||||
json_data = [
|
||||
{
|
||||
'image': IMAGE_1_PATH,
|
||||
'caption': CAPTION_1_PATH
|
||||
},
|
||||
{
|
||||
'image': IMAGE_2_PATH,
|
||||
'caption': 'caption for test2'
|
||||
},
|
||||
{
|
||||
'image': IMAGE_3_PATH,
|
||||
}
|
||||
]
|
||||
|
||||
with open(JSON_ROOT_PATH, 'w') as f:
|
||||
json.dump(json_data, f, indent=4)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
for file in glob.glob(os.path.join(DATA_PATH, 'test*')):
|
||||
os.remove(file)
|
||||
|
||||
def test_directory_resolve_with_str(self):
|
||||
items = resolver.resolve(DATA_PATH, ASPECTS)
|
||||
image_paths = [item.pathname for item in items]
|
||||
image_captions = [item.caption for item in items]
|
||||
captions = [caption.get_caption() for caption in image_captions]
|
||||
|
||||
self.assertEqual(len(items), 3)
|
||||
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH])
|
||||
self.assertEqual(captions, ['caption for test1', 'test2', 'test3'])
|
||||
|
||||
undersized_images = list(filter(lambda i: i.is_undersized, items))
|
||||
self.assertEqual(len(undersized_images), 1)
|
||||
|
||||
def test_directory_resolve_with_dict(self):
|
||||
data_root_spec = {
|
||||
'resolver': 'directory',
|
||||
'path': DATA_PATH,
|
||||
}
|
||||
|
||||
items = resolver.resolve(data_root_spec, ASPECTS)
|
||||
image_paths = [item.pathname for item in items]
|
||||
image_captions = [item.caption for item in items]
|
||||
captions = [caption.get_caption() for caption in image_captions]
|
||||
|
||||
self.assertEqual(len(items), 3)
|
||||
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH])
|
||||
self.assertEqual(captions, ['caption for test1', 'test2', 'test3'])
|
||||
|
||||
undersized_images = list(filter(lambda i: i.is_undersized, items))
|
||||
self.assertEqual(len(undersized_images), 1)
|
||||
|
||||
def test_json_resolve_with_str(self):
|
||||
items = resolver.resolve(JSON_ROOT_PATH, ASPECTS)
|
||||
image_paths = [item.pathname for item in items]
|
||||
image_captions = [item.caption for item in items]
|
||||
captions = [caption.get_caption() for caption in image_captions]
|
||||
|
||||
self.assertEqual(len(items), 3)
|
||||
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH])
|
||||
self.assertEqual(captions, ['caption for test1', 'caption for test2', 'test3'])
|
||||
|
||||
undersized_images = list(filter(lambda i: i.is_undersized, items))
|
||||
self.assertEqual(len(undersized_images), 1)
|
||||
|
||||
def test_json_resolve_with_dict(self):
|
||||
data_root_spec = {
|
||||
'resolver': 'json',
|
||||
'path': JSON_ROOT_PATH,
|
||||
}
|
||||
|
||||
items = resolver.resolve(data_root_spec, ASPECTS)
|
||||
image_paths = [item.pathname for item in items]
|
||||
image_captions = [item.caption for item in items]
|
||||
captions = [caption.get_caption() for caption in image_captions]
|
||||
|
||||
self.assertEqual(len(items), 3)
|
||||
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH])
|
||||
self.assertEqual(captions, ['caption for test1', 'caption for test2', 'test3'])
|
||||
|
||||
undersized_images = list(filter(lambda i: i.is_undersized, items))
|
||||
self.assertEqual(len(undersized_images), 1)
|
|
@ -0,0 +1,71 @@
|
|||
import unittest
|
||||
import os
|
||||
import pathlib
|
||||
import PIL.Image as Image
|
||||
|
||||
from data.image_train_item import ImageCaption, ImageTrainItem
|
||||
|
||||
DATA_PATH = pathlib.Path('./test/data')
|
||||
|
||||
class TestImageCaption(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
with open(DATA_PATH / "test1.txt", encoding='utf-8', mode='w') as f:
|
||||
f.write("caption for test1")
|
||||
|
||||
Image.new("RGB", (512,512)).save(DATA_PATH / "test1.jpg")
|
||||
Image.new("RGB", (512,512)).save(DATA_PATH / "test2.jpg")
|
||||
|
||||
with open(DATA_PATH / "test_caption.caption", encoding='utf-8', mode='w') as f:
|
||||
f.write("caption for test2")
|
||||
|
||||
return super().setUp()
|
||||
|
||||
def tearDown(self) -> None:
|
||||
for file in DATA_PATH.glob("test*"):
|
||||
file.unlink()
|
||||
|
||||
return super().tearDown()
|
||||
|
||||
def test_constructor(self):
|
||||
caption = ImageCaption("hello world", 1.0, ["one", "two", "three"], [1.0]*3, 2048, False)
|
||||
self.assertEqual(caption.get_caption(), "hello world, one, two, three")
|
||||
|
||||
caption = ImageCaption("hello world", 1.0, [], [], 2048, False)
|
||||
self.assertEqual(caption.get_caption(), "hello world")
|
||||
|
||||
def test_parse(self):
|
||||
caption = ImageCaption.parse("hello world, one, two, three")
|
||||
|
||||
self.assertEqual(caption.get_caption(), "hello world, one, two, three")
|
||||
|
||||
def test_from_file_name(self):
|
||||
caption = ImageCaption.from_file_name("foo bar_1_2_3.jpg")
|
||||
self.assertEqual(caption.get_caption(), "foo bar")
|
||||
|
||||
def test_from_text_file(self):
|
||||
caption = ImageCaption.from_text_file("test/data/test1.txt")
|
||||
self.assertEqual(caption.get_caption(), "caption for test1")
|
||||
|
||||
def test_from_file(self):
|
||||
caption = ImageCaption.from_file("test/data/test1.txt")
|
||||
self.assertEqual(caption.get_caption(), "caption for test1")
|
||||
|
||||
caption = ImageCaption.from_file("test/data/test_caption.caption")
|
||||
self.assertEqual(caption.get_caption(), "caption for test2")
|
||||
|
||||
def test_resolve(self):
|
||||
caption = ImageCaption.resolve("test/data/test1.txt")
|
||||
self.assertEqual(caption.get_caption(), "caption for test1")
|
||||
|
||||
caption = ImageCaption.resolve("test/data/test_caption.caption")
|
||||
self.assertEqual(caption.get_caption(), "caption for test2")
|
||||
|
||||
caption = ImageCaption.resolve("hello world")
|
||||
self.assertEqual(caption.get_caption(), "hello world")
|
||||
|
||||
caption = ImageCaption.resolve("test/data/test1.jpg")
|
||||
self.assertEqual(caption.get_caption(), "caption for test1")
|
||||
|
||||
caption = ImageCaption.resolve("test/data/test2.jpg")
|
||||
self.assertEqual(caption.get_caption(), "test2")
|
Loading…
Reference in New Issue