Add support for enhanced dataset configuration

Add support for:
* flip_p.txt
* cond_dropout.txt
* local.yaml consolidated config (including default captions)
* global.yaml consolidated config which applies recursively to subfolders
* flip_p, and cond_dropout config per image
* manifest.json with full image-level configuration
This commit is contained in:
Augusto de la Torre 2023-03-08 15:02:14 +01:00
parent 2e3d044ba3
commit 0716c40ab6
9 changed files with 648 additions and 279 deletions

261
data/dataset.py Normal file
View File

@ -0,0 +1,261 @@
import os
import logging
import yaml
import json
from functools import total_ordering
from attrs import define, Factory
from data.image_train_item import ImageCaption, ImageTrainItem
from utils.fs_helpers import *
@define(frozen=True)
@total_ordering
class Tag:
value: str
weight: float = None
def __lt__(self, other):
return self.value < other.value
@define(frozen=True)
@total_ordering
class Caption:
main_prompt: str = None
rating: float = None
max_caption_length: int = None
tags: frozenset[Tag] = Factory(frozenset)
@classmethod
def from_dict(cls, data: dict):
main_prompt = data.get("main_prompt")
rating = data.get("rating")
max_caption_length = data.get("max_caption_length")
tags = frozenset([ Tag(value=t.get("tag"), weight=t.get("weight"))
for t in data.get("tags", [])
if "tag" in t and len(t.get("tag")) > 0])
if not main_prompt and not tags:
return None
return Caption(main_prompt=main_prompt, rating=rating, max_caption_length=max_caption_length, tags=tags)
@classmethod
def from_text(cls, text: str):
if text is None:
return Caption(main_prompt="")
split_caption = list(map(str.strip, text.split(",")))
main_prompt = split_caption[0]
tags = frozenset(Tag(value=t) for t in split_caption[1:])
return Caption(main_prompt=main_prompt, tags=tags)
@classmethod
def load(cls, input):
if isinstance(input, str):
if os.path.isfile(input):
return Caption.from_text(read_text(input))
else:
return Caption.from_text(input)
elif isinstance(input, dict):
return Caption.from_dict(input)
def __lt__(self, other):
self_str = ",".join([self.main_prompt] + sorted(repr(t) for t in self.tags))
other_str = ",".join([other.main_prompt] + sorted(repr(t) for t in other.tags))
return self_str < other_str
@define(frozen=True)
class ImageConfig:
image: str = None
captions: frozenset[Caption] = Factory(frozenset)
multiply: float = None
cond_dropout: float = None
flip_p: float = None
@classmethod
def fold(cls, configs):
acc = ImageConfig()
[acc := acc.merge(cfg) for cfg in configs]
return acc
def merge(self, other):
if other is None:
return self
if other.image and self.image:
logging(f"Found two images with different extensions and the same barename: {self.image} and {other.image}")
return ImageConfig(
image = other.image or self.image,
captions = other.captions.union(self.captions),
multiply = other.multiply if other.multiply is not None else self.multiply,
cond_dropout = other.cond_dropout if other.cond_dropout is not None else self.cond_dropout,
flip_p = other.flip_p if other.flip_p is not None else self.flip_p
)
def ensure_caption(self):
if not self.captions:
filename_caption = Caption.from_text(barename(self.image).split("_")[0])
return self.merge(ImageConfig(captions=frozenset([filename_caption])))
return self
@classmethod
def from_dict(cls, data: dict):
captions = set()
if "captions" in data:
captions.update(Caption.load(cap) for cap in data.get("captions"))
if "caption" in data:
captions.add(Caption.load(data.get("caption")))
if not captions:
# For backward compatibility with existing caption yaml
caption = Caption.load(data)
if caption:
captions.add(caption)
return ImageConfig(
image = data.get("image"),
captions=frozenset(captions),
multiply=data.get("multiply"),
cond_dropout=data.get("cond_dropout"),
flip_p=data.get("flip_p"))
@classmethod
def from_text(cls, text: str):
try:
if os.path.isfile(text):
return ImageConfig.from_file(text)
return ImageConfig(captions=frozenset({Caption.from_text(text)}))
except Exception as e:
logging.warning(f" *** Error parsing config from text {text}: \n{e}")
@classmethod
def from_file(cls, file: str):
try:
match ext(file):
case '.jpg' | '.jpeg' | '.png' | '.bmp' | '.webp' | '.jfif':
return ImageConfig(image=file)
case ".json":
return ImageConfig.from_dict(json.load(read_text(file)))
case ".yaml" | ".yml":
return ImageConfig.from_dict(yaml.safe_load(read_text(file)))
case ".txt" | ".caption":
return ImageConfig.from_text(read_text(file))
case _:
return logging.warning(" *** Unrecognized config extension {ext}")
except Exception as e:
logging.warning(f" *** Error parsing config from {file}: {e}")
@classmethod
def load(cls, input):
if isinstance(input, str):
return ImageConfig.from_text(input)
elif isinstance(input, dict):
return ImageConfig.from_dict(input)
@define()
class Dataset:
image_configs: set[ImageConfig]
def __global_cfg(files):
cfgs = []
for file in files:
match os.path.basename(file):
case 'global.yaml' | 'global.yml':
cfgs.append(ImageConfig.from_file(file))
return ImageConfig.fold(cfgs)
def __local_cfg(files):
cfgs = []
for file in files:
match os.path.basename(file):
case 'multiply.txt':
cfgs.append(ImageConfig(multiply=read_float(file)))
case 'cond_dropout.txt':
cfgs.append(ImageConfig(cond_dropout=read_float(file)))
case 'flip_p.txt':
cfgs.append(ImageConfig(flip_p=read_float(file)))
case 'local.yaml' | 'local.yml':
cfgs.append(ImageConfig.from_file(file))
return ImageConfig.fold(cfgs)
def __image_cfg(imagepath, files):
cfgs = [ImageConfig.from_file(imagepath)]
for file in files:
if same_barename(imagepath, file):
match ext(file):
case '.txt' | '.caption' | '.yml' | '.yaml':
cfgs.append(ImageConfig.from_file(file))
return ImageConfig.fold(cfgs)
@classmethod
def from_path(cls, data_root):
# Create a visitor that maintains global config stack
# and accumulates image configs as it traverses dataset
image_configs = set()
def process_dir(files, parent_globals):
globals = parent_globals.merge(Dataset.__global_cfg(files))
locals = Dataset.__local_cfg(files)
for img in filter(is_image, files):
img_cfg = Dataset.__image_cfg(img, files)
collapsed_cfg = ImageConfig.fold([globals, locals, img_cfg])
resolved_cfg = collapsed_cfg.ensure_caption()
image_configs.add(resolved_cfg)
return globals
walk_and_visit(data_root, process_dir, ImageConfig())
return Dataset(image_configs)
@classmethod
def from_json(cls, json_path):
"""
Import a dataset definition from a JSON file
"""
configs = set()
with open(json_path, encoding='utf-8', mode='r') as stream:
for data in json.load(stream):
cfg = ImageConfig.load(data).ensure_caption()
if not cfg or not cfg.image:
logging.warning(f" *** Error parsing json image entry in {json_path}: {data}")
continue
configs.add(cfg)
return Dataset(configs)
def image_train_items(self, aspects):
items = []
for config in self.image_configs:
caption = next(iter(sorted(config.captions)))
if len(config.captions) > 1:
logging.warning(f" *** Found multiple captions for image {config.image}, but only one will be applied: {config.captions}")
use_weights = len(set(t.weight or 1.0 for t in caption.tags)) > 1
tags = []
tag_weights = []
for tag in sorted(caption.tags):
tags.append(tag.value)
tag_weights.append(tag.weight or 1.0)
use_weights = len(set(tag_weights)) > 1
caption = ImageCaption(
main_prompt=caption.main_prompt,
rating=caption.rating,
tags=tags,
tag_weights=tag_weights,
max_target_length=caption.max_caption_length,
use_weights=use_weights)
item = ImageTrainItem(
image=None,
caption=caption,
aspects=aspects,
pathname=os.path.abspath(config.image),
flip_p=config.flip_p or 0.0,
multiplier=config.multiply or 1.0,
cond_dropout=config.cond_dropout
)
items.append(item)
return list(sorted(items, key=lambda ti: ti.pathname))

