EveryDream2trainer/data/resolver.py

111 lines
3.9 KiB
Python
Raw Normal View History

import json
import logging
import os
import typing
import zipfile
import argparse
from data.dataset import Dataset
import tqdm
from colorama import Fore, Style
from data.image_train_item import ImageCaption, ImageTrainItem
class DataResolver:
def __init__(self, args: argparse.Namespace):
"""
:param args: EveryDream configuration, an `argparse.Namespace` object.
"""
self.aspects = args.aspects
self.flip_p = args.flip_p
def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
"""
Get the list of `ImageTrainItem` for the given data root.
:param data_root: The data root, a directory, a file, etc..
:return: The list of `ImageTrainItem`.
"""
raise NotImplementedError()
class JSONResolver(DataResolver):
def image_train_items(self, json_path: str) -> list[ImageTrainItem]:
"""
Create `ImageTrainItem` objects with metadata for hydration later.
Extracts images and captions from a JSON file.
:param json_path: The path to the JSON file.
"""
return Dataset.from_json(json_path).image_train_items(self.aspects)
class DirectoryResolver(DataResolver):
def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
"""
Create `ImageTrainItem` objects with metadata for hydration later.
Unzips all zip files in `data_root` and then recursively searches the
`data_root` for images and captions.
:param data_root: The root directory to recurse through
"""
DirectoryResolver.unzip_all(data_root)
return Dataset.from_path(data_root).image_train_items(self.aspects)
@staticmethod
def unzip_all(path):
try:
for root, dirs, files in os.walk(path):
for file in files:
if file.endswith('.zip'):
logging.info(f"Unzipping {file}")
with zipfile.ZipFile(path, 'r') as zip_ref:
zip_ref.extractall(path)
except Exception as e:
logging.error(f"Error unzipping files {e}")
def strategy(data_root: str) -> typing.Type[DataResolver]:
"""
Determine the strategy to use for resolving the data.
:param data_root: The root directory or JSON file to resolve.
"""
if os.path.isfile(data_root) and data_root.endswith('.json'):
return JSONResolver
if os.path.isdir(data_root):
return DirectoryResolver
raise ValueError(f"data_root '{data_root}' is not a valid directory or JSON file.")
def resolve_root(path: str, args: argparse.Namespace) -> list[ImageTrainItem]:
"""
Resolve the training data from the root path.
:param path: The root path to resolve.
:param args: EveryDream configuration, an `argparse.Namespace` object.
"""
resolver = strategy(path)
return resolver(args).image_train_items(path)
def resolve(value: typing.Union[dict, str], args: argparse.Namespace) -> list[ImageTrainItem]:
"""
Resolve the training data from the value.
:param value: The value to resolve, either a dict, an array, or a string.
:param args: EveryDream configuration, an `argparse.Namespace` object.
"""
if isinstance(value, str):
return resolve_root(value, args)
if isinstance(value, dict):
resolver = value.get('resolver', None)
match resolver:
case 'directory' | 'json':
path = value.get('path', None)
return resolve_root(path, args)
case 'multi':
return resolve(value.get('resolvers', []), args)
case _:
raise ValueError(f"Cannot resolve training data for resolver value '{resolver}'")
if isinstance(value, list):
items = []
for item in value:
items += resolve(item, args)
return items