Add support for enhanced dataset configuration

Add support for:
* flip_p.txt
* cond_dropout.txt
* local.yaml consolidated config (including default captions)
* global.yaml consolidated config which applies recursively to subfolders
* flip_p, and cond_dropout config per image
* manifest.json with full image-level configuration
This commit is contained in:
Augusto de la Torre 2023-03-08 15:02:14 +01:00
parent 2e3d044ba3
commit 0716c40ab6
9 changed files with 648 additions and 279 deletions

261
data/dataset.py Normal file
View File

@ -0,0 +1,261 @@
import os
import logging
import yaml
import json
from functools import total_ordering
from attrs import define, Factory
from data.image_train_item import ImageCaption, ImageTrainItem
from utils.fs_helpers import *
@define(frozen=True)
@total_ordering
class Tag:
value: str
weight: float = None
def __lt__(self, other):
return self.value < other.value
@define(frozen=True)
@total_ordering
class Caption:
main_prompt: str = None
rating: float = None
max_caption_length: int = None
tags: frozenset[Tag] = Factory(frozenset)
@classmethod
def from_dict(cls, data: dict):
main_prompt = data.get("main_prompt")
rating = data.get("rating")
max_caption_length = data.get("max_caption_length")
tags = frozenset([ Tag(value=t.get("tag"), weight=t.get("weight"))
for t in data.get("tags", [])
if "tag" in t and len(t.get("tag")) > 0])
if not main_prompt and not tags:
return None
return Caption(main_prompt=main_prompt, rating=rating, max_caption_length=max_caption_length, tags=tags)
@classmethod
def from_text(cls, text: str):
if text is None:
return Caption(main_prompt="")
split_caption = list(map(str.strip, text.split(",")))
main_prompt = split_caption[0]
tags = frozenset(Tag(value=t) for t in split_caption[1:])
return Caption(main_prompt=main_prompt, tags=tags)
@classmethod
def load(cls, input):
if isinstance(input, str):
if os.path.isfile(input):
return Caption.from_text(read_text(input))
else:
return Caption.from_text(input)
elif isinstance(input, dict):
return Caption.from_dict(input)
def __lt__(self, other):
self_str = ",".join([self.main_prompt] + sorted(repr(t) for t in self.tags))
other_str = ",".join([other.main_prompt] + sorted(repr(t) for t in other.tags))
return self_str < other_str
@define(frozen=True)
class ImageConfig:
image: str = None
captions: frozenset[Caption] = Factory(frozenset)
multiply: float = None
cond_dropout: float = None
flip_p: float = None
@classmethod
def fold(cls, configs):
acc = ImageConfig()
[acc := acc.merge(cfg) for cfg in configs]
return acc
def merge(self, other):
if other is None:
return self
if other.image and self.image:
logging(f"Found two images with different extensions and the same barename: {self.image} and {other.image}")
return ImageConfig(
image = other.image or self.image,
captions = other.captions.union(self.captions),
multiply = other.multiply if other.multiply is not None else self.multiply,
cond_dropout = other.cond_dropout if other.cond_dropout is not None else self.cond_dropout,
flip_p = other.flip_p if other.flip_p is not None else self.flip_p
)
def ensure_caption(self):
if not self.captions:
filename_caption = Caption.from_text(barename(self.image).split("_")[0])
return self.merge(ImageConfig(captions=frozenset([filename_caption])))
return self
@classmethod
def from_dict(cls, data: dict):
captions = set()
if "captions" in data:
captions.update(Caption.load(cap) for cap in data.get("captions"))
if "caption" in data:
captions.add(Caption.load(data.get("caption")))
if not captions:
# For backward compatibility with existing caption yaml
caption = Caption.load(data)
if caption:
captions.add(caption)
return ImageConfig(
image = data.get("image"),
captions=frozenset(captions),
multiply=data.get("multiply"),
cond_dropout=data.get("cond_dropout"),
flip_p=data.get("flip_p"))
@classmethod
def from_text(cls, text: str):
try:
if os.path.isfile(text):
return ImageConfig.from_file(text)
return ImageConfig(captions=frozenset({Caption.from_text(text)}))
except Exception as e:
logging.warning(f" *** Error parsing config from text {text}: \n{e}")
@classmethod
def from_file(cls, file: str):
try:
match ext(file):
case '.jpg' | '.jpeg' | '.png' | '.bmp' | '.webp' | '.jfif':
return ImageConfig(image=file)
case ".json":
return ImageConfig.from_dict(json.load(read_text(file)))
case ".yaml" | ".yml":
return ImageConfig.from_dict(yaml.safe_load(read_text(file)))
case ".txt" | ".caption":
return ImageConfig.from_text(read_text(file))
case _:
return logging.warning(" *** Unrecognized config extension {ext}")
except Exception as e:
logging.warning(f" *** Error parsing config from {file}: {e}")
@classmethod
def load(cls, input):
if isinstance(input, str):
return ImageConfig.from_text(input)
elif isinstance(input, dict):
return ImageConfig.from_dict(input)
@define()
class Dataset:
image_configs: set[ImageConfig]
def __global_cfg(files):
cfgs = []
for file in files:
match os.path.basename(file):
case 'global.yaml' | 'global.yml':
cfgs.append(ImageConfig.from_file(file))
return ImageConfig.fold(cfgs)
def __local_cfg(files):
cfgs = []
for file in files:
match os.path.basename(file):
case 'multiply.txt':
cfgs.append(ImageConfig(multiply=read_float(file)))
case 'cond_dropout.txt':
cfgs.append(ImageConfig(cond_dropout=read_float(file)))
case 'flip_p.txt':
cfgs.append(ImageConfig(flip_p=read_float(file)))
case 'local.yaml' | 'local.yml':
cfgs.append(ImageConfig.from_file(file))
return ImageConfig.fold(cfgs)
def __image_cfg(imagepath, files):
cfgs = [ImageConfig.from_file(imagepath)]
for file in files:
if same_barename(imagepath, file):
match ext(file):
case '.txt' | '.caption' | '.yml' | '.yaml':
cfgs.append(ImageConfig.from_file(file))
return ImageConfig.fold(cfgs)
@classmethod
def from_path(cls, data_root):
# Create a visitor that maintains global config stack
# and accumulates image configs as it traverses dataset
image_configs = set()
def process_dir(files, parent_globals):
globals = parent_globals.merge(Dataset.__global_cfg(files))
locals = Dataset.__local_cfg(files)
for img in filter(is_image, files):
img_cfg = Dataset.__image_cfg(img, files)
collapsed_cfg = ImageConfig.fold([globals, locals, img_cfg])
resolved_cfg = collapsed_cfg.ensure_caption()
image_configs.add(resolved_cfg)
return globals
walk_and_visit(data_root, process_dir, ImageConfig())
return Dataset(image_configs)
@classmethod
def from_json(cls, json_path):
"""
Import a dataset definition from a JSON file
"""
configs = set()
with open(json_path, encoding='utf-8', mode='r') as stream:
for data in json.load(stream):
cfg = ImageConfig.load(data).ensure_caption()
if not cfg or not cfg.image:
logging.warning(f" *** Error parsing json image entry in {json_path}: {data}")
continue
configs.add(cfg)
return Dataset(configs)
def image_train_items(self, aspects):
items = []
for config in self.image_configs:
caption = next(iter(sorted(config.captions)))
if len(config.captions) > 1:
logging.warning(f" *** Found multiple captions for image {config.image}, but only one will be applied: {config.captions}")
use_weights = len(set(t.weight or 1.0 for t in caption.tags)) > 1
tags = []
tag_weights = []
for tag in sorted(caption.tags):
tags.append(tag.value)
tag_weights.append(tag.weight or 1.0)
use_weights = len(set(tag_weights)) > 1
caption = ImageCaption(
main_prompt=caption.main_prompt,
rating=caption.rating,
tags=tags,
tag_weights=tag_weights,
max_target_length=caption.max_caption_length,
use_weights=use_weights)
item = ImageTrainItem(
image=None,
caption=caption,
aspects=aspects,
pathname=os.path.abspath(config.image),
flip_p=config.flip_p or 0.0,
multiplier=config.multiply or 1.0,
cond_dropout=config.cond_dropout
)
items.append(item)
return list(sorted(items, key=lambda ti: ti.pathname))