View File

@ -104,7 +104,7 @@ class EveryDreamBatch(Dataset):
example["image"] = image_transforms(train_item["image"]) example["image"] = image_transforms(train_item["image"])
if random.random() > self.conditional_dropout: if random.random() > (train_item.cond_dropout or self.conditional_dropout):
example["tokens"] = self.tokenizer(example["caption"], example["tokens"] = self.tokenizer(example["caption"],
truncation=True, truncation=True,
padding="max_length", padding="max_length",

View File

@ -113,136 +113,6 @@ class ImageCaption:
random.Random(seed).shuffle(tags) random.Random(seed).shuffle(tags)
return ", ".join(tags) return ", ".join(tags)
@staticmethod
def parse(string: str) -> 'ImageCaption':
"""
Parses a string to get the caption.
:param string: String to parse.
:return: `ImageCaption` object.
"""
split_caption = list(map(str.strip, string.split(",")))
main_prompt = split_caption[0]
tags = split_caption[1:]
tag_weights = [1.0] * len(tags)
return ImageCaption(main_prompt, 1.0, tags, tag_weights, DEFAULT_MAX_CAPTION_LENGTH, False)
@staticmethod
def from_file_name(file_path: str) -> 'ImageCaption':
"""
Parses the file name to get the caption.
:param file_path: Path to the image file.
:return: `ImageCaption` object.
"""
(file_name, _) = os.path.splitext(os.path.basename(file_path))
caption = file_name.split("_")[0]
return ImageCaption.parse(caption)
@staticmethod
def from_text_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
"""
Parses a text file to get the caption. Returns the default caption if
the file does not exist or is invalid.
:param file_path: Path to the text file.
:param default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid.
:return: `ImageCaption` object or `None`.
"""
try:
with open(file_path, encoding='utf-8', mode='r') as caption_file:
caption_text = caption_file.read()
return ImageCaption.parse(caption_text)
except:
logging.error(f" *** Error reading {file_path} to get caption")
return default_caption
@staticmethod
def from_yaml_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
"""
Parses a yaml file to get the caption. Returns the default caption if
the file does not exist or is invalid.
:param file_path: path to the yaml file
:param default_caption: caption to return if the file does not exist or is invalid
:return: `ImageCaption` object or `None`.
"""
try:
with open(file_path, "r") as stream:
file_content = yaml.safe_load(stream)
main_prompt = file_content.get("main_prompt", "")
rating = file_content.get("rating", 1.0)
unparsed_tags = file_content.get("tags", [])
max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH)
tags = []
tag_weights = []
last_weight = None
weights_differ = False
for unparsed_tag in unparsed_tags:
tag = unparsed_tag.get("tag", "").strip()
if len(tag) == 0:
continue
tags.append(tag)
tag_weight = unparsed_tag.get("weight", 1.0)
tag_weights.append(tag_weight)
if last_weight is not None and weights_differ is False:
weights_differ = last_weight != tag_weight
last_weight = tag_weight
return ImageCaption(main_prompt, rating, tags, tag_weights, max_caption_length, weights_differ)
except:
logging.error(f" *** Error reading {file_path} to get caption")
return default_caption
@staticmethod
def from_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
"""
Try to resolve a caption from a file path or return `default_caption`.
:string: The path to the file to parse.
:default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid.
:return: `ImageCaption` object or `None`.
"""
if os.path.exists(file_path):
(file_path_without_ext, ext) = os.path.splitext(file_path)
match ext:
case ".yaml" | ".yml":
return ImageCaption.from_yaml_file(file_path, default_caption)
case ".txt" | ".caption":
return ImageCaption.from_text_file(file_path, default_caption)
case '.jpg'| '.jpeg'| '.png'| '.bmp'| '.webp'| '.jfif':
for ext in [".yaml", ".yml", ".txt", ".caption"]:
file_path = file_path_without_ext + ext
image_caption = ImageCaption.from_file(file_path)
if image_caption is not None:
return image_caption
return ImageCaption.from_file_name(file_path)
case _:
return default_caption
else:
return default_caption
@staticmethod
def resolve(string: str) -> 'ImageCaption':
"""
Try to resolve a caption from a string. If the string is a file path,
the caption will be read from the file, otherwise the string will be
parsed as a caption.
:string: The string to resolve.
:return: `ImageCaption` object.
"""
return ImageCaption.from_file(string, None) or ImageCaption.parse(string)
class ImageTrainItem: class ImageTrainItem:
""" """
@ -253,7 +123,7 @@ class ImageTrainItem:
flip_p: probability of flipping image (0.0 to 1.0) flip_p: probability of flipping image (0.0 to 1.0)
rating: the relative rating of the images. The rating is measured in comparison to the other images. rating: the relative rating of the images. The rating is measured in comparison to the other images.
""" """
def __init__(self, image: PIL.Image, caption: ImageCaption, aspects: list[float], pathname: str, flip_p=0.0, multiplier: float=1.0): def __init__(self, image: PIL.Image, caption: ImageCaption, aspects: list[float], pathname: str, flip_p=0.0, multiplier: float=1.0, cond_dropout=None):
self.caption = caption self.caption = caption
self.aspects = aspects self.aspects = aspects
self.pathname = pathname self.pathname = pathname
@ -261,6 +131,7 @@ class ImageTrainItem:
self.cropped_img = None self.cropped_img = None
self.runt_size = 0 self.runt_size = 0
self.multiplier = multiplier self.multiplier = multiplier
self.cond_dropout = cond_dropout
self.image_size = None self.image_size = None
if image is None or len(image) == 0: if image is None or len(image) == 0:

View File

@ -4,6 +4,7 @@ import os
import typing import typing
import zipfile import zipfile
import argparse import argparse
from data.dataset import Dataset
import tqdm import tqdm
from colorama import Fore, Style from colorama import Fore, Style
@ -27,16 +28,6 @@ class DataResolver:
""" """
raise NotImplementedError() raise NotImplementedError()
def image_train_item(self, image_path: str, caption: ImageCaption, multiplier: float=1) -> ImageTrainItem:
return ImageTrainItem(
image=None,
caption=caption,
aspects=self.aspects,
pathname=image_path,
flip_p=self.flip_p,
multiplier=multiplier
)
class JSONResolver(DataResolver): class JSONResolver(DataResolver):
def image_train_items(self, json_path: str) -> list[ImageTrainItem]: def image_train_items(self, json_path: str) -> list[ImageTrainItem]:
""" """
@ -45,61 +36,7 @@ class JSONResolver(DataResolver):
:param json_path: The path to the JSON file. :param json_path: The path to the JSON file.
""" """
items = [] return Dataset.from_json(json_path).image_train_items(self.aspects)
with open(json_path, encoding='utf-8', mode='r') as f:
json_data = json.load(f)
for data in tqdm.tqdm(json_data):
caption = JSONResolver.image_caption(data)
if caption:
image_value = JSONResolver.get_image_value(data)
item = self.image_train_item(image_value, caption)
if item:
items.append(item)
return items
@staticmethod
def get_image_value(json_data: dict) -> typing.Optional[str]:
"""
Get the image from the json data if possible.
:param json_data: The json data, a dict.
:return: The image, or None if not found.
"""
image_value = json_data.get("image", None)
if isinstance(image_value, str):
image_value = image_value.strip()
if os.path.exists(image_value):
return image_value
@staticmethod
def get_caption_value(json_data: dict) -> typing.Optional[str]:
"""
Get the caption from the json data if possible.
:param json_data: The json data, a dict.
:return: The caption, or None if not found.
"""
caption_value = json_data.get("caption", None)
if isinstance(caption_value, str):
return caption_value.strip()
@staticmethod
def image_caption(json_data: dict) -> typing.Optional[ImageCaption]:
"""
Get the caption from the json data if possible.
:param json_data: The json data, a dict.
:return: The `ImageCaption`, or None if not found.
"""
image_value = JSONResolver.get_image_value(json_data)
caption_value = JSONResolver.get_caption_value(json_data)
if image_value:
if caption_value:
return ImageCaption.resolve(caption_value)
return ImageCaption.from_file(image_value)
class DirectoryResolver(DataResolver): class DirectoryResolver(DataResolver):
def image_train_items(self, data_root: str) -> list[ImageTrainItem]: def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
@ -111,32 +48,7 @@ class DirectoryResolver(DataResolver):
:param data_root: The root directory to recurse through :param data_root: The root directory to recurse through
""" """
DirectoryResolver.unzip_all(data_root) DirectoryResolver.unzip_all(data_root)
image_paths = list(DirectoryResolver.recurse_data_root(data_root)) return Dataset.from_path(data_root).image_train_items(self.aspects)
items = []
multipliers = {}
for pathname in tqdm.tqdm(image_paths):
current_dir = os.path.dirname(pathname)
if current_dir not in multipliers:
multiply_txt_path = os.path.join(current_dir, "multiply.txt")
if os.path.exists(multiply_txt_path):
try:
with open(multiply_txt_path, 'r') as f:
val = float(f.read().strip())
multipliers[current_dir] = val
logging.info(f" - multiply.txt in '{current_dir}' set to {val}")
except Exception as e:
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}")
multipliers[current_dir] = 1.0
else:
multipliers[current_dir] = 1.0
caption = ImageCaption.resolve(pathname)
item = self.image_train_item(pathname, caption, multiplier=multipliers[current_dir])
items.append(item)
return items
@staticmethod @staticmethod
def unzip_all(path): def unzip_all(path):
@ -150,21 +62,6 @@ class DirectoryResolver(DataResolver):
except Exception as e: except Exception as e:
logging.error(f"Error unzipping files {e}") logging.error(f"Error unzipping files {e}")
@staticmethod
def recurse_data_root(recurse_root):
for f in os.listdir(recurse_root):
current = os.path.join(recurse_root, f)
if os.path.isfile(current):
ext = os.path.splitext(f)[1].lower()
if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']:
yield current
for d in os.listdir(recurse_root):
current = os.path.join(recurse_root, d)
if os.path.isdir(current):
yield from DirectoryResolver.recurse_data_root(current)
def strategy(data_root: str) -> typing.Type[DataResolver]: def strategy(data_root: str) -> typing.Type[DataResolver]:
""" """
Determine the strategy to use for resolving the data. Determine the strategy to use for resolving the data.

