Add support for enhanced dataset configuration
Add support for: * flip_p.txt * cond_dropout.txt * local.yaml consolidated config (including default captions) * global.yaml consolidated config which applies recursively to subfolders * flip_p, and cond_dropout config per image * manifest.json with full image-level configuration
This commit is contained in:
parent
2e3d044ba3
commit
0716c40ab6
|
@ -0,0 +1,261 @@
|
|||
import os
|
||||
import logging
|
||||
import yaml
|
||||
import json
|
||||
|
||||
from functools import total_ordering
|
||||
from attrs import define, Factory
|
||||
from data.image_train_item import ImageCaption, ImageTrainItem
|
||||
from utils.fs_helpers import *
|
||||
|
||||
|
||||
@define(frozen=True)
|
||||
@total_ordering
|
||||
class Tag:
|
||||
value: str
|
||||
weight: float = None
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.value < other.value
|
||||
|
||||
@define(frozen=True)
|
||||
@total_ordering
|
||||
class Caption:
|
||||
main_prompt: str = None
|
||||
rating: float = None
|
||||
max_caption_length: int = None
|
||||
tags: frozenset[Tag] = Factory(frozenset)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
main_prompt = data.get("main_prompt")
|
||||
rating = data.get("rating")
|
||||
max_caption_length = data.get("max_caption_length")
|
||||
|
||||
tags = frozenset([ Tag(value=t.get("tag"), weight=t.get("weight"))
|
||||
for t in data.get("tags", [])
|
||||
if "tag" in t and len(t.get("tag")) > 0])
|
||||
|
||||
if not main_prompt and not tags:
|
||||
return None
|
||||
|
||||
return Caption(main_prompt=main_prompt, rating=rating, max_caption_length=max_caption_length, tags=tags)
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, text: str):
|
||||
if text is None:
|
||||
return Caption(main_prompt="")
|
||||
split_caption = list(map(str.strip, text.split(",")))
|
||||
main_prompt = split_caption[0]
|
||||
tags = frozenset(Tag(value=t) for t in split_caption[1:])
|
||||
return Caption(main_prompt=main_prompt, tags=tags)
|
||||
|
||||
@classmethod
|
||||
def load(cls, input):
|
||||
if isinstance(input, str):
|
||||
if os.path.isfile(input):
|
||||
return Caption.from_text(read_text(input))
|
||||
else:
|
||||
return Caption.from_text(input)
|
||||
elif isinstance(input, dict):
|
||||
return Caption.from_dict(input)
|
||||
|
||||
def __lt__(self, other):
|
||||
self_str = ",".join([self.main_prompt] + sorted(repr(t) for t in self.tags))
|
||||
other_str = ",".join([other.main_prompt] + sorted(repr(t) for t in other.tags))
|
||||
return self_str < other_str
|
||||
|
||||
@define(frozen=True)
|
||||
class ImageConfig:
|
||||
image: str = None
|
||||
captions: frozenset[Caption] = Factory(frozenset)
|
||||
multiply: float = None
|
||||
cond_dropout: float = None
|
||||
flip_p: float = None
|
||||
|
||||
@classmethod
|
||||
def fold(cls, configs):
|
||||
acc = ImageConfig()
|
||||
[acc := acc.merge(cfg) for cfg in configs]
|
||||
return acc
|
||||
|
||||
def merge(self, other):
|
||||
if other is None:
|
||||
return self
|
||||
|
||||
if other.image and self.image:
|
||||
logging(f"Found two images with different extensions and the same barename: {self.image} and {other.image}")
|
||||
|
||||
return ImageConfig(
|
||||
image = other.image or self.image,
|
||||
captions = other.captions.union(self.captions),
|
||||
multiply = other.multiply if other.multiply is not None else self.multiply,
|
||||
cond_dropout = other.cond_dropout if other.cond_dropout is not None else self.cond_dropout,
|
||||
flip_p = other.flip_p if other.flip_p is not None else self.flip_p
|
||||
)
|
||||
|
||||
def ensure_caption(self):
|
||||
if not self.captions:
|
||||
filename_caption = Caption.from_text(barename(self.image).split("_")[0])
|
||||
return self.merge(ImageConfig(captions=frozenset([filename_caption])))
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
captions = set()
|
||||
if "captions" in data:
|
||||
captions.update(Caption.load(cap) for cap in data.get("captions"))
|
||||
|
||||
if "caption" in data:
|
||||
captions.add(Caption.load(data.get("caption")))
|
||||
|
||||
if not captions:
|
||||
# For backward compatibility with existing caption yaml
|
||||
caption = Caption.load(data)
|
||||
if caption:
|
||||
captions.add(caption)
|
||||
|
||||
return ImageConfig(
|
||||
image = data.get("image"),
|
||||
captions=frozenset(captions),
|
||||
multiply=data.get("multiply"),
|
||||
cond_dropout=data.get("cond_dropout"),
|
||||
flip_p=data.get("flip_p"))
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, text: str):
|
||||
try:
|
||||
if os.path.isfile(text):
|
||||
return ImageConfig.from_file(text)
|
||||
return ImageConfig(captions=frozenset({Caption.from_text(text)}))
|
||||
except Exception as e:
|
||||
logging.warning(f" *** Error parsing config from text {text}: \n{e}")
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, file: str):
|
||||
try:
|
||||
match ext(file):
|
||||
case '.jpg' | '.jpeg' | '.png' | '.bmp' | '.webp' | '.jfif':
|
||||
return ImageConfig(image=file)
|
||||
case ".json":
|
||||
return ImageConfig.from_dict(json.load(read_text(file)))
|
||||
case ".yaml" | ".yml":
|
||||
return ImageConfig.from_dict(yaml.safe_load(read_text(file)))
|
||||
case ".txt" | ".caption":
|
||||
return ImageConfig.from_text(read_text(file))
|
||||
case _:
|
||||
return logging.warning(" *** Unrecognized config extension {ext}")
|
||||
except Exception as e:
|
||||
logging.warning(f" *** Error parsing config from {file}: {e}")
|
||||
|
||||
@classmethod
|
||||
def load(cls, input):
|
||||
if isinstance(input, str):
|
||||
return ImageConfig.from_text(input)
|
||||
elif isinstance(input, dict):
|
||||
return ImageConfig.from_dict(input)
|
||||
|
||||
@define()
|
||||
class Dataset:
|
||||
image_configs: set[ImageConfig]
|
||||
|
||||
def __global_cfg(files):
|
||||
cfgs = []
|
||||
for file in files:
|
||||
match os.path.basename(file):
|
||||
case 'global.yaml' | 'global.yml':
|
||||
cfgs.append(ImageConfig.from_file(file))
|
||||
return ImageConfig.fold(cfgs)
|
||||
|
||||
def __local_cfg(files):
|
||||
cfgs = []
|
||||
for file in files:
|
||||
match os.path.basename(file):
|
||||
case 'multiply.txt':
|
||||
cfgs.append(ImageConfig(multiply=read_float(file)))
|
||||
case 'cond_dropout.txt':
|
||||
cfgs.append(ImageConfig(cond_dropout=read_float(file)))
|
||||
case 'flip_p.txt':
|
||||
cfgs.append(ImageConfig(flip_p=read_float(file)))
|
||||
case 'local.yaml' | 'local.yml':
|
||||
cfgs.append(ImageConfig.from_file(file))
|
||||
return ImageConfig.fold(cfgs)
|
||||
|
||||
def __image_cfg(imagepath, files):
|
||||
cfgs = [ImageConfig.from_file(imagepath)]
|
||||
for file in files:
|
||||
if same_barename(imagepath, file):
|
||||
match ext(file):
|
||||
case '.txt' | '.caption' | '.yml' | '.yaml':
|
||||
cfgs.append(ImageConfig.from_file(file))
|
||||
return ImageConfig.fold(cfgs)
|
||||
|
||||
@classmethod
|
||||
def from_path(cls, data_root):
|
||||
# Create a visitor that maintains global config stack
|
||||
# and accumulates image configs as it traverses dataset
|
||||
image_configs = set()
|
||||
def process_dir(files, parent_globals):
|
||||
globals = parent_globals.merge(Dataset.__global_cfg(files))
|
||||
locals = Dataset.__local_cfg(files)
|
||||
for img in filter(is_image, files):
|
||||
img_cfg = Dataset.__image_cfg(img, files)
|
||||
collapsed_cfg = ImageConfig.fold([globals, locals, img_cfg])
|
||||
resolved_cfg = collapsed_cfg.ensure_caption()
|
||||
image_configs.add(resolved_cfg)
|
||||
return globals
|
||||
|
||||
walk_and_visit(data_root, process_dir, ImageConfig())
|
||||
return Dataset(image_configs)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_path):
|
||||
"""
|
||||
Import a dataset definition from a JSON file
|
||||
"""
|
||||
configs = set()
|
||||
with open(json_path, encoding='utf-8', mode='r') as stream:
|
||||
for data in json.load(stream):
|
||||
cfg = ImageConfig.load(data).ensure_caption()
|
||||
if not cfg or not cfg.image:
|
||||
logging.warning(f" *** Error parsing json image entry in {json_path}: {data}")
|
||||
continue
|
||||
configs.add(cfg)
|
||||
return Dataset(configs)
|
||||
|
||||
def image_train_items(self, aspects):
|
||||
items = []
|
||||
for config in self.image_configs:
|
||||
caption = next(iter(sorted(config.captions)))
|
||||
if len(config.captions) > 1:
|
||||
logging.warning(f" *** Found multiple captions for image {config.image}, but only one will be applied: {config.captions}")
|
||||
|
||||
use_weights = len(set(t.weight or 1.0 for t in caption.tags)) > 1
|
||||
tags = []
|
||||
tag_weights = []
|
||||
for tag in sorted(caption.tags):
|
||||
tags.append(tag.value)
|
||||
tag_weights.append(tag.weight or 1.0)
|
||||
use_weights = len(set(tag_weights)) > 1
|
||||
|
||||
caption = ImageCaption(
|
||||
main_prompt=caption.main_prompt,
|
||||
rating=caption.rating,
|
||||
tags=tags,
|
||||
tag_weights=tag_weights,
|
||||
max_target_length=caption.max_caption_length,
|
||||
use_weights=use_weights)
|
||||
|
||||
item = ImageTrainItem(
|
||||
image=None,
|
||||
caption=caption,
|
||||
aspects=aspects,
|
||||
pathname=os.path.abspath(config.image),
|
||||
flip_p=config.flip_p or 0.0,
|
||||
multiplier=config.multiply or 1.0,
|
||||
cond_dropout=config.cond_dropout
|
||||
)
|
||||
items.append(item)
|
||||
return list(sorted(items, key=lambda ti: ti.pathname))
|
||||
|
||||
|
|
@ -104,7 +104,7 @@ class EveryDreamBatch(Dataset):
|
|||
|
||||
example["image"] = image_transforms(train_item["image"])
|
||||
|
||||
if random.random() > self.conditional_dropout:
|
||||
if random.random() > (train_item.cond_dropout or self.conditional_dropout):
|
||||
example["tokens"] = self.tokenizer(example["caption"],
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
|
|
|
@ -113,136 +113,6 @@ class ImageCaption:
|
|||
random.Random(seed).shuffle(tags)
|
||||
return ", ".join(tags)
|
||||
|
||||
@staticmethod
|
||||
def parse(string: str) -> 'ImageCaption':
|
||||
"""
|
||||
Parses a string to get the caption.
|
||||
|
||||
:param string: String to parse.
|
||||
:return: `ImageCaption` object.
|
||||
"""
|
||||
split_caption = list(map(str.strip, string.split(",")))
|
||||
main_prompt = split_caption[0]
|
||||
tags = split_caption[1:]
|
||||
tag_weights = [1.0] * len(tags)
|
||||
|
||||
return ImageCaption(main_prompt, 1.0, tags, tag_weights, DEFAULT_MAX_CAPTION_LENGTH, False)
|
||||
|
||||
@staticmethod
|
||||
def from_file_name(file_path: str) -> 'ImageCaption':
|
||||
"""
|
||||
Parses the file name to get the caption.
|
||||
|
||||
:param file_path: Path to the image file.
|
||||
:return: `ImageCaption` object.
|
||||
"""
|
||||
(file_name, _) = os.path.splitext(os.path.basename(file_path))
|
||||
caption = file_name.split("_")[0]
|
||||
return ImageCaption.parse(caption)
|
||||
|
||||
@staticmethod
|
||||
def from_text_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
|
||||
"""
|
||||
Parses a text file to get the caption. Returns the default caption if
|
||||
the file does not exist or is invalid.
|
||||
|
||||
:param file_path: Path to the text file.
|
||||
:param default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid.
|
||||
:return: `ImageCaption` object or `None`.
|
||||
"""
|
||||
try:
|
||||
with open(file_path, encoding='utf-8', mode='r') as caption_file:
|
||||
caption_text = caption_file.read()
|
||||
return ImageCaption.parse(caption_text)
|
||||
except:
|
||||
logging.error(f" *** Error reading {file_path} to get caption")
|
||||
return default_caption
|
||||
|
||||
@staticmethod
|
||||
def from_yaml_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
|
||||
"""
|
||||
Parses a yaml file to get the caption. Returns the default caption if
|
||||
the file does not exist or is invalid.
|
||||
|
||||
:param file_path: path to the yaml file
|
||||
:param default_caption: caption to return if the file does not exist or is invalid
|
||||
:return: `ImageCaption` object or `None`.
|
||||
"""
|
||||
try:
|
||||
with open(file_path, "r") as stream:
|
||||
file_content = yaml.safe_load(stream)
|
||||
main_prompt = file_content.get("main_prompt", "")
|
||||
rating = file_content.get("rating", 1.0)
|
||||
unparsed_tags = file_content.get("tags", [])
|
||||
|
||||
max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH)
|
||||
|
||||
tags = []
|
||||
tag_weights = []
|
||||
last_weight = None
|
||||
weights_differ = False
|
||||
for unparsed_tag in unparsed_tags:
|
||||
tag = unparsed_tag.get("tag", "").strip()
|
||||
if len(tag) == 0:
|
||||
continue
|
||||
|
||||
tags.append(tag)
|
||||
tag_weight = unparsed_tag.get("weight", 1.0)
|
||||
tag_weights.append(tag_weight)
|
||||
|
||||
if last_weight is not None and weights_differ is False:
|
||||
weights_differ = last_weight != tag_weight
|
||||
|
||||
last_weight = tag_weight
|
||||
|
||||
return ImageCaption(main_prompt, rating, tags, tag_weights, max_caption_length, weights_differ)
|
||||
except:
|
||||
logging.error(f" *** Error reading {file_path} to get caption")
|
||||
return default_caption
|
||||
|
||||
@staticmethod
|
||||
def from_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
|
||||
"""
|
||||
Try to resolve a caption from a file path or return `default_caption`.
|
||||
|
||||
:string: The path to the file to parse.
|
||||
:default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid.
|
||||
:return: `ImageCaption` object or `None`.
|
||||
"""
|
||||
if os.path.exists(file_path):
|
||||
(file_path_without_ext, ext) = os.path.splitext(file_path)
|
||||
match ext:
|
||||
case ".yaml" | ".yml":
|
||||
return ImageCaption.from_yaml_file(file_path, default_caption)
|
||||
|
||||
case ".txt" | ".caption":
|
||||
return ImageCaption.from_text_file(file_path, default_caption)
|
||||
|
||||
case '.jpg'| '.jpeg'| '.png'| '.bmp'| '.webp'| '.jfif':
|
||||
for ext in [".yaml", ".yml", ".txt", ".caption"]:
|
||||
file_path = file_path_without_ext + ext
|
||||
image_caption = ImageCaption.from_file(file_path)
|
||||
if image_caption is not None:
|
||||
return image_caption
|
||||
return ImageCaption.from_file_name(file_path)
|
||||
|
||||
case _:
|
||||
return default_caption
|
||||
else:
|
||||
return default_caption
|
||||
|
||||
@staticmethod
|
||||
def resolve(string: str) -> 'ImageCaption':
|
||||
"""
|
||||
Try to resolve a caption from a string. If the string is a file path,
|
||||
the caption will be read from the file, otherwise the string will be
|
||||
parsed as a caption.
|
||||
|
||||
:string: The string to resolve.
|
||||
:return: `ImageCaption` object.
|
||||
"""
|
||||
return ImageCaption.from_file(string, None) or ImageCaption.parse(string)
|
||||
|
||||
|
||||
class ImageTrainItem:
|
||||
"""
|
||||
|
@ -253,7 +123,7 @@ class ImageTrainItem:
|
|||
flip_p: probability of flipping image (0.0 to 1.0)
|
||||
rating: the relative rating of the images. The rating is measured in comparison to the other images.
|
||||
"""
|
||||
def __init__(self, image: PIL.Image, caption: ImageCaption, aspects: list[float], pathname: str, flip_p=0.0, multiplier: float=1.0):
|
||||
def __init__(self, image: PIL.Image, caption: ImageCaption, aspects: list[float], pathname: str, flip_p=0.0, multiplier: float=1.0, cond_dropout=None):
|
||||
self.caption = caption
|
||||
self.aspects = aspects
|
||||
self.pathname = pathname
|
||||
|
@ -261,6 +131,7 @@ class ImageTrainItem:
|
|||
self.cropped_img = None
|
||||
self.runt_size = 0
|
||||
self.multiplier = multiplier
|
||||
self.cond_dropout = cond_dropout
|
||||
|
||||
self.image_size = None
|
||||
if image is None or len(image) == 0:
|
||||
|
|
109
data/resolver.py
109
data/resolver.py
|
@ -4,6 +4,7 @@ import os
|
|||
import typing
|
||||
import zipfile
|
||||
import argparse
|
||||
from data.dataset import Dataset
|
||||
|
||||
import tqdm
|
||||
from colorama import Fore, Style
|
||||
|
@ -27,16 +28,6 @@ class DataResolver:
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def image_train_item(self, image_path: str, caption: ImageCaption, multiplier: float=1) -> ImageTrainItem:
|
||||
return ImageTrainItem(
|
||||
image=None,
|
||||
caption=caption,
|
||||
aspects=self.aspects,
|
||||
pathname=image_path,
|
||||
flip_p=self.flip_p,
|
||||
multiplier=multiplier
|
||||
)
|
||||
|
||||
class JSONResolver(DataResolver):
|
||||
def image_train_items(self, json_path: str) -> list[ImageTrainItem]:
|
||||
"""
|
||||
|
@ -45,62 +36,8 @@ class JSONResolver(DataResolver):
|
|||
|
||||
:param json_path: The path to the JSON file.
|
||||
"""
|
||||
items = []
|
||||
with open(json_path, encoding='utf-8', mode='r') as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
for data in tqdm.tqdm(json_data):
|
||||
caption = JSONResolver.image_caption(data)
|
||||
if caption:
|
||||
image_value = JSONResolver.get_image_value(data)
|
||||
item = self.image_train_item(image_value, caption)
|
||||
if item:
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
return Dataset.from_json(json_path).image_train_items(self.aspects)
|
||||
|
||||
@staticmethod
|
||||
def get_image_value(json_data: dict) -> typing.Optional[str]:
|
||||
"""
|
||||
Get the image from the json data if possible.
|
||||
|
||||
:param json_data: The json data, a dict.
|
||||
:return: The image, or None if not found.
|
||||
"""
|
||||
image_value = json_data.get("image", None)
|
||||
if isinstance(image_value, str):
|
||||
image_value = image_value.strip()
|
||||
if os.path.exists(image_value):
|
||||
return image_value
|
||||
|
||||
@staticmethod
|
||||
def get_caption_value(json_data: dict) -> typing.Optional[str]:
|
||||
"""
|
||||
Get the caption from the json data if possible.
|
||||
|
||||
:param json_data: The json data, a dict.
|
||||
:return: The caption, or None if not found.
|
||||
"""
|
||||
caption_value = json_data.get("caption", None)
|
||||
if isinstance(caption_value, str):
|
||||
return caption_value.strip()
|
||||
|
||||
@staticmethod
|
||||
def image_caption(json_data: dict) -> typing.Optional[ImageCaption]:
|
||||
"""
|
||||
Get the caption from the json data if possible.
|
||||
|
||||
:param json_data: The json data, a dict.
|
||||
:return: The `ImageCaption`, or None if not found.
|
||||
"""
|
||||
image_value = JSONResolver.get_image_value(json_data)
|
||||
caption_value = JSONResolver.get_caption_value(json_data)
|
||||
if image_value:
|
||||
if caption_value:
|
||||
return ImageCaption.resolve(caption_value)
|
||||
return ImageCaption.from_file(image_value)
|
||||
|
||||
|
||||
class DirectoryResolver(DataResolver):
|
||||
def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
|
||||
"""
|
||||
|
@ -111,32 +48,7 @@ class DirectoryResolver(DataResolver):
|
|||
:param data_root: The root directory to recurse through
|
||||
"""
|
||||
DirectoryResolver.unzip_all(data_root)
|
||||
image_paths = list(DirectoryResolver.recurse_data_root(data_root))
|
||||
items = []
|
||||
multipliers = {}
|
||||
|
||||
for pathname in tqdm.tqdm(image_paths):
|
||||
current_dir = os.path.dirname(pathname)
|
||||
|
||||
if current_dir not in multipliers:
|
||||
multiply_txt_path = os.path.join(current_dir, "multiply.txt")
|
||||
if os.path.exists(multiply_txt_path):
|
||||
try:
|
||||
with open(multiply_txt_path, 'r') as f:
|
||||
val = float(f.read().strip())
|
||||
multipliers[current_dir] = val
|
||||
logging.info(f" - multiply.txt in '{current_dir}' set to {val}")
|
||||
except Exception as e:
|
||||
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}")
|
||||
multipliers[current_dir] = 1.0
|
||||
else:
|
||||
multipliers[current_dir] = 1.0
|
||||
|
||||
caption = ImageCaption.resolve(pathname)
|
||||
item = self.image_train_item(pathname, caption, multiplier=multipliers[current_dir])
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
return Dataset.from_path(data_root).image_train_items(self.aspects)
|
||||
|
||||
@staticmethod
|
||||
def unzip_all(path):
|
||||
|
@ -150,21 +62,6 @@ class DirectoryResolver(DataResolver):
|
|||
except Exception as e:
|
||||
logging.error(f"Error unzipping files {e}")
|
||||
|
||||
@staticmethod
|
||||
def recurse_data_root(recurse_root):
|
||||
for f in os.listdir(recurse_root):
|
||||
current = os.path.join(recurse_root, f)
|
||||
|
||||
if os.path.isfile(current):
|
||||
ext = os.path.splitext(f)[1].lower()
|
||||
if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']:
|
||||
yield current
|
||||
|
||||
for d in os.listdir(recurse_root):
|
||||
current = os.path.join(recurse_root, d)
|
||||
if os.path.isdir(current):
|
||||
yield from DirectoryResolver.recurse_data_root(current)
|
||||
|
||||
def strategy(data_root: str) -> typing.Type[DataResolver]:
|
||||
"""
|
||||
Determine the strategy to use for resolving the data.
|
||||
|
|
|
@ -10,6 +10,7 @@ ninja
|
|||
omegaconf==2.2.3
|
||||
piexif==1.1.3
|
||||
protobuf==3.20.3
|
||||
pyfakefs
|
||||
pynvml==11.5.0
|
||||
pyre-extensions==0.0.30
|
||||
pytorch-lightning==1.9.2
|
||||
|
|
|
@ -58,13 +58,13 @@ class TestResolve(unittest.TestCase):
|
|||
|
||||
def test_directory_resolve_with_str(self):
|
||||
items = resolver.resolve(DATA_PATH, ARGS)
|
||||
image_paths = [item.pathname for item in items]
|
||||
image_paths = set(item.pathname for item in items)
|
||||
image_captions = [item.caption for item in items]
|
||||
captions = [caption.get_caption() for caption in image_captions]
|
||||
captions = set(caption.get_caption() for caption in image_captions)
|
||||
|
||||
self.assertEqual(len(items), 3)
|
||||
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH])
|
||||
self.assertEqual(captions, ['caption for test1', 'test2', 'test3'])
|
||||
self.assertEqual(image_paths, {IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH})
|
||||
self.assertEqual(captions, {'caption for test1', 'test2', 'test3'})
|
||||
|
||||
undersized_images = list(filter(lambda i: i.is_undersized, items))
|
||||
self.assertEqual(len(undersized_images), 1)
|
||||
|
|
|
@ -0,0 +1,329 @@
|
|||
import os
|
||||
from data.dataset import Dataset, ImageConfig, Caption, Tag
|
||||
|
||||
from textwrap import dedent
|
||||
from pyfakefs.fake_filesystem_unittest import TestCase
|
||||
|
||||
class TestResolve(TestCase):
|
||||
def setUp(self):
|
||||
self.setUpPyfakefs()
|
||||
|
||||
def test_simple_image(self):
|
||||
self.fs.create_file("image, tag1, tag2.jpg")
|
||||
|
||||
actual = Dataset.from_path(".").image_configs
|
||||
|
||||
expected = {
|
||||
ImageConfig(
|
||||
image="./image, tag1, tag2.jpg",
|
||||
captions=frozenset([
|
||||
Caption(main_prompt="image", tags=frozenset([Tag("tag1"), Tag("tag2")]))
|
||||
]))
|
||||
}
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_image_types(self):
|
||||
self.fs.create_file("image_1.JPG")
|
||||
self.fs.create_file("image_2.jpeg")
|
||||
self.fs.create_file("image_3.png")
|
||||
self.fs.create_file("image_4.webp")
|
||||
self.fs.create_file("image_5.jfif")
|
||||
self.fs.create_file("image_6.bmp")
|
||||
|
||||
actual = Dataset.from_path(".").image_configs
|
||||
|
||||
captions = frozenset([Caption(main_prompt="image")])
|
||||
expected = {
|
||||
ImageConfig(image="./image_1.JPG", captions=captions),
|
||||
ImageConfig(image="./image_2.jpeg", captions=captions),
|
||||
ImageConfig(image="./image_3.png", captions=captions),
|
||||
ImageConfig(image="./image_4.webp", captions=captions),
|
||||
ImageConfig(image="./image_5.jfif", captions=captions),
|
||||
ImageConfig(image="./image_6.bmp", captions=captions),
|
||||
}
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_caption_file(self):
|
||||
self.fs.create_file("image_1.jpg")
|
||||
self.fs.create_file("image_1.txt", contents="an image, test, from .txt")
|
||||
self.fs.create_file("image_2.jpg")
|
||||
self.fs.create_file("image_2.caption", contents="an image, test, from .caption")
|
||||
|
||||
actual = Dataset.from_path(".").image_configs
|
||||
|
||||
expected = {
|
||||
ImageConfig(
|
||||
image="./image_1.jpg",
|
||||
captions=frozenset([
|
||||
Caption(main_prompt="an image", tags=frozenset([Tag("test"), Tag("from .txt")]))
|
||||
])),
|
||||
ImageConfig(
|
||||
image="./image_2.jpg",
|
||||
captions=frozenset([
|
||||
Caption(main_prompt="an image", tags=frozenset([Tag("test"), Tag("from .caption")]))
|
||||
]))
|
||||
}
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
|
||||
def test_image_yaml(self):
|
||||
self.fs.create_file("image_1.jpg")
|
||||
self.fs.create_file("image_1.yaml",
|
||||
contents=dedent("""
|
||||
multiply: 2
|
||||
cond_dropout: 0.05
|
||||
flip_p: 0.5
|
||||
caption: "A simple caption, from .yaml"
|
||||
"""))
|
||||
self.fs.create_file("image_2.jpg")
|
||||
self.fs.create_file("image_2.yml",
|
||||
contents=dedent("""
|
||||
flip_p: 0.0
|
||||
caption:
|
||||
main_prompt: A complex caption
|
||||
rating: 1.1
|
||||
max_caption_length: 1024
|
||||
tags:
|
||||
- tag: from .yml
|
||||
- tag: with weight
|
||||
weight: 0.5
|
||||
"""))
|
||||
|
||||
actual = Dataset.from_path(".").image_configs
|
||||
|
||||
expected = {
|
||||
ImageConfig(
|
||||
image="./image_1.jpg",
|
||||
multiply=2,
|
||||
cond_dropout=0.05,
|
||||
flip_p=0.5,
|
||||
captions=frozenset([
|
||||
Caption(main_prompt="A simple caption", tags=frozenset([Tag("from .yaml")]))
|
||||
])),
|
||||
ImageConfig(
|
||||
image="./image_2.jpg",
|
||||
flip_p=0.0,
|
||||
captions=frozenset([
|
||||
Caption(main_prompt="A complex caption", rating=1.1,
|
||||
max_caption_length=1024,
|
||||
tags=frozenset([
|
||||
Tag("from .yml"),
|
||||
Tag("with weight", weight=0.5)
|
||||
]))
|
||||
]))
|
||||
}
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
|
||||
def test_multi_caption(self):
|
||||
self.fs.create_file("image_1.jpg")
|
||||
self.fs.create_file("image_1.yaml", contents=dedent("""
|
||||
caption: "A simple caption, from .yaml"
|
||||
captions:
|
||||
- "Another simple caption"
|
||||
- main_prompt: A complex caption
|
||||
"""))
|
||||
self.fs.create_file("image_1.txt", contents="A .txt caption")
|
||||
self.fs.create_file("image_1.caption", contents="A .caption caption")
|
||||
|
||||
actual = Dataset.from_path(".").image_configs
|
||||
|
||||
expected = {
|
||||
ImageConfig(
|
||||
image="./image_1.jpg",
|
||||
captions=frozenset([
|
||||
Caption(main_prompt="A simple caption", tags=frozenset([Tag("from .yaml")])),
|
||||
Caption(main_prompt="Another simple caption", tags=frozenset()),
|
||||
Caption(main_prompt="A complex caption", tags=frozenset()),
|
||||
Caption(main_prompt="A .txt caption", tags=frozenset()),
|
||||
Caption(main_prompt="A .caption caption", tags=frozenset())
|
||||
])
|
||||
),
|
||||
}
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_globals_and_locals(self):
|
||||
self.fs.create_file("./people/global.yaml", contents=dedent("""\
|
||||
multiply: 1.0
|
||||
cond_dropout: 0.0
|
||||
flip_p: 0.0
|
||||
"""))
|
||||
self.fs.create_file("./people/alice/local.yaml", contents="multiply: 1.5")
|
||||
self.fs.create_file("./people/alice/alice_1.png")
|
||||
self.fs.create_file("./people/alice/alice_1.yaml", contents="multiply: 2")
|
||||
self.fs.create_file("./people/alice/alice_2.png")
|
||||
|
||||
self.fs.create_file("./people/bob/multiply.txt", contents="3")
|
||||
self.fs.create_file("./people/bob/cond_dropout.txt", contents="0.05")
|
||||
self.fs.create_file("./people/bob/flip_p.txt", contents="0.05")
|
||||
self.fs.create_file("./people/bob/bob.png")
|
||||
|
||||
self.fs.create_file("./people/cleo/cleo.png")
|
||||
self.fs.create_file("./people/dan.png")
|
||||
|
||||
self.fs.create_file("./other/dog/local.yaml", contents="caption: spike")
|
||||
self.fs.create_file("./other/dog/xyz.png")
|
||||
|
||||
actual = Dataset.from_path(".").image_configs
|
||||
|
||||
expected = {
|
||||
ImageConfig(
|
||||
image="./people/alice/alice_1.png",
|
||||
captions=frozenset([Caption(main_prompt="alice")]),
|
||||
multiply=2,
|
||||
cond_dropout=0.0,
|
||||
flip_p=0.0
|
||||
),
|
||||
ImageConfig(
|
||||
image="./people/alice/alice_2.png",
|
||||
captions=frozenset([Caption(main_prompt="alice")]),
|
||||
multiply=1.5,
|
||||
cond_dropout=0.0,
|
||||
flip_p=0.0
|
||||
),
|
||||
ImageConfig(
|
||||
image="./people/bob/bob.png",
|
||||
captions=frozenset([Caption(main_prompt="bob")]),
|
||||
multiply=3,
|
||||
cond_dropout=0.05,
|
||||
flip_p=0.05
|
||||
),
|
||||
ImageConfig(
|
||||
image="./people/cleo/cleo.png",
|
||||
captions=frozenset([Caption(main_prompt="cleo")]),
|
||||
multiply=1.0,
|
||||
cond_dropout=0.0,
|
||||
flip_p=0.0
|
||||
),
|
||||
ImageConfig(
|
||||
image="./people/dan.png",
|
||||
captions=frozenset([Caption(main_prompt="dan")]),
|
||||
multiply=1.0,
|
||||
cond_dropout=0.0,
|
||||
flip_p=0.0
|
||||
),
|
||||
ImageConfig(
|
||||
image="./other/dog/xyz.png",
|
||||
captions=frozenset([Caption(main_prompt="spike")]),
|
||||
multiply=None,
|
||||
cond_dropout=None,
|
||||
flip_p=None
|
||||
)
|
||||
}
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_json_manifest(self):
|
||||
self.fs.create_file("./stuff/image_1.jpg")
|
||||
self.fs.create_file("./stuff/default.caption", contents= "default caption")
|
||||
self.fs.create_file("./other/image_1.jpg")
|
||||
self.fs.create_file("./other/image_2.jpg")
|
||||
self.fs.create_file("./other/image_3.jpg")
|
||||
self.fs.create_file("./manifest.json", contents=dedent("""
|
||||
[
|
||||
{ "image": "./stuff/image_1.jpg", "caption": "./stuff/default.caption" },
|
||||
{ "image": "./other/image_1.jpg", "caption": "other caption" },
|
||||
{
|
||||
"image": "./other/image_2.jpg",
|
||||
"caption": {
|
||||
"main_prompt": "complex caption",
|
||||
"rating": 0.1,
|
||||
"max_caption_length": 1000,
|
||||
"tags": [
|
||||
{"tag": "including"},
|
||||
{"tag": "weighted tag", "weight": 999.9}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"image": "./other/image_3.jpg",
|
||||
"multiply": 2,
|
||||
"flip_p": 0.5,
|
||||
"cond_dropout": 0.01,
|
||||
"captions": [
|
||||
"first caption",
|
||||
{ "main_prompt": "second caption" }
|
||||
]
|
||||
}
|
||||
]
|
||||
"""))
|
||||
|
||||
actual = Dataset.from_json("./manifest.json").image_configs
|
||||
expected = {
|
||||
ImageConfig(
|
||||
image="./stuff/image_1.jpg",
|
||||
captions=frozenset([Caption(main_prompt="default caption")])
|
||||
),
|
||||
ImageConfig(
|
||||
image="./other/image_1.jpg",
|
||||
captions=frozenset([Caption(main_prompt="other caption")])
|
||||
),
|
||||
ImageConfig(
|
||||
image="./other/image_2.jpg",
|
||||
captions=frozenset([
|
||||
Caption(
|
||||
main_prompt="complex caption",
|
||||
rating=0.1,
|
||||
max_caption_length=1000,
|
||||
tags=frozenset([
|
||||
Tag("including"),
|
||||
Tag("weighted tag", 999.9)
|
||||
]))
|
||||
])
|
||||
),
|
||||
ImageConfig(
|
||||
image="./other/image_3.jpg",
|
||||
multiply=2,
|
||||
flip_p=0.5,
|
||||
cond_dropout=0.01,
|
||||
captions=frozenset([
|
||||
Caption("first caption"),
|
||||
Caption("second caption")
|
||||
])
|
||||
)
|
||||
}
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_train_items(self):
|
||||
dataset = Dataset([
|
||||
ImageConfig(
|
||||
image="1.jpg",
|
||||
multiply=2,
|
||||
flip_p=0.1,
|
||||
cond_dropout=0.01,
|
||||
captions=frozenset([
|
||||
Caption(
|
||||
main_prompt="first caption",
|
||||
rating = 1.1,
|
||||
max_caption_length=1024,
|
||||
tags=frozenset([
|
||||
Tag("tag"),
|
||||
Tag("tag_2", 2.0)
|
||||
])),
|
||||
Caption(main_prompt="second_caption")
|
||||
])),
|
||||
ImageConfig(
|
||||
image="2.jpg",
|
||||
captions=frozenset([Caption(main_prompt="single caption")])
|
||||
)
|
||||
])
|
||||
|
||||
aspects = []
|
||||
actual = dataset.image_train_items(aspects)
|
||||
|
||||
self.assertEqual(len(actual), 2)
|
||||
|
||||
self.assertEqual(actual[0].pathname, os.path.abspath('1.jpg'))
|
||||
self.assertEqual(actual[0].multiplier, 2.0)
|
||||
self.assertEqual(actual[0].flip.p, 0.1)
|
||||
self.assertEqual(actual[0].cond_dropout, 0.01)
|
||||
self.assertEqual(actual[0].caption.get_caption(), "first caption, tag, tag_2")
|
||||
# Can't test this
|
||||
# self.assertTrue(actual[0].caption.__use_weights)
|
||||
|
||||
self.assertEqual(actual[1].pathname, os.path.abspath('2.jpg'))
|
||||
self.assertEqual(actual[1].multiplier, 1.0)
|
||||
self.assertEqual(actual[1].flip.p, 0.0)
|
||||
self.assertIsNone(actual[1].cond_dropout)
|
||||
self.assertEqual(actual[1].caption.get_caption(), "single caption")
|
||||
# Can't test this
|
||||
# self.assertFalse(actual[1].caption.__use_weights)
|
|
@ -32,40 +32,4 @@ class TestImageCaption(unittest.TestCase):
|
|||
self.assertEqual(caption.get_caption(), "hello world, one, two, three")
|
||||
|
||||
caption = ImageCaption("hello world", 1.0, [], [], 2048, False)
|
||||
self.assertEqual(caption.get_caption(), "hello world")
|
||||
|
||||
def test_parse(self):
|
||||
caption = ImageCaption.parse("hello world, one, two, three")
|
||||
|
||||
self.assertEqual(caption.get_caption(), "hello world, one, two, three")
|
||||
|
||||
def test_from_file_name(self):
|
||||
caption = ImageCaption.from_file_name("foo bar_1_2_3.jpg")
|
||||
self.assertEqual(caption.get_caption(), "foo bar")
|
||||
|
||||
def test_from_text_file(self):
|
||||
caption = ImageCaption.from_text_file("test/data/test1.txt")
|
||||
self.assertEqual(caption.get_caption(), "caption for test1")
|
||||
|
||||
def test_from_file(self):
|
||||
caption = ImageCaption.from_file("test/data/test1.txt")
|
||||
self.assertEqual(caption.get_caption(), "caption for test1")
|
||||
|
||||
caption = ImageCaption.from_file("test/data/test_caption.caption")
|
||||
self.assertEqual(caption.get_caption(), "caption for test2")
|
||||
|
||||
def test_resolve(self):
|
||||
caption = ImageCaption.resolve("test/data/test1.txt")
|
||||
self.assertEqual(caption.get_caption(), "caption for test1")
|
||||
|
||||
caption = ImageCaption.resolve("test/data/test_caption.caption")
|
||||
self.assertEqual(caption.get_caption(), "caption for test2")
|
||||
|
||||
caption = ImageCaption.resolve("hello world")
|
||||
self.assertEqual(caption.get_caption(), "hello world")
|
||||
|
||||
caption = ImageCaption.resolve("test/data/test1.jpg")
|
||||
self.assertEqual(caption.get_caption(), "caption for test1")
|
||||
|
||||
caption = ImageCaption.resolve("test/data/test2.jpg")
|
||||
self.assertEqual(caption.get_caption(), "test2")
|
||||
self.assertEqual(caption.get_caption(), "hello world")
|
|
@ -0,0 +1,46 @@
|
|||
|
||||
def barename(file):
|
||||
(val, _) = os.path.splitext(os.path.basename(file))
|
||||
return val
|
||||
|
||||
def ext(file):
|
||||
(_, val) = os.path.splitext(os.path.basename(file))
|
||||
return val.lower()
|
||||
|
||||
def same_barename(lhs, rhs):
|
||||
return barename(lhs) == barename(rhs)
|
||||
|
||||
def is_image(file):
|
||||
return ext(file) in {'.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif'}
|
||||
|
||||
def read_text(file):
|
||||
try:
|
||||
with open(file, encoding='utf-8', mode='r') as stream:
|
||||
return stream.read().strip()
|
||||
except Exception as e:
|
||||
logging.warning(f" *** Error reading text file {file}: {e}")
|
||||
|
||||
def read_float(file):
|
||||
try:
|
||||
return float(read_text(file))
|
||||
except Exception as e:
|
||||
logging.warning(f" *** Could not parse '{data}' to float in file {file}: {e}")
|
||||
|
||||
import os
|
||||
|
||||
def walk_and_visit(path, visit_fn, context=None):
|
||||
names = [entry.name for entry in os.scandir(path)]
|
||||
|
||||
dirs = []
|
||||
files = []
|
||||
for name in names:
|
||||
fullname = os.path.join(path, name)
|
||||
if os.path.isdir(fullname) and not str(name).startswith('.'):
|
||||
dirs.append(fullname)
|
||||
else:
|
||||
files.append(fullname)
|
||||
|
||||
subcontext = visit_fn(files, context)
|
||||
|
||||
for subdir in dirs:
|
||||
walk_and_visit(subdir, visit_fn, subcontext)
|
Loading…
Reference in New Issue