View File

@ -104,7 +104,7 @@ class EveryDreamBatch(Dataset):
example["image"] = image_transforms(train_item["image"])
if random.random() > self.conditional_dropout:
if random.random() > (train_item.cond_dropout or self.conditional_dropout):
example["tokens"] = self.tokenizer(example["caption"],
truncation=True,
padding="max_length",

View File

@ -113,136 +113,6 @@ class ImageCaption:
random.Random(seed).shuffle(tags)
return ", ".join(tags)
@staticmethod
def parse(string: str) -> 'ImageCaption':
"""
Parses a string to get the caption.
:param string: String to parse.
:return: `ImageCaption` object.
"""
split_caption = list(map(str.strip, string.split(",")))
main_prompt = split_caption[0]
tags = split_caption[1:]
tag_weights = [1.0] * len(tags)
return ImageCaption(main_prompt, 1.0, tags, tag_weights, DEFAULT_MAX_CAPTION_LENGTH, False)
@staticmethod
def from_file_name(file_path: str) -> 'ImageCaption':
"""
Parses the file name to get the caption.
:param file_path: Path to the image file.
:return: `ImageCaption` object.
"""
(file_name, _) = os.path.splitext(os.path.basename(file_path))
caption = file_name.split("_")[0]
return ImageCaption.parse(caption)
@staticmethod
def from_text_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
"""
Parses a text file to get the caption. Returns the default caption if
the file does not exist or is invalid.
:param file_path: Path to the text file.
:param default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid.
:return: `ImageCaption` object or `None`.
"""
try:
with open(file_path, encoding='utf-8', mode='r') as caption_file:
caption_text = caption_file.read()
return ImageCaption.parse(caption_text)
except:
logging.error(f" *** Error reading {file_path} to get caption")
return default_caption
@staticmethod
def from_yaml_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
"""
Parses a yaml file to get the caption. Returns the default caption if
the file does not exist or is invalid.
:param file_path: path to the yaml file
:param default_caption: caption to return if the file does not exist or is invalid
:return: `ImageCaption` object or `None`.
"""
try:
with open(file_path, "r") as stream:
file_content = yaml.safe_load(stream)
main_prompt = file_content.get("main_prompt", "")
rating = file_content.get("rating", 1.0)
unparsed_tags = file_content.get("tags", [])
max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH)
tags = []
tag_weights = []
last_weight = None
weights_differ = False
for unparsed_tag in unparsed_tags:
tag = unparsed_tag.get("tag", "").strip()
if len(tag) == 0:
continue
tags.append(tag)
tag_weight = unparsed_tag.get("weight", 1.0)
tag_weights.append(tag_weight)
if last_weight is not None and weights_differ is False:
weights_differ = last_weight != tag_weight
last_weight = tag_weight
return ImageCaption(main_prompt, rating, tags, tag_weights, max_caption_length, weights_differ)
except:
logging.error(f" *** Error reading {file_path} to get caption")
return default_caption
@staticmethod
def from_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
"""
Try to resolve a caption from a file path or return `default_caption`.
:string: The path to the file to parse.
:default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid.
:return: `ImageCaption` object or `None`.
"""
if os.path.exists(file_path):
(file_path_without_ext, ext) = os.path.splitext(file_path)
match ext:
case ".yaml" | ".yml":
return ImageCaption.from_yaml_file(file_path, default_caption)
case ".txt" | ".caption":
return ImageCaption.from_text_file(file_path, default_caption)
case '.jpg'| '.jpeg'| '.png'| '.bmp'| '.webp'| '.jfif':
for ext in [".yaml", ".yml", ".txt", ".caption"]:
file_path = file_path_without_ext + ext
image_caption = ImageCaption.from_file(file_path)
if image_caption is not None:
return image_caption
return ImageCaption.from_file_name(file_path)
case _:
return default_caption
else:
return default_caption
@staticmethod
def resolve(string: str) -> 'ImageCaption':
"""
Try to resolve a caption from a string. If the string is a file path,
the caption will be read from the file, otherwise the string will be
parsed as a caption.
:string: The string to resolve.
:return: `ImageCaption` object.
"""
return ImageCaption.from_file(string, None) or ImageCaption.parse(string)
class ImageTrainItem:
"""
@ -253,7 +123,7 @@ class ImageTrainItem:
flip_p: probability of flipping image (0.0 to 1.0)
rating: the relative rating of the images. The rating is measured in comparison to the other images.
"""
def __init__(self, image: PIL.Image, caption: ImageCaption, aspects: list[float], pathname: str, flip_p=0.0, multiplier: float=1.0):
def __init__(self, image: PIL.Image, caption: ImageCaption, aspects: list[float], pathname: str, flip_p=0.0, multiplier: float=1.0, cond_dropout=None):
self.caption = caption
self.aspects = aspects
self.pathname = pathname
@ -261,6 +131,7 @@ class ImageTrainItem:
self.cropped_img = None
self.runt_size = 0
self.multiplier = multiplier
self.cond_dropout = cond_dropout
self.image_size = None
if image is None or len(image) == 0:

View File

@ -4,6 +4,7 @@ import os
import typing
import zipfile
import argparse
from data.dataset import Dataset
import tqdm
from colorama import Fore, Style
@ -27,16 +28,6 @@ class DataResolver:
"""
raise NotImplementedError()
def image_train_item(self, image_path: str, caption: ImageCaption, multiplier: float=1) -> ImageTrainItem:
return ImageTrainItem(
image=None,
caption=caption,
aspects=self.aspects,
pathname=image_path,
flip_p=self.flip_p,
multiplier=multiplier
)
class JSONResolver(DataResolver):
def image_train_items(self, json_path: str) -> list[ImageTrainItem]:
"""
@ -45,62 +36,8 @@ class JSONResolver(DataResolver):
:param json_path: The path to the JSON file.
"""
items = []
with open(json_path, encoding='utf-8', mode='r') as f:
json_data = json.load(f)
for data in tqdm.tqdm(json_data):
caption = JSONResolver.image_caption(data)
if caption:
image_value = JSONResolver.get_image_value(data)
item = self.image_train_item(image_value, caption)
if item:
items.append(item)
return items
return Dataset.from_json(json_path).image_train_items(self.aspects)
@staticmethod
def get_image_value(json_data: dict) -> typing.Optional[str]:
"""
Get the image from the json data if possible.
:param json_data: The json data, a dict.
:return: The image, or None if not found.
"""
image_value = json_data.get("image", None)
if isinstance(image_value, str):
image_value = image_value.strip()
if os.path.exists(image_value):
return image_value
@staticmethod
def get_caption_value(json_data: dict) -> typing.Optional[str]:
"""
Get the caption from the json data if possible.
:param json_data: The json data, a dict.
:return: The caption, or None if not found.
"""
caption_value = json_data.get("caption", None)
if isinstance(caption_value, str):
return caption_value.strip()
@staticmethod
def image_caption(json_data: dict) -> typing.Optional[ImageCaption]:
"""
Get the caption from the json data if possible.
:param json_data: The json data, a dict.
:return: The `ImageCaption`, or None if not found.
"""
image_value = JSONResolver.get_image_value(json_data)
caption_value = JSONResolver.get_caption_value(json_data)
if image_value:
if caption_value:
return ImageCaption.resolve(caption_value)
return ImageCaption.from_file(image_value)
class DirectoryResolver(DataResolver):
def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
"""
@ -111,32 +48,7 @@ class DirectoryResolver(DataResolver):
:param data_root: The root directory to recurse through
"""
DirectoryResolver.unzip_all(data_root)
image_paths = list(DirectoryResolver.recurse_data_root(data_root))
items = []
multipliers = {}
for pathname in tqdm.tqdm(image_paths):
current_dir = os.path.dirname(pathname)
if current_dir not in multipliers:
multiply_txt_path = os.path.join(current_dir, "multiply.txt")
if os.path.exists(multiply_txt_path):
try:
with open(multiply_txt_path, 'r') as f:
val = float(f.read().strip())
multipliers[current_dir] = val
logging.info(f" - multiply.txt in '{current_dir}' set to {val}")
except Exception as e:
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}")
multipliers[current_dir] = 1.0
else:
multipliers[current_dir] = 1.0
caption = ImageCaption.resolve(pathname)
item = self.image_train_item(pathname, caption, multiplier=multipliers[current_dir])
items.append(item)
return items
return Dataset.from_path(data_root).image_train_items(self.aspects)
@staticmethod
def unzip_all(path):
@ -150,21 +62,6 @@ class DirectoryResolver(DataResolver):
except Exception as e:
logging.error(f"Error unzipping files {e}")
@staticmethod
def recurse_data_root(recurse_root):
for f in os.listdir(recurse_root):
current = os.path.join(recurse_root, f)
if os.path.isfile(current):
ext = os.path.splitext(f)[1].lower()
if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']:
yield current
for d in os.listdir(recurse_root):
current = os.path.join(recurse_root, d)
if os.path.isdir(current):
yield from DirectoryResolver.recurse_data_root(current)
def strategy(data_root: str) -> typing.Type[DataResolver]:
"""
Determine the strategy to use for resolving the data.