View File

@ -10,6 +10,7 @@ ninja
omegaconf==2.2.3 omegaconf==2.2.3
piexif==1.1.3 piexif==1.1.3
protobuf==3.20.3 protobuf==3.20.3
pyfakefs
pynvml==11.5.0 pynvml==11.5.0
pyre-extensions==0.0.30 pyre-extensions==0.0.30
pytorch-lightning==1.9.2 pytorch-lightning==1.9.2

View File

@ -58,13 +58,13 @@ class TestResolve(unittest.TestCase):
def test_directory_resolve_with_str(self): def test_directory_resolve_with_str(self):
items = resolver.resolve(DATA_PATH, ARGS) items = resolver.resolve(DATA_PATH, ARGS)
image_paths = [item.pathname for item in items] image_paths = set(item.pathname for item in items)
image_captions = [item.caption for item in items] image_captions = [item.caption for item in items]
captions = [caption.get_caption() for caption in image_captions] captions = set(caption.get_caption() for caption in image_captions)
self.assertEqual(len(items), 3) self.assertEqual(len(items), 3)
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH]) self.assertEqual(image_paths, {IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH})
self.assertEqual(captions, ['caption for test1', 'test2', 'test3']) self.assertEqual(captions, {'caption for test1', 'test2', 'test3'})
undersized_images = list(filter(lambda i: i.is_undersized, items)) undersized_images = list(filter(lambda i: i.is_undersized, items))
self.assertEqual(len(undersized_images), 1) self.assertEqual(len(undersized_images), 1)

