diff --git a/data/dataset.py b/data/dataset.py new file mode 100644 index 0000000..61b6ab9 --- /dev/null +++ b/data/dataset.py @@ -0,0 +1,261 @@ +import os +import logging +import yaml +import json + +from functools import total_ordering +from attrs import define, Factory +from data.image_train_item import ImageCaption, ImageTrainItem +from utils.fs_helpers import * + + +@define(frozen=True) +@total_ordering +class Tag: + value: str + weight: float = None + + def __lt__(self, other): + return self.value < other.value + +@define(frozen=True) +@total_ordering +class Caption: + main_prompt: str = None + rating: float = None + max_caption_length: int = None + tags: frozenset[Tag] = Factory(frozenset) + + @classmethod + def from_dict(cls, data: dict): + main_prompt = data.get("main_prompt") + rating = data.get("rating") + max_caption_length = data.get("max_caption_length") + + tags = frozenset([ Tag(value=t.get("tag"), weight=t.get("weight")) + for t in data.get("tags", []) + if "tag" in t and len(t.get("tag")) > 0]) + + if not main_prompt and not tags: + return None + + return Caption(main_prompt=main_prompt, rating=rating, max_caption_length=max_caption_length, tags=tags) + + @classmethod + def from_text(cls, text: str): + if text is None: + return Caption(main_prompt="") + split_caption = list(map(str.strip, text.split(","))) + main_prompt = split_caption[0] + tags = frozenset(Tag(value=t) for t in split_caption[1:]) + return Caption(main_prompt=main_prompt, tags=tags) + + @classmethod + def load(cls, input): + if isinstance(input, str): + if os.path.isfile(input): + return Caption.from_text(read_text(input)) + else: + return Caption.from_text(input) + elif isinstance(input, dict): + return Caption.from_dict(input) + + def __lt__(self, other): + self_str = ",".join([self.main_prompt] + sorted(repr(t) for t in self.tags)) + other_str = ",".join([other.main_prompt] + sorted(repr(t) for t in other.tags)) + return self_str < other_str + +@define(frozen=True) +class ImageConfig: + image: str = None + captions: frozenset[Caption] = Factory(frozenset) + multiply: float = None + cond_dropout: float = None + flip_p: float = None + + @classmethod + def fold(cls, configs): + acc = ImageConfig() + [acc := acc.merge(cfg) for cfg in configs] + return acc + + def merge(self, other): + if other is None: + return self + + if other.image and self.image: + logging(f"Found two images with different extensions and the same barename: {self.image} and {other.image}") + + return ImageConfig( + image = other.image or self.image, + captions = other.captions.union(self.captions), + multiply = other.multiply if other.multiply is not None else self.multiply, + cond_dropout = other.cond_dropout if other.cond_dropout is not None else self.cond_dropout, + flip_p = other.flip_p if other.flip_p is not None else self.flip_p + ) + + def ensure_caption(self): + if not self.captions: + filename_caption = Caption.from_text(barename(self.image).split("_")[0]) + return self.merge(ImageConfig(captions=frozenset([filename_caption]))) + return self + + @classmethod + def from_dict(cls, data: dict): + captions = set() + if "captions" in data: + captions.update(Caption.load(cap) for cap in data.get("captions")) + + if "caption" in data: + captions.add(Caption.load(data.get("caption"))) + + if not captions: + # For backward compatibility with existing caption yaml + caption = Caption.load(data) + if caption: + captions.add(caption) + + return ImageConfig( + image = data.get("image"), + captions=frozenset(captions), + multiply=data.get("multiply"), + cond_dropout=data.get("cond_dropout"), + flip_p=data.get("flip_p")) + + @classmethod + def from_text(cls, text: str): + try: + if os.path.isfile(text): + return ImageConfig.from_file(text) + return ImageConfig(captions=frozenset({Caption.from_text(text)})) + except Exception as e: + logging.warning(f" *** Error parsing config from text {text}: \n{e}") + + @classmethod + def from_file(cls, file: str): + try: + match ext(file): + case '.jpg' | '.jpeg' | '.png' | '.bmp' | '.webp' | '.jfif': + return ImageConfig(image=file) + case ".json": + return ImageConfig.from_dict(json.load(read_text(file))) + case ".yaml" | ".yml": + return ImageConfig.from_dict(yaml.safe_load(read_text(file))) + case ".txt" | ".caption": + return ImageConfig.from_text(read_text(file)) + case _: + return logging.warning(" *** Unrecognized config extension {ext}") + except Exception as e: + logging.warning(f" *** Error parsing config from {file}: {e}") + + @classmethod + def load(cls, input): + if isinstance(input, str): + return ImageConfig.from_text(input) + elif isinstance(input, dict): + return ImageConfig.from_dict(input) + +@define() +class Dataset: + image_configs: set[ImageConfig] + + def __global_cfg(files): + cfgs = [] + for file in files: + match os.path.basename(file): + case 'global.yaml' | 'global.yml': + cfgs.append(ImageConfig.from_file(file)) + return ImageConfig.fold(cfgs) + + def __local_cfg(files): + cfgs = [] + for file in files: + match os.path.basename(file): + case 'multiply.txt': + cfgs.append(ImageConfig(multiply=read_float(file))) + case 'cond_dropout.txt': + cfgs.append(ImageConfig(cond_dropout=read_float(file))) + case 'flip_p.txt': + cfgs.append(ImageConfig(flip_p=read_float(file))) + case 'local.yaml' | 'local.yml': + cfgs.append(ImageConfig.from_file(file)) + return ImageConfig.fold(cfgs) + + def __image_cfg(imagepath, files): + cfgs = [ImageConfig.from_file(imagepath)] + for file in files: + if same_barename(imagepath, file): + match ext(file): + case '.txt' | '.caption' | '.yml' | '.yaml': + cfgs.append(ImageConfig.from_file(file)) + return ImageConfig.fold(cfgs) + + @classmethod + def from_path(cls, data_root): + # Create a visitor that maintains global config stack + # and accumulates image configs as it traverses dataset + image_configs = set() + def process_dir(files, parent_globals): + globals = parent_globals.merge(Dataset.__global_cfg(files)) + locals = Dataset.__local_cfg(files) + for img in filter(is_image, files): + img_cfg = Dataset.__image_cfg(img, files) + collapsed_cfg = ImageConfig.fold([globals, locals, img_cfg]) + resolved_cfg = collapsed_cfg.ensure_caption() + image_configs.add(resolved_cfg) + return globals + + walk_and_visit(data_root, process_dir, ImageConfig()) + return Dataset(image_configs) + + @classmethod + def from_json(cls, json_path): + """ + Import a dataset definition from a JSON file + """ + configs = set() + with open(json_path, encoding='utf-8', mode='r') as stream: + for data in json.load(stream): + cfg = ImageConfig.load(data).ensure_caption() + if not cfg or not cfg.image: + logging.warning(f" *** Error parsing json image entry in {json_path}: {data}") + continue + configs.add(cfg) + return Dataset(configs) + + def image_train_items(self, aspects): + items = [] + for config in self.image_configs: + caption = next(iter(sorted(config.captions))) + if len(config.captions) > 1: + logging.warning(f" *** Found multiple captions for image {config.image}, but only one will be applied: {config.captions}") + + use_weights = len(set(t.weight or 1.0 for t in caption.tags)) > 1 + tags = [] + tag_weights = [] + for tag in sorted(caption.tags): + tags.append(tag.value) + tag_weights.append(tag.weight or 1.0) + use_weights = len(set(tag_weights)) > 1 + + caption = ImageCaption( + main_prompt=caption.main_prompt, + rating=caption.rating, + tags=tags, + tag_weights=tag_weights, + max_target_length=caption.max_caption_length, + use_weights=use_weights) + + item = ImageTrainItem( + image=None, + caption=caption, + aspects=aspects, + pathname=os.path.abspath(config.image), + flip_p=config.flip_p or 0.0, + multiplier=config.multiply or 1.0, + cond_dropout=config.cond_dropout + ) + items.append(item) + return list(sorted(items, key=lambda ti: ti.pathname)) + + diff --git a/data/every_dream.py b/data/every_dream.py index e4f3908..f5df7d9 100644 --- a/data/every_dream.py +++ b/data/every_dream.py @@ -104,7 +104,7 @@ class EveryDreamBatch(Dataset): example["image"] = image_transforms(train_item["image"]) - if random.random() > self.conditional_dropout: + if random.random() > (train_item.cond_dropout or self.conditional_dropout): example["tokens"] = self.tokenizer(example["caption"], truncation=True, padding="max_length", diff --git a/data/image_train_item.py b/data/image_train_item.py index 8e88612..42f3016 100644 --- a/data/image_train_item.py +++ b/data/image_train_item.py @@ -113,136 +113,6 @@ class ImageCaption: random.Random(seed).shuffle(tags) return ", ".join(tags) - @staticmethod - def parse(string: str) -> 'ImageCaption': - """ - Parses a string to get the caption. - - :param string: String to parse. - :return: `ImageCaption` object. - """ - split_caption = list(map(str.strip, string.split(","))) - main_prompt = split_caption[0] - tags = split_caption[1:] - tag_weights = [1.0] * len(tags) - - return ImageCaption(main_prompt, 1.0, tags, tag_weights, DEFAULT_MAX_CAPTION_LENGTH, False) - - @staticmethod - def from_file_name(file_path: str) -> 'ImageCaption': - """ - Parses the file name to get the caption. - - :param file_path: Path to the image file. - :return: `ImageCaption` object. - """ - (file_name, _) = os.path.splitext(os.path.basename(file_path)) - caption = file_name.split("_")[0] - return ImageCaption.parse(caption) - - @staticmethod - def from_text_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption: - """ - Parses a text file to get the caption. Returns the default caption if - the file does not exist or is invalid. - - :param file_path: Path to the text file. - :param default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid. - :return: `ImageCaption` object or `None`. - """ - try: - with open(file_path, encoding='utf-8', mode='r') as caption_file: - caption_text = caption_file.read() - return ImageCaption.parse(caption_text) - except: - logging.error(f" *** Error reading {file_path} to get caption") - return default_caption - - @staticmethod - def from_yaml_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption: - """ - Parses a yaml file to get the caption. Returns the default caption if - the file does not exist or is invalid. - - :param file_path: path to the yaml file - :param default_caption: caption to return if the file does not exist or is invalid - :return: `ImageCaption` object or `None`. - """ - try: - with open(file_path, "r") as stream: - file_content = yaml.safe_load(stream) - main_prompt = file_content.get("main_prompt", "") - rating = file_content.get("rating", 1.0) - unparsed_tags = file_content.get("tags", []) - - max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH) - - tags = [] - tag_weights = [] - last_weight = None - weights_differ = False - for unparsed_tag in unparsed_tags: - tag = unparsed_tag.get("tag", "").strip() - if len(tag) == 0: - continue - - tags.append(tag) - tag_weight = unparsed_tag.get("weight", 1.0) - tag_weights.append(tag_weight) - - if last_weight is not None and weights_differ is False: - weights_differ = last_weight != tag_weight - - last_weight = tag_weight - - return ImageCaption(main_prompt, rating, tags, tag_weights, max_caption_length, weights_differ) - except: - logging.error(f" *** Error reading {file_path} to get caption") - return default_caption - - @staticmethod - def from_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption: - """ - Try to resolve a caption from a file path or return `default_caption`. - - :string: The path to the file to parse. - :default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid. - :return: `ImageCaption` object or `None`. - """ - if os.path.exists(file_path): - (file_path_without_ext, ext) = os.path.splitext(file_path) - match ext: - case ".yaml" | ".yml": - return ImageCaption.from_yaml_file(file_path, default_caption) - - case ".txt" | ".caption": - return ImageCaption.from_text_file(file_path, default_caption) - - case '.jpg'| '.jpeg'| '.png'| '.bmp'| '.webp'| '.jfif': - for ext in [".yaml", ".yml", ".txt", ".caption"]: - file_path = file_path_without_ext + ext - image_caption = ImageCaption.from_file(file_path) - if image_caption is not None: - return image_caption - return ImageCaption.from_file_name(file_path) - - case _: - return default_caption - else: - return default_caption - - @staticmethod - def resolve(string: str) -> 'ImageCaption': - """ - Try to resolve a caption from a string. If the string is a file path, - the caption will be read from the file, otherwise the string will be - parsed as a caption. - - :string: The string to resolve. - :return: `ImageCaption` object. - """ - return ImageCaption.from_file(string, None) or ImageCaption.parse(string) - class ImageTrainItem: """ @@ -253,7 +123,7 @@ class ImageTrainItem: flip_p: probability of flipping image (0.0 to 1.0) rating: the relative rating of the images. The rating is measured in comparison to the other images. """ - def __init__(self, image: PIL.Image, caption: ImageCaption, aspects: list[float], pathname: str, flip_p=0.0, multiplier: float=1.0): + def __init__(self, image: PIL.Image, caption: ImageCaption, aspects: list[float], pathname: str, flip_p=0.0, multiplier: float=1.0, cond_dropout=None): self.caption = caption self.aspects = aspects self.pathname = pathname @@ -261,6 +131,7 @@ class ImageTrainItem: self.cropped_img = None self.runt_size = 0 self.multiplier = multiplier + self.cond_dropout = cond_dropout self.image_size = None if image is None or len(image) == 0: diff --git a/data/resolver.py b/data/resolver.py index b31043f..eae5df8 100644 --- a/data/resolver.py +++ b/data/resolver.py @@ -4,6 +4,7 @@ import os import typing import zipfile import argparse +from data.dataset import Dataset import tqdm from colorama import Fore, Style @@ -27,16 +28,6 @@ class DataResolver: """ raise NotImplementedError() - def image_train_item(self, image_path: str, caption: ImageCaption, multiplier: float=1) -> ImageTrainItem: - return ImageTrainItem( - image=None, - caption=caption, - aspects=self.aspects, - pathname=image_path, - flip_p=self.flip_p, - multiplier=multiplier - ) - class JSONResolver(DataResolver): def image_train_items(self, json_path: str) -> list[ImageTrainItem]: """ @@ -45,62 +36,8 @@ class JSONResolver(DataResolver): :param json_path: The path to the JSON file. """ - items = [] - with open(json_path, encoding='utf-8', mode='r') as f: - json_data = json.load(f) - - for data in tqdm.tqdm(json_data): - caption = JSONResolver.image_caption(data) - if caption: - image_value = JSONResolver.get_image_value(data) - item = self.image_train_item(image_value, caption) - if item: - items.append(item) - - return items + return Dataset.from_json(json_path).image_train_items(self.aspects) - @staticmethod - def get_image_value(json_data: dict) -> typing.Optional[str]: - """ - Get the image from the json data if possible. - - :param json_data: The json data, a dict. - :return: The image, or None if not found. - """ - image_value = json_data.get("image", None) - if isinstance(image_value, str): - image_value = image_value.strip() - if os.path.exists(image_value): - return image_value - - @staticmethod - def get_caption_value(json_data: dict) -> typing.Optional[str]: - """ - Get the caption from the json data if possible. - - :param json_data: The json data, a dict. - :return: The caption, or None if not found. - """ - caption_value = json_data.get("caption", None) - if isinstance(caption_value, str): - return caption_value.strip() - - @staticmethod - def image_caption(json_data: dict) -> typing.Optional[ImageCaption]: - """ - Get the caption from the json data if possible. - - :param json_data: The json data, a dict. - :return: The `ImageCaption`, or None if not found. - """ - image_value = JSONResolver.get_image_value(json_data) - caption_value = JSONResolver.get_caption_value(json_data) - if image_value: - if caption_value: - return ImageCaption.resolve(caption_value) - return ImageCaption.from_file(image_value) - - class DirectoryResolver(DataResolver): def image_train_items(self, data_root: str) -> list[ImageTrainItem]: """ @@ -111,32 +48,7 @@ class DirectoryResolver(DataResolver): :param data_root: The root directory to recurse through """ DirectoryResolver.unzip_all(data_root) - image_paths = list(DirectoryResolver.recurse_data_root(data_root)) - items = [] - multipliers = {} - - for pathname in tqdm.tqdm(image_paths): - current_dir = os.path.dirname(pathname) - - if current_dir not in multipliers: - multiply_txt_path = os.path.join(current_dir, "multiply.txt") - if os.path.exists(multiply_txt_path): - try: - with open(multiply_txt_path, 'r') as f: - val = float(f.read().strip()) - multipliers[current_dir] = val - logging.info(f" - multiply.txt in '{current_dir}' set to {val}") - except Exception as e: - logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}") - multipliers[current_dir] = 1.0 - else: - multipliers[current_dir] = 1.0 - - caption = ImageCaption.resolve(pathname) - item = self.image_train_item(pathname, caption, multiplier=multipliers[current_dir]) - items.append(item) - - return items + return Dataset.from_path(data_root).image_train_items(self.aspects) @staticmethod def unzip_all(path): @@ -150,21 +62,6 @@ class DirectoryResolver(DataResolver): except Exception as e: logging.error(f"Error unzipping files {e}") - @staticmethod - def recurse_data_root(recurse_root): - for f in os.listdir(recurse_root): - current = os.path.join(recurse_root, f) - - if os.path.isfile(current): - ext = os.path.splitext(f)[1].lower() - if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']: - yield current - - for d in os.listdir(recurse_root): - current = os.path.join(recurse_root, d) - if os.path.isdir(current): - yield from DirectoryResolver.recurse_data_root(current) - def strategy(data_root: str) -> typing.Type[DataResolver]: """ Determine the strategy to use for resolving the data. diff --git a/docker/requirements.txt b/docker/requirements.txt index 064a00a..0677981 100644 --- a/docker/requirements.txt +++ b/docker/requirements.txt @@ -10,6 +10,7 @@ ninja omegaconf==2.2.3 piexif==1.1.3 protobuf==3.20.3 +pyfakefs pynvml==11.5.0 pyre-extensions==0.0.30 pytorch-lightning==1.9.2 diff --git a/test/test_data_resolver.py b/test/test_data_resolver.py index f668974..c0607e1 100644 --- a/test/test_data_resolver.py +++ b/test/test_data_resolver.py @@ -58,13 +58,13 @@ class TestResolve(unittest.TestCase): def test_directory_resolve_with_str(self): items = resolver.resolve(DATA_PATH, ARGS) - image_paths = [item.pathname for item in items] + image_paths = set(item.pathname for item in items) image_captions = [item.caption for item in items] - captions = [caption.get_caption() for caption in image_captions] + captions = set(caption.get_caption() for caption in image_captions) self.assertEqual(len(items), 3) - self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH]) - self.assertEqual(captions, ['caption for test1', 'test2', 'test3']) + self.assertEqual(image_paths, {IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH}) + self.assertEqual(captions, {'caption for test1', 'test2', 'test3'}) undersized_images = list(filter(lambda i: i.is_undersized, items)) self.assertEqual(len(undersized_images), 1) diff --git a/test/test_dataset.py b/test/test_dataset.py new file mode 100644 index 0000000..4fb3c41 --- /dev/null +++ b/test/test_dataset.py @@ -0,0 +1,329 @@ +import os +from data.dataset import Dataset, ImageConfig, Caption, Tag + +from textwrap import dedent +from pyfakefs.fake_filesystem_unittest import TestCase + +class TestResolve(TestCase): + def setUp(self): + self.setUpPyfakefs() + + def test_simple_image(self): + self.fs.create_file("image, tag1, tag2.jpg") + + actual = Dataset.from_path(".").image_configs + + expected = { + ImageConfig( + image="./image, tag1, tag2.jpg", + captions=frozenset([ + Caption(main_prompt="image", tags=frozenset([Tag("tag1"), Tag("tag2")])) + ])) + } + self.assertEqual(expected, actual) + + def test_image_types(self): + self.fs.create_file("image_1.JPG") + self.fs.create_file("image_2.jpeg") + self.fs.create_file("image_3.png") + self.fs.create_file("image_4.webp") + self.fs.create_file("image_5.jfif") + self.fs.create_file("image_6.bmp") + + actual = Dataset.from_path(".").image_configs + + captions = frozenset([Caption(main_prompt="image")]) + expected = { + ImageConfig(image="./image_1.JPG", captions=captions), + ImageConfig(image="./image_2.jpeg", captions=captions), + ImageConfig(image="./image_3.png", captions=captions), + ImageConfig(image="./image_4.webp", captions=captions), + ImageConfig(image="./image_5.jfif", captions=captions), + ImageConfig(image="./image_6.bmp", captions=captions), + } + self.assertEqual(expected, actual) + + def test_caption_file(self): + self.fs.create_file("image_1.jpg") + self.fs.create_file("image_1.txt", contents="an image, test, from .txt") + self.fs.create_file("image_2.jpg") + self.fs.create_file("image_2.caption", contents="an image, test, from .caption") + + actual = Dataset.from_path(".").image_configs + + expected = { + ImageConfig( + image="./image_1.jpg", + captions=frozenset([ + Caption(main_prompt="an image", tags=frozenset([Tag("test"), Tag("from .txt")])) + ])), + ImageConfig( + image="./image_2.jpg", + captions=frozenset([ + Caption(main_prompt="an image", tags=frozenset([Tag("test"), Tag("from .caption")])) + ])) + } + self.assertEqual(expected, actual) + + + def test_image_yaml(self): + self.fs.create_file("image_1.jpg") + self.fs.create_file("image_1.yaml", + contents=dedent(""" + multiply: 2 + cond_dropout: 0.05 + flip_p: 0.5 + caption: "A simple caption, from .yaml" + """)) + self.fs.create_file("image_2.jpg") + self.fs.create_file("image_2.yml", + contents=dedent(""" + flip_p: 0.0 + caption: + main_prompt: A complex caption + rating: 1.1 + max_caption_length: 1024 + tags: + - tag: from .yml + - tag: with weight + weight: 0.5 + """)) + + actual = Dataset.from_path(".").image_configs + + expected = { + ImageConfig( + image="./image_1.jpg", + multiply=2, + cond_dropout=0.05, + flip_p=0.5, + captions=frozenset([ + Caption(main_prompt="A simple caption", tags=frozenset([Tag("from .yaml")])) + ])), + ImageConfig( + image="./image_2.jpg", + flip_p=0.0, + captions=frozenset([ + Caption(main_prompt="A complex caption", rating=1.1, + max_caption_length=1024, + tags=frozenset([ + Tag("from .yml"), + Tag("with weight", weight=0.5) + ])) + ])) + } + self.assertEqual(expected, actual) + + + def test_multi_caption(self): + self.fs.create_file("image_1.jpg") + self.fs.create_file("image_1.yaml", contents=dedent(""" + caption: "A simple caption, from .yaml" + captions: + - "Another simple caption" + - main_prompt: A complex caption + """)) + self.fs.create_file("image_1.txt", contents="A .txt caption") + self.fs.create_file("image_1.caption", contents="A .caption caption") + + actual = Dataset.from_path(".").image_configs + + expected = { + ImageConfig( + image="./image_1.jpg", + captions=frozenset([ + Caption(main_prompt="A simple caption", tags=frozenset([Tag("from .yaml")])), + Caption(main_prompt="Another simple caption", tags=frozenset()), + Caption(main_prompt="A complex caption", tags=frozenset()), + Caption(main_prompt="A .txt caption", tags=frozenset()), + Caption(main_prompt="A .caption caption", tags=frozenset()) + ]) + ), + } + self.assertEqual(expected, actual) + + def test_globals_and_locals(self): + self.fs.create_file("./people/global.yaml", contents=dedent("""\ + multiply: 1.0 + cond_dropout: 0.0 + flip_p: 0.0 + """)) + self.fs.create_file("./people/alice/local.yaml", contents="multiply: 1.5") + self.fs.create_file("./people/alice/alice_1.png") + self.fs.create_file("./people/alice/alice_1.yaml", contents="multiply: 2") + self.fs.create_file("./people/alice/alice_2.png") + + self.fs.create_file("./people/bob/multiply.txt", contents="3") + self.fs.create_file("./people/bob/cond_dropout.txt", contents="0.05") + self.fs.create_file("./people/bob/flip_p.txt", contents="0.05") + self.fs.create_file("./people/bob/bob.png") + + self.fs.create_file("./people/cleo/cleo.png") + self.fs.create_file("./people/dan.png") + + self.fs.create_file("./other/dog/local.yaml", contents="caption: spike") + self.fs.create_file("./other/dog/xyz.png") + + actual = Dataset.from_path(".").image_configs + + expected = { + ImageConfig( + image="./people/alice/alice_1.png", + captions=frozenset([Caption(main_prompt="alice")]), + multiply=2, + cond_dropout=0.0, + flip_p=0.0 + ), + ImageConfig( + image="./people/alice/alice_2.png", + captions=frozenset([Caption(main_prompt="alice")]), + multiply=1.5, + cond_dropout=0.0, + flip_p=0.0 + ), + ImageConfig( + image="./people/bob/bob.png", + captions=frozenset([Caption(main_prompt="bob")]), + multiply=3, + cond_dropout=0.05, + flip_p=0.05 + ), + ImageConfig( + image="./people/cleo/cleo.png", + captions=frozenset([Caption(main_prompt="cleo")]), + multiply=1.0, + cond_dropout=0.0, + flip_p=0.0 + ), + ImageConfig( + image="./people/dan.png", + captions=frozenset([Caption(main_prompt="dan")]), + multiply=1.0, + cond_dropout=0.0, + flip_p=0.0 + ), + ImageConfig( + image="./other/dog/xyz.png", + captions=frozenset([Caption(main_prompt="spike")]), + multiply=None, + cond_dropout=None, + flip_p=None + ) + } + self.assertEqual(expected, actual) + + def test_json_manifest(self): + self.fs.create_file("./stuff/image_1.jpg") + self.fs.create_file("./stuff/default.caption", contents= "default caption") + self.fs.create_file("./other/image_1.jpg") + self.fs.create_file("./other/image_2.jpg") + self.fs.create_file("./other/image_3.jpg") + self.fs.create_file("./manifest.json", contents=dedent(""" + [ + { "image": "./stuff/image_1.jpg", "caption": "./stuff/default.caption" }, + { "image": "./other/image_1.jpg", "caption": "other caption" }, + { + "image": "./other/image_2.jpg", + "caption": { + "main_prompt": "complex caption", + "rating": 0.1, + "max_caption_length": 1000, + "tags": [ + {"tag": "including"}, + {"tag": "weighted tag", "weight": 999.9} + ] + } + }, + { + "image": "./other/image_3.jpg", + "multiply": 2, + "flip_p": 0.5, + "cond_dropout": 0.01, + "captions": [ + "first caption", + { "main_prompt": "second caption" } + ] + } + ] + """)) + + actual = Dataset.from_json("./manifest.json").image_configs + expected = { + ImageConfig( + image="./stuff/image_1.jpg", + captions=frozenset([Caption(main_prompt="default caption")]) + ), + ImageConfig( + image="./other/image_1.jpg", + captions=frozenset([Caption(main_prompt="other caption")]) + ), + ImageConfig( + image="./other/image_2.jpg", + captions=frozenset([ + Caption( + main_prompt="complex caption", + rating=0.1, + max_caption_length=1000, + tags=frozenset([ + Tag("including"), + Tag("weighted tag", 999.9) + ])) + ]) + ), + ImageConfig( + image="./other/image_3.jpg", + multiply=2, + flip_p=0.5, + cond_dropout=0.01, + captions=frozenset([ + Caption("first caption"), + Caption("second caption") + ]) + ) + } + self.assertEqual(expected, actual) + + def test_train_items(self): + dataset = Dataset([ + ImageConfig( + image="1.jpg", + multiply=2, + flip_p=0.1, + cond_dropout=0.01, + captions=frozenset([ + Caption( + main_prompt="first caption", + rating = 1.1, + max_caption_length=1024, + tags=frozenset([ + Tag("tag"), + Tag("tag_2", 2.0) + ])), + Caption(main_prompt="second_caption") + ])), + ImageConfig( + image="2.jpg", + captions=frozenset([Caption(main_prompt="single caption")]) + ) + ]) + + aspects = [] + actual = dataset.image_train_items(aspects) + + self.assertEqual(len(actual), 2) + + self.assertEqual(actual[0].pathname, os.path.abspath('1.jpg')) + self.assertEqual(actual[0].multiplier, 2.0) + self.assertEqual(actual[0].flip.p, 0.1) + self.assertEqual(actual[0].cond_dropout, 0.01) + self.assertEqual(actual[0].caption.get_caption(), "first caption, tag, tag_2") + # Can't test this + # self.assertTrue(actual[0].caption.__use_weights) + + self.assertEqual(actual[1].pathname, os.path.abspath('2.jpg')) + self.assertEqual(actual[1].multiplier, 1.0) + self.assertEqual(actual[1].flip.p, 0.0) + self.assertIsNone(actual[1].cond_dropout) + self.assertEqual(actual[1].caption.get_caption(), "single caption") + # Can't test this + # self.assertFalse(actual[1].caption.__use_weights) \ No newline at end of file diff --git a/test/test_image_train_item.py b/test/test_image_train_item.py index bf12b43..75ffb41 100644 --- a/test/test_image_train_item.py +++ b/test/test_image_train_item.py @@ -32,40 +32,4 @@ class TestImageCaption(unittest.TestCase): self.assertEqual(caption.get_caption(), "hello world, one, two, three") caption = ImageCaption("hello world", 1.0, [], [], 2048, False) - self.assertEqual(caption.get_caption(), "hello world") - - def test_parse(self): - caption = ImageCaption.parse("hello world, one, two, three") - - self.assertEqual(caption.get_caption(), "hello world, one, two, three") - - def test_from_file_name(self): - caption = ImageCaption.from_file_name("foo bar_1_2_3.jpg") - self.assertEqual(caption.get_caption(), "foo bar") - - def test_from_text_file(self): - caption = ImageCaption.from_text_file("test/data/test1.txt") - self.assertEqual(caption.get_caption(), "caption for test1") - - def test_from_file(self): - caption = ImageCaption.from_file("test/data/test1.txt") - self.assertEqual(caption.get_caption(), "caption for test1") - - caption = ImageCaption.from_file("test/data/test_caption.caption") - self.assertEqual(caption.get_caption(), "caption for test2") - - def test_resolve(self): - caption = ImageCaption.resolve("test/data/test1.txt") - self.assertEqual(caption.get_caption(), "caption for test1") - - caption = ImageCaption.resolve("test/data/test_caption.caption") - self.assertEqual(caption.get_caption(), "caption for test2") - - caption = ImageCaption.resolve("hello world") - self.assertEqual(caption.get_caption(), "hello world") - - caption = ImageCaption.resolve("test/data/test1.jpg") - self.assertEqual(caption.get_caption(), "caption for test1") - - caption = ImageCaption.resolve("test/data/test2.jpg") - self.assertEqual(caption.get_caption(), "test2") \ No newline at end of file + self.assertEqual(caption.get_caption(), "hello world") \ No newline at end of file diff --git a/utils/fs_helpers.py b/utils/fs_helpers.py new file mode 100644 index 0000000..2115a39 --- /dev/null +++ b/utils/fs_helpers.py @@ -0,0 +1,46 @@ + +def barename(file): + (val, _) = os.path.splitext(os.path.basename(file)) + return val + +def ext(file): + (_, val) = os.path.splitext(os.path.basename(file)) + return val.lower() + +def same_barename(lhs, rhs): + return barename(lhs) == barename(rhs) + +def is_image(file): + return ext(file) in {'.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif'} + +def read_text(file): + try: + with open(file, encoding='utf-8', mode='r') as stream: + return stream.read().strip() + except Exception as e: + logging.warning(f" *** Error reading text file {file}: {e}") + +def read_float(file): + try: + return float(read_text(file)) + except Exception as e: + logging.warning(f" *** Could not parse '{data}' to float in file {file}: {e}") + +import os + +def walk_and_visit(path, visit_fn, context=None): + names = [entry.name for entry in os.scandir(path)] + + dirs = [] + files = [] + for name in names: + fullname = os.path.join(path, name) + if os.path.isdir(fullname) and not str(name).startswith('.'): + dirs.append(fullname) + else: + files.append(fullname) + + subcontext = visit_fn(files, context) + + for subdir in dirs: + walk_and_visit(subdir, visit_fn, subcontext) \ No newline at end of file