EveryDream2trainer/data/resolver.py

249 lines
8.8 KiB
Python
Raw Normal View History

import json
import logging
import os
import typing
import zipfile
import PIL.Image as Image
import tqdm
from colorama import Fore, Style
from data.image_train_item import ImageCaption, ImageTrainItem
OptionalCallable = typing.Optional[typing.Callable]
class Event:
def __init__(self, name: str):
self.name = name
class UndersizedImageEvent(Event):
def __init__(self, image_path: str, image_size: typing.Tuple[int, int], target_size: typing.Tuple[int, int]):
super().__init__('undersized_image')
self.image_path = image_path
self.image_size = image_size
self.target_size = target_size
class DataResolver:
def __init__(self, aspects: list[typing.Tuple[int, int]], flip_p=0.0, on_event: OptionalCallable=None):
self.aspects = aspects
self.flip_p = flip_p
self.on_event = on_event or (lambda data: None)
def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
"""
Get the list of `ImageTrainItem` for the given data root.
:param data_root: The data root, a directory, a file, etc..
:return: The list of `ImageTrainItem`.
"""
raise NotImplementedError()
def compute_target_width_height(self, image_path: str) -> typing.Optional[typing.Tuple[int, int]]:
# Compute the target width and height for the image based on the aspect ratio.
with Image.open(image_path) as image:
width, height = image.size
image_aspect = width / height
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
if width * height < target_wh[0] * target_wh[1]:
event = UndersizedImageEvent(image_path, (width, height), target_wh)
self.on_event(event)
return target_wh
def image_train_item(self, image_path: str, caption: ImageCaption) -> ImageTrainItem:
#try:
target_wh = self.compute_target_width_height(image_path)
return ImageTrainItem(image=None, caption=caption, target_wh=target_wh, pathname=image_path, flip_p=self.flip_p)
# except Exception as e:
# logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{image_path}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
# logging.error(f" *** exception: {e}")
class JSONResolver(DataResolver):
def image_train_items(self, json_path: str) -> list[ImageTrainItem]:
items = []
with open(json_path, encoding='utf-8', mode='r') as f:
json_data = json.load(f)
for data in json_data:
caption = JSONResolver.image_caption(data)
if caption:
image_value = JSONResolver.get_image_value(data)
item = self.image_train_item(image_value, caption)
if item:
items.append(item)
return items
@staticmethod
def get_image_value(json_data: dict) -> typing.Optional[str]:
"""
Get the image from the json data if possible.
:param json_data: The json data, a dict.
:return: The image, or None if not found.
"""
image_value = json_data.get("image", None)
if isinstance(image_value, str):
image_value = image_value.strip()
if os.path.exists(image_value):
return image_value
@staticmethod
def get_caption_value(json_data: dict) -> typing.Optional[str]:
"""
Get the caption from the json data if possible.
:param json_data: The json data, a dict.
:return: The caption, or None if not found.
"""
caption_value = json_data.get("caption", None)
if isinstance(caption_value, str):
return caption_value.strip()
@staticmethod
def image_caption(json_data: dict) -> typing.Optional[ImageCaption]:
"""
Get the caption from the json data if possible.
:param json_data: The json data, a dict.
:return: The `ImageCaption`, or None if not found.
"""
image_value = JSONResolver.get_image_value(json_data)
caption_value = JSONResolver.get_caption_value(json_data)
if image_value:
if caption_value:
return ImageCaption.resolve(caption_value)
return ImageCaption.from_file(image_value)
class DirectoryResolver(DataResolver):
def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
"""
Create `ImageTrainItem` objects with metadata for hydration later.
Unzips all zip files in `data_root` and then recursively searches the
`data_root` for images and captions.
:param data_root: The root directory to recurse through
"""
DirectoryResolver.unzip_all(data_root)
image_paths = list(DirectoryResolver.recurse_data_root(data_root))
items = []
for pathname in tqdm.tqdm(image_paths):
caption = ImageCaption.from_file(pathname)
item = self.image_train_item(pathname, caption)
if item:
items.append(item)
return items
@staticmethod
def unzip_all(path):
try:
for root, dirs, files in os.walk(path):
for file in files:
if file.endswith('.zip'):
logging.info(f"Unzipping {file}")
with zipfile.ZipFile(path, 'r') as zip_ref:
zip_ref.extractall(path)
except Exception as e:
logging.error(f"Error unzipping files {e}")
@staticmethod
def recurse_data_root(recurse_root):
for f in os.listdir(recurse_root):
current = os.path.join(recurse_root, f)
if os.path.isfile(current):
ext = os.path.splitext(f)[1].lower()
if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']:
yield current
sub_dirs = []
for d in os.listdir(recurse_root):
current = os.path.join(recurse_root, d)
if os.path.isdir(current):
sub_dirs.append(current)
for dir in sub_dirs:
DirectoryResolver.__recurse_data_root(dir)
def strategy(data_root: str):
if os.path.isfile(data_root) and data_root.endswith('.json'):
return JSONResolver
if os.path.isdir(data_root):
return DirectoryResolver
raise ValueError(f"data_root '{data_root}' is not a valid directory or JSON file.")
def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0, on_event: OptionalCallable=None) -> list[ImageTrainItem]:
"""
:param data_root: Directory or JSON file.
:param aspects: The list of aspect ratios to use
:param flip_p: The probability of flipping the image
"""
if os.path.isfile(path) and path.endswith('.json'):
resolver = JSONResolver(aspects, flip_p, on_event)
if os.path.isdir(path):
resolver = DirectoryResolver(aspects, flip_p, on_event)
if not resolver:
raise ValueError(f"data_root '{path}' is not a valid directory or JSON file.")
return resolver.image_train_items(path)
def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=0.0, on_event: OptionalCallable=None) -> list[ImageTrainItem]:
"""
Resolve the training data from the value.
:param value: The value to resolve, either a dict or a string.
:param aspects: The list of aspect ratios to use
:param flip_p: The probability of flipping the image
:param on_event: The callback to call when an event occurs (e.g. undersized image detected)
"""
if isinstance(value, str):
return resolve_root(value, aspects, flip_p, on_event)
if isinstance(value, dict):
resolver = value.get('resolver', None)
match resolver:
case 'directory' | 'json':
path = value.get('path', None)
return resolve_root(path, aspects, flip_p, on_event)
case 'multi':
items = []
for resolver in value.get('resolvers', []):
items += resolve(resolver, aspects, flip_p, on_event)
return items
case _:
raise ValueError(f"Cannot resolve training data for resolver value '{resolver}'")
# example = {
# 'resolver': 'directory',
# 'data_root': 'data',
# }
# example = {
# 'resolver': 'json',
# 'data_root': 'data.json',
# }
# example = {
# 'resolver': 'multi',
# 'resolvers': [
# {
# 'resolver': 'directory',
# 'data_root': 'data',
# }, {
# 'resolver': 'json',
# 'data_root': 'data.json',
# },
# ]
# }