329
test/test_dataset.py Normal file
View File

@ -0,0 +1,329 @@
import os
from data.dataset import Dataset, ImageConfig, Caption, Tag
from textwrap import dedent
from pyfakefs.fake_filesystem_unittest import TestCase
class TestResolve(TestCase):
def setUp(self):
self.setUpPyfakefs()
def test_simple_image(self):
self.fs.create_file("image, tag1, tag2.jpg")
actual = Dataset.from_path(".").image_configs
expected = {
ImageConfig(
image="./image, tag1, tag2.jpg",
captions=frozenset([
Caption(main_prompt="image", tags=frozenset([Tag("tag1"), Tag("tag2")]))
]))
}
self.assertEqual(expected, actual)
def test_image_types(self):
self.fs.create_file("image_1.JPG")
self.fs.create_file("image_2.jpeg")
self.fs.create_file("image_3.png")
self.fs.create_file("image_4.webp")
self.fs.create_file("image_5.jfif")
self.fs.create_file("image_6.bmp")
actual = Dataset.from_path(".").image_configs
captions = frozenset([Caption(main_prompt="image")])
expected = {
ImageConfig(image="./image_1.JPG", captions=captions),
ImageConfig(image="./image_2.jpeg", captions=captions),
ImageConfig(image="./image_3.png", captions=captions),
ImageConfig(image="./image_4.webp", captions=captions),
ImageConfig(image="./image_5.jfif", captions=captions),
ImageConfig(image="./image_6.bmp", captions=captions),
}
self.assertEqual(expected, actual)
def test_caption_file(self):
self.fs.create_file("image_1.jpg")
self.fs.create_file("image_1.txt", contents="an image, test, from .txt")
self.fs.create_file("image_2.jpg")
self.fs.create_file("image_2.caption", contents="an image, test, from .caption")
actual = Dataset.from_path(".").image_configs
expected = {
ImageConfig(
image="./image_1.jpg",
captions=frozenset([
Caption(main_prompt="an image", tags=frozenset([Tag("test"), Tag("from .txt")]))
])),
ImageConfig(
image="./image_2.jpg",
captions=frozenset([
Caption(main_prompt="an image", tags=frozenset([Tag("test"), Tag("from .caption")]))
]))
}
self.assertEqual(expected, actual)
def test_image_yaml(self):
self.fs.create_file("image_1.jpg")
self.fs.create_file("image_1.yaml",
contents=dedent("""
multiply: 2
cond_dropout: 0.05
flip_p: 0.5
caption: "A simple caption, from .yaml"
"""))
self.fs.create_file("image_2.jpg")
self.fs.create_file("image_2.yml",
contents=dedent("""
flip_p: 0.0
caption:
main_prompt: A complex caption
rating: 1.1
max_caption_length: 1024
tags:
- tag: from .yml
- tag: with weight
weight: 0.5
"""))
actual = Dataset.from_path(".").image_configs
expected = {
ImageConfig(
image="./image_1.jpg",
multiply=2,
cond_dropout=0.05,
flip_p=0.5,
captions=frozenset([
Caption(main_prompt="A simple caption", tags=frozenset([Tag("from .yaml")]))
])),
ImageConfig(
image="./image_2.jpg",
flip_p=0.0,
captions=frozenset([
Caption(main_prompt="A complex caption", rating=1.1,
max_caption_length=1024,
tags=frozenset([
Tag("from .yml"),
Tag("with weight", weight=0.5)
]))
]))
}
self.assertEqual(expected, actual)
def test_multi_caption(self):
self.fs.create_file("image_1.jpg")
self.fs.create_file("image_1.yaml", contents=dedent("""
caption: "A simple caption, from .yaml"
captions:
- "Another simple caption"
- main_prompt: A complex caption
"""))
self.fs.create_file("image_1.txt", contents="A .txt caption")
self.fs.create_file("image_1.caption", contents="A .caption caption")
actual = Dataset.from_path(".").image_configs
expected = {
ImageConfig(
image="./image_1.jpg",
captions=frozenset([
Caption(main_prompt="A simple caption", tags=frozenset([Tag("from .yaml")])),
Caption(main_prompt="Another simple caption", tags=frozenset()),
Caption(main_prompt="A complex caption", tags=frozenset()),
Caption(main_prompt="A .txt caption", tags=frozenset()),
Caption(main_prompt="A .caption caption", tags=frozenset())
])
),
}
self.assertEqual(expected, actual)
def test_globals_and_locals(self):
self.fs.create_file("./people/global.yaml", contents=dedent("""\
multiply: 1.0
cond_dropout: 0.0
flip_p: 0.0
"""))
self.fs.create_file("./people/alice/local.yaml", contents="multiply: 1.5")
self.fs.create_file("./people/alice/alice_1.png")
self.fs.create_file("./people/alice/alice_1.yaml", contents="multiply: 2")
self.fs.create_file("./people/alice/alice_2.png")
self.fs.create_file("./people/bob/multiply.txt", contents="3")
self.fs.create_file("./people/bob/cond_dropout.txt", contents="0.05")
self.fs.create_file("./people/bob/flip_p.txt", contents="0.05")
self.fs.create_file("./people/bob/bob.png")
self.fs.create_file("./people/cleo/cleo.png")
self.fs.create_file("./people/dan.png")
self.fs.create_file("./other/dog/local.yaml", contents="caption: spike")
self.fs.create_file("./other/dog/xyz.png")
actual = Dataset.from_path(".").image_configs
expected = {
ImageConfig(
image="./people/alice/alice_1.png",
captions=frozenset([Caption(main_prompt="alice")]),
multiply=2,
cond_dropout=0.0,
flip_p=0.0
),
ImageConfig(
image="./people/alice/alice_2.png",
captions=frozenset([Caption(main_prompt="alice")]),
multiply=1.5,
cond_dropout=0.0,
flip_p=0.0
),
ImageConfig(
image="./people/bob/bob.png",
captions=frozenset([Caption(main_prompt="bob")]),
multiply=3,
cond_dropout=0.05,
flip_p=0.05
),
ImageConfig(
image="./people/cleo/cleo.png",
captions=frozenset([Caption(main_prompt="cleo")]),
multiply=1.0,
cond_dropout=0.0,
flip_p=0.0
),
ImageConfig(
image="./people/dan.png",
captions=frozenset([Caption(main_prompt="dan")]),
multiply=1.0,
cond_dropout=0.0,
flip_p=0.0
),
ImageConfig(
image="./other/dog/xyz.png",
captions=frozenset([Caption(main_prompt="spike")]),
multiply=None,
cond_dropout=None,
flip_p=None
)
}
self.assertEqual(expected, actual)
def test_json_manifest(self):
self.fs.create_file("./stuff/image_1.jpg")
self.fs.create_file("./stuff/default.caption", contents= "default caption")
self.fs.create_file("./other/image_1.jpg")
self.fs.create_file("./other/image_2.jpg")
self.fs.create_file("./other/image_3.jpg")
self.fs.create_file("./manifest.json", contents=dedent("""
[
{ "image": "./stuff/image_1.jpg", "caption": "./stuff/default.caption" },
{ "image": "./other/image_1.jpg", "caption": "other caption" },
{
"image": "./other/image_2.jpg",
"caption": {
"main_prompt": "complex caption",
"rating": 0.1,
"max_caption_length": 1000,
"tags": [
{"tag": "including"},
{"tag": "weighted tag", "weight": 999.9}
]
}
},
{
"image": "./other/image_3.jpg",
"multiply": 2,
"flip_p": 0.5,
"cond_dropout": 0.01,
"captions": [
"first caption",
{ "main_prompt": "second caption" }
]
}
]
"""))
actual = Dataset.from_json("./manifest.json").image_configs
expected = {
ImageConfig(
image="./stuff/image_1.jpg",
captions=frozenset([Caption(main_prompt="default caption")])
),
ImageConfig(
image="./other/image_1.jpg",
captions=frozenset([Caption(main_prompt="other caption")])
),
ImageConfig(
image="./other/image_2.jpg",
captions=frozenset([
Caption(
main_prompt="complex caption",
rating=0.1,
max_caption_length=1000,
tags=frozenset([
Tag("including"),
Tag("weighted tag", 999.9)
]))
])
),
ImageConfig(
image="./other/image_3.jpg",
multiply=2,
flip_p=0.5,
cond_dropout=0.01,
captions=frozenset([
Caption("first caption"),
Caption("second caption")
])
)
}
self.assertEqual(expected, actual)
def test_train_items(self):
dataset = Dataset([
ImageConfig(
image="1.jpg",
multiply=2,
flip_p=0.1,
cond_dropout=0.01,
captions=frozenset([
Caption(
main_prompt="first caption",
rating = 1.1,
max_caption_length=1024,
tags=frozenset([
Tag("tag"),
Tag("tag_2", 2.0)
])),
Caption(main_prompt="second_caption")
])),
ImageConfig(
image="2.jpg",
captions=frozenset([Caption(main_prompt="single caption")])
)
])
aspects = []
actual = dataset.image_train_items(aspects)
self.assertEqual(len(actual), 2)
self.assertEqual(actual[0].pathname, os.path.abspath('1.jpg'))
self.assertEqual(actual[0].multiplier, 2.0)
self.assertEqual(actual[0].flip.p, 0.1)
self.assertEqual(actual[0].cond_dropout, 0.01)
self.assertEqual(actual[0].caption.get_caption(), "first caption, tag, tag_2")
# Can't test this
# self.assertTrue(actual[0].caption.__use_weights)
self.assertEqual(actual[1].pathname, os.path.abspath('2.jpg'))
self.assertEqual(actual[1].multiplier, 1.0)
self.assertEqual(actual[1].flip.p, 0.0)
self.assertIsNone(actual[1].cond_dropout)
self.assertEqual(actual[1].caption.get_caption(), "single caption")
# Can't test this
# self.assertFalse(actual[1].caption.__use_weights)