View File

@ -10,6 +10,7 @@ ninja
omegaconf==2.2.3
piexif==1.1.3
protobuf==3.20.3
pyfakefs
pynvml==11.5.0
pyre-extensions==0.0.30
pytorch-lightning==1.9.2

View File

@ -58,13 +58,13 @@ class TestResolve(unittest.TestCase):
def test_directory_resolve_with_str(self):
items = resolver.resolve(DATA_PATH, ARGS)
image_paths = [item.pathname for item in items]
image_paths = set(item.pathname for item in items)
image_captions = [item.caption for item in items]
captions = [caption.get_caption() for caption in image_captions]
captions = set(caption.get_caption() for caption in image_captions)
self.assertEqual(len(items), 3)
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH])
self.assertEqual(captions, ['caption for test1', 'test2', 'test3'])
self.assertEqual(image_paths, {IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH})
self.assertEqual(captions, {'caption for test1', 'test2', 'test3'})
undersized_images = list(filter(lambda i: i.is_undersized, items))
self.assertEqual(len(undersized_images), 1)

329
test/test_dataset.py Normal file
View File

@ -0,0 +1,329 @@
import os
from data.dataset import Dataset, ImageConfig, Caption, Tag
from textwrap import dedent
from pyfakefs.fake_filesystem_unittest import TestCase
class TestResolve(TestCase):
def setUp(self):
self.setUpPyfakefs()
def test_simple_image(self):
self.fs.create_file("image, tag1, tag2.jpg")
actual = Dataset.from_path(".").image_configs
expected = {
ImageConfig(
image="./image, tag1, tag2.jpg",
captions=frozenset([
Caption(main_prompt="image", tags=frozenset([Tag("tag1"), Tag("tag2")]))
]))
}
self.assertEqual(expected, actual)
def test_image_types(self):
self.fs.create_file("image_1.JPG")
self.fs.create_file("image_2.jpeg")
self.fs.create_file("image_3.png")
self.fs.create_file("image_4.webp")
self.fs.create_file("image_5.jfif")
self.fs.create_file("image_6.bmp")
actual = Dataset.from_path(".").image_configs
captions = frozenset([Caption(main_prompt="image")])
expected = {
ImageConfig(image="./image_1.JPG", captions=captions),
ImageConfig(image="./image_2.jpeg", captions=captions),
ImageConfig(image="./image_3.png", captions=captions),
ImageConfig(image="./image_4.webp", captions=captions),
ImageConfig(image="./image_5.jfif", captions=captions),
ImageConfig(image="./image_6.bmp", captions=captions),
}
self.assertEqual(expected, actual)
def test_caption_file(self):
self.fs.create_file("image_1.jpg")
self.fs.create_file("image_1.txt", contents="an image, test, from .txt")
self.fs.create_file("image_2.jpg")
self.fs.create_file("image_2.caption", contents="an image, test, from .caption")
actual = Dataset.from_path(".").image_configs
expected = {
ImageConfig(
image="./image_1.jpg",
captions=frozenset([
Caption(main_prompt="an image", tags=frozenset([Tag("test"), Tag("from .txt")]))
])),
ImageConfig(
image="./image_2.jpg",
captions=frozenset([
Caption(main_prompt="an image", tags=frozenset([Tag("test"), Tag("from .caption")]))
]))
}
self.assertEqual(expected, actual)
def test_image_yaml(self):
self.fs.create_file("image_1.jpg")
self.fs.create_file("image_1.yaml",
contents=dedent("""
multiply: 2
cond_dropout: 0.05
flip_p: 0.5
caption: "A simple caption, from .yaml"
"""))
self.fs.create_file("image_2.jpg")
self.fs.create_file("image_2.yml",
contents=dedent("""
flip_p: 0.0
caption:
main_prompt: A complex caption
rating: 1.1
max_caption_length: 1024
tags:
- tag: from .yml
- tag: with weight
weight: 0.5
"""))
actual = Dataset.from_path(".").image_configs
expected = {
ImageConfig(
image="./image_1.jpg",
multiply=2,
cond_dropout=0.05,
flip_p=0.5,
captions=frozenset([
Caption(main_prompt="A simple caption", tags=frozenset([Tag("from .yaml")]))
])),
ImageConfig(
image="./image_2.jpg",
flip_p=0.0,
captions=frozenset([
Caption(main_prompt="A complex caption", rating=1.1,
max_caption_length=1024,
tags=frozenset([
Tag("from .yml"),
Tag("with weight", weight=0.5)
]))
]))
}
self.assertEqual(expected, actual)
def test_multi_caption(self):
self.fs.create_file("image_1.jpg")
self.fs.create_file("image_1.yaml", contents=dedent("""
caption: "A simple caption, from .yaml"
captions:
- "Another simple caption"
- main_prompt: A complex caption
"""))
self.fs.create_file("image_1.txt", contents="A .txt caption")
self.fs.create_file("image_1.caption", contents="A .caption caption")
actual = Dataset.from_path(".").image_configs
expected = {
ImageConfig(
image="./image_1.jpg",
captions=frozenset([
Caption(main_prompt="A simple caption", tags=frozenset([Tag("from .yaml")])),
Caption(main_prompt="Another simple caption", tags=frozenset()),
Caption(main_prompt="A complex caption", tags=frozenset()),
Caption(main_prompt="A .txt caption", tags=frozenset()),
Caption(main_prompt="A .caption caption", tags=frozenset())
])
),
}
self.assertEqual(expected, actual)
def test_globals_and_locals(self):
self.fs.create_file("./people/global.yaml", contents=dedent("""\
multiply: 1.0
cond_dropout: 0.0
flip_p: 0.0
"""))
self.fs.create_file("./people/alice/local.yaml", contents="multiply: 1.5")
self.fs.create_file("./people/alice/alice_1.png")
self.fs.create_file("./people/alice/alice_1.yaml", contents="multiply: 2")
self.fs.create_file("./people/alice/alice_2.png")
self.fs.create_file("./people/bob/multiply.txt", contents="3")
self.fs.create_file("./people/bob/cond_dropout.txt", contents="0.05")
self.fs.create_file("./people/bob/flip_p.txt", contents="0.05")
self.fs.create_file("./people/bob/bob.png")
self.fs.create_file("./people/cleo/cleo.png")
self.fs.create_file("./people/dan.png")
self.fs.create_file("./other/dog/local.yaml", contents="caption: spike")
self.fs.create_file("./other/dog/xyz.png")
actual = Dataset.from_path(".").image_configs
expected = {
ImageConfig(
image="./people/alice/alice_1.png",
captions=frozenset([Caption(main_prompt="alice")]),
multiply=2,
cond_dropout=0.0,
flip_p=0.0
),
ImageConfig(
image="./people/alice/alice_2.png",
captions=frozenset([Caption(main_prompt="alice")]),
multiply=1.5,
cond_dropout=0.0,
flip_p=0.0
),
ImageConfig(
image="./people/bob/bob.png",
captions=frozenset([Caption(main_prompt="bob")]),
multiply=3,
cond_dropout=0.05,
flip_p=0.05
),
ImageConfig(
image="./people/cleo/cleo.png",
captions=frozenset([Caption(main_prompt="cleo")]),
multiply=1.0,
cond_dropout=0.0,
flip_p=0.0
),
ImageConfig(
image="./people/dan.png",
captions=frozenset([Caption(main_prompt="dan")]),
multiply=1.0,
cond_dropout=0.0,
flip_p=0.0
),
ImageConfig(
image="./other/dog/xyz.png",
captions=frozenset([Caption(main_prompt="spike")]),
multiply=None,
cond_dropout=None,
flip_p=None
)
}
self.assertEqual(expected, actual)
def test_json_manifest(self):
self.fs.create_file("./stuff/image_1.jpg")
self.fs.create_file("./stuff/default.caption", contents= "default caption")
self.fs.create_file("./other/image_1.jpg")
self.fs.create_file("./other/image_2.jpg")
self.fs.create_file("./other/image_3.jpg")
self.fs.create_file("./manifest.json", contents=dedent("""
[
{ "image": "./stuff/image_1.jpg", "caption": "./stuff/default.caption" },
{ "image": "./other/image_1.jpg", "caption": "other caption" },
{
"image": "./other/image_2.jpg",
"caption": {
"main_prompt": "complex caption",
"rating": 0.1,
"max_caption_length": 1000,
"tags": [
{"tag": "including"},
{"tag": "weighted tag", "weight": 999.9}
]
}
},
{
"image": "./other/image_3.jpg",
"multiply": 2,
"flip_p": 0.5,
"cond_dropout": 0.01,
"captions": [
"first caption",
{ "main_prompt": "second caption" }
]
}
]
"""))
actual = Dataset.from_json("./manifest.json").image_configs
expected = {
ImageConfig(
image="./stuff/image_1.jpg",
captions=frozenset([Caption(main_prompt="default caption")])
),
ImageConfig(
image="./other/image_1.jpg",
captions=frozenset([Caption(main_prompt="other caption")])
),
ImageConfig(
image="./other/image_2.jpg",
captions=frozenset([
Caption(
main_prompt="complex caption",
rating=0.1,
max_caption_length=1000,
tags=frozenset([
Tag("including"),
Tag("weighted tag", 999.9)
]))
])
),
ImageConfig(
image="./other/image_3.jpg",
multiply=2,
flip_p=0.5,
cond_dropout=0.01,
captions=frozenset([
Caption("first caption"),
Caption("second caption")
])
)
}
self.assertEqual(expected, actual)
def test_train_items(self):
dataset = Dataset([
ImageConfig(
image="1.jpg",
multiply=2,
flip_p=0.1,
cond_dropout=0.01,
captions=frozenset([
Caption(
main_prompt="first caption",
rating = 1.1,
max_caption_length=1024,
tags=frozenset([
Tag("tag"),
Tag("tag_2", 2.0)
])),
Caption(main_prompt="second_caption")
])),
ImageConfig(
image="2.jpg",
captions=frozenset([Caption(main_prompt="single caption")])
)
])
aspects = []
actual = dataset.image_train_items(aspects)
self.assertEqual(len(actual), 2)
self.assertEqual(actual[0].pathname, os.path.abspath('1.jpg'))
self.assertEqual(actual[0].multiplier, 2.0)
self.assertEqual(actual[0].flip.p, 0.1)
self.assertEqual(actual[0].cond_dropout, 0.01)
self.assertEqual(actual[0].caption.get_caption(), "first caption, tag, tag_2")
# Can't test this
# self.assertTrue(actual[0].caption.__use_weights)
self.assertEqual(actual[1].pathname, os.path.abspath('2.jpg'))
self.assertEqual(actual[1].multiplier, 1.0)
self.assertEqual(actual[1].flip.p, 0.0)
self.assertIsNone(actual[1].cond_dropout)
self.assertEqual(actual[1].caption.get_caption(), "single caption")
# Can't test this
# self.assertFalse(actual[1].caption.__use_weights)