View File

@ -33,39 +33,3 @@ class TestImageCaption(unittest.TestCase):
caption = ImageCaption("hello world", 1.0, [], [], 2048, False) caption = ImageCaption("hello world", 1.0, [], [], 2048, False)
self.assertEqual(caption.get_caption(), "hello world") self.assertEqual(caption.get_caption(), "hello world")
def test_parse(self):
caption = ImageCaption.parse("hello world, one, two, three")
self.assertEqual(caption.get_caption(), "hello world, one, two, three")
def test_from_file_name(self):
caption = ImageCaption.from_file_name("foo bar_1_2_3.jpg")
self.assertEqual(caption.get_caption(), "foo bar")
def test_from_text_file(self):
caption = ImageCaption.from_text_file("test/data/test1.txt")
self.assertEqual(caption.get_caption(), "caption for test1")
def test_from_file(self):
caption = ImageCaption.from_file("test/data/test1.txt")
self.assertEqual(caption.get_caption(), "caption for test1")
caption = ImageCaption.from_file("test/data/test_caption.caption")
self.assertEqual(caption.get_caption(), "caption for test2")
def test_resolve(self):
caption = ImageCaption.resolve("test/data/test1.txt")
self.assertEqual(caption.get_caption(), "caption for test1")
caption = ImageCaption.resolve("test/data/test_caption.caption")
self.assertEqual(caption.get_caption(), "caption for test2")
caption = ImageCaption.resolve("hello world")
self.assertEqual(caption.get_caption(), "hello world")
caption = ImageCaption.resolve("test/data/test1.jpg")
self.assertEqual(caption.get_caption(), "caption for test1")
caption = ImageCaption.resolve("test/data/test2.jpg")
self.assertEqual(caption.get_caption(), "test2")

46
utils/fs_helpers.py Normal file
View File

@ -0,0 +1,46 @@
def barename(file):
(val, _) = os.path.splitext(os.path.basename(file))
return val
def ext(file):
(_, val) = os.path.splitext(os.path.basename(file))
return val.lower()
def same_barename(lhs, rhs):
return barename(lhs) == barename(rhs)
def is_image(file):
return ext(file) in {'.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif'}
def read_text(file):
try:
with open(file, encoding='utf-8', mode='r') as stream:
return stream.read().strip()
except Exception as e:
logging.warning(f" *** Error reading text file {file}: {e}")
def read_float(file):
try:
return float(read_text(file))
except Exception as e:
logging.warning(f" *** Could not parse '{data}' to float in file {file}: {e}")
import os
def walk_and_visit(path, visit_fn, context=None):
names = [entry.name for entry in os.scandir(path)]
dirs = []
files = []
for name in names:
fullname = os.path.join(path, name)
if os.path.isdir(fullname) and not str(name).startswith('.'):
dirs.append(fullname)
else:
files.append(fullname)
subcontext = visit_fn(files, context)
for subdir in dirs:
walk_and_visit(subdir, visit_fn, subcontext)