View File

@ -32,40 +32,4 @@ class TestImageCaption(unittest.TestCase):
self.assertEqual(caption.get_caption(), "hello world, one, two, three")
caption = ImageCaption("hello world", 1.0, [], [], 2048, False)
self.assertEqual(caption.get_caption(), "hello world")
def test_parse(self):
caption = ImageCaption.parse("hello world, one, two, three")
self.assertEqual(caption.get_caption(), "hello world, one, two, three")
def test_from_file_name(self):
caption = ImageCaption.from_file_name("foo bar_1_2_3.jpg")
self.assertEqual(caption.get_caption(), "foo bar")
def test_from_text_file(self):
caption = ImageCaption.from_text_file("test/data/test1.txt")
self.assertEqual(caption.get_caption(), "caption for test1")
def test_from_file(self):
caption = ImageCaption.from_file("test/data/test1.txt")
self.assertEqual(caption.get_caption(), "caption for test1")
caption = ImageCaption.from_file("test/data/test_caption.caption")
self.assertEqual(caption.get_caption(), "caption for test2")
def test_resolve(self):
caption = ImageCaption.resolve("test/data/test1.txt")
self.assertEqual(caption.get_caption(), "caption for test1")
caption = ImageCaption.resolve("test/data/test_caption.caption")
self.assertEqual(caption.get_caption(), "caption for test2")
caption = ImageCaption.resolve("hello world")
self.assertEqual(caption.get_caption(), "hello world")
caption = ImageCaption.resolve("test/data/test1.jpg")
self.assertEqual(caption.get_caption(), "caption for test1")
caption = ImageCaption.resolve("test/data/test2.jpg")
self.assertEqual(caption.get_caption(), "test2")
self.assertEqual(caption.get_caption(), "hello world")

46
utils/fs_helpers.py Normal file
View File

@ -0,0 +1,46 @@
def barename(file):
(val, _) = os.path.splitext(os.path.basename(file))
return val
def ext(file):
(_, val) = os.path.splitext(os.path.basename(file))
return val.lower()
def same_barename(lhs, rhs):
return barename(lhs) == barename(rhs)
def is_image(file):
return ext(file) in {'.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif'}
def read_text(file):
try:
with open(file, encoding='utf-8', mode='r') as stream:
return stream.read().strip()
except Exception as e:
logging.warning(f" *** Error reading text file {file}: {e}")
def read_float(file):
try:
return float(read_text(file))
except Exception as e:
logging.warning(f" *** Could not parse '{data}' to float in file {file}: {e}")
import os
def walk_and_visit(path, visit_fn, context=None):
names = [entry.name for entry in os.scandir(path)]
dirs = []
files = []
for name in names:
fullname = os.path.join(path, name)
if os.path.isdir(fullname) and not str(name).startswith('.'):
dirs.append(fullname)
else:
files.append(fullname)
subcontext = visit_fn(files, context)
for subdir in dirs:
walk_and_visit(subdir, visit_fn, subcontext)