EveryDream2trainer/test/test_dataset.py

289 lines
11 KiB
Python

import os
from data.dataset import Dataset, ImageConfig, Tag, DEFAULT_MAX_CAPTION_LENGTH
from textwrap import dedent
from pyfakefs.fake_filesystem_unittest import TestCase
class TestDataset(TestCase):
def setUp(self):
self.maxDiff = None
self.setUpPyfakefs()
def test_a_caption_is_generated_from_image_given_no_other_config(self):
self.fs.create_file("image, tag1, tag2.jpg")
actual = Dataset.from_path(".").image_configs
expected = {
"./image, tag1, tag2.jpg": ImageConfig(main_prompts="image", tags=frozenset([Tag("tag1"), Tag("tag2")]))
}
self.assertEqual(expected, actual)
def test_several_image_formats_are_supported(self):
self.fs.create_file("image.JPG")
self.fs.create_file("image.jpeg")
self.fs.create_file("image.png")
self.fs.create_file("image.webp")
self.fs.create_file("image.jfif")
self.fs.create_file("image.bmp")
actual = Dataset.from_path(".").image_configs
common_cfg = ImageConfig(main_prompts="image")
expected = {
"./image.JPG": common_cfg,
"./image.jpeg": common_cfg,
"./image.png": common_cfg,
"./image.webp": common_cfg,
"./image.jfif": common_cfg,
"./image.bmp": common_cfg,
}
self.assertEqual(expected, actual)
def test_captions_can_be_read_from_txt_or_caption_sidecar(self):
self.fs.create_file("image_1.jpg")
self.fs.create_file("image_1.txt", contents="an image, test, from .txt")
self.fs.create_file("image_2.jpg")
self.fs.create_file("image_2.caption", contents="an image, test, from .caption")
actual = Dataset.from_path(".").image_configs
expected = {
"./image_1.jpg": ImageConfig(main_prompts="an image", tags=frozenset([Tag("test"), Tag("from .txt")])),
"./image_2.jpg": ImageConfig(main_prompts="an image", tags=frozenset([Tag("test"), Tag("from .caption")]))
}
self.assertEqual(expected, actual)
def test_captions_and_options_can_be_read_from_yaml_sidecar(self):
self.fs.create_file("image_1.jpg")
self.fs.create_file("image_1.yaml",
contents=dedent("""
multiply: 2
cond_dropout: 0.05
flip_p: 0.5
caption: "A simple caption, from .yaml"
"""))
self.fs.create_file("image_2.jpg")
self.fs.create_file("image_2.yml",
contents=dedent("""
flip_p: 0.0
caption:
main_prompt: A complex caption
rating: 1.1
max_caption_length: 1024
tags:
- tag: from .yml
- tag: with weight
weight: 0.5
"""))
actual = Dataset.from_path(".").image_configs
expected = {
"./image_1.jpg": ImageConfig(
multiply=2,
cond_dropout=0.05,
flip_p=0.5,
main_prompts="A simple caption",
tags= { Tag("from .yaml") }
),
"./image_2.jpg": ImageConfig(
flip_p=0.0,
rating=1.1,
max_caption_length=1024,
main_prompts="A complex caption",
tags= { Tag("from .yml"), Tag("with weight", weight=0.5) }
)
}
self.assertEqual(expected, actual)
def test_multiple_prompts_and_tags_from_multiple_sidecars_are_supported(self):
self.fs.create_file("image_1.jpg")
self.fs.create_file("image_1.yaml", contents=dedent("""
main_prompt:
- unique prompt
- dupe prompt
tags:
- from .yaml
- dupe tag
"""))
self.fs.create_file("image_1.txt", contents="also unique prompt, from .txt, dupe tag")
self.fs.create_file("image_1.caption", contents="dupe prompt, from .caption")
actual = Dataset.from_path(".").image_configs
expected = {
"./image_1.jpg": ImageConfig(
main_prompts={ "unique prompt", "also unique prompt", "dupe prompt" },
tags={ Tag("from .yaml"), Tag("from .txt"), Tag("from .caption"), Tag("dupe tag") }
)
}
self.assertEqual(expected, actual)
def test_sidecars_can_also_be_attached_to_local_and_recursive_folders(self):
self.fs.create_file("./global.yaml",
contents=dedent("""\
main_prompt: global prompt
tags:
- global tag
flip_p: 0.0
"""))
self.fs.create_file("./local.yaml",
contents=dedent("""
main_prompt: local prompt
tags:
- tag: local tag
"""))
self.fs.create_file("./arbitrary filename.png")
self.fs.create_file("./sub/sub arbitrary filename.png")
self.fs.create_file("./sub/sidecar.png")
self.fs.create_file("./sub/sidecar.txt",
contents="sidecar prompt, sidecar tag")
self.fs.create_file("./optfile/optfile.png")
self.fs.create_file("./optfile/flip_p.txt",
contents="0.1234")
self.fs.create_file("./sub/sub2/global.yaml",
contents=dedent("""
tags:
- tag: sub global tag
"""))
self.fs.create_file("./sub/sub2/local.yaml",
contents=dedent("""
tags:
- This tag wil not apply to any files
"""))
self.fs.create_file("./sub/sub2/sub3/xyz.png")
actual = Dataset.from_path(".").image_configs
expected = {
"./arbitrary filename.png": ImageConfig(
main_prompts={ 'global prompt', 'local prompt' },
tags={ Tag("global tag"), Tag("local tag") },
flip_p=0.0
),
"./sub/sub arbitrary filename.png": ImageConfig(
main_prompts={ 'global prompt' },
tags={ Tag("global tag") },
flip_p=0.0
),
"./sub/sidecar.png": ImageConfig(
main_prompts={ 'global prompt', 'sidecar prompt' },
tags={ Tag("global tag"), Tag("sidecar tag") },
flip_p=0.0
),
"./optfile/optfile.png": ImageConfig(
main_prompts={ 'global prompt' },
tags={ Tag("global tag") },
flip_p=0.1234
),
"./sub/sub2/sub3/xyz.png": ImageConfig(
main_prompts={ 'global prompt' },
tags={ Tag("global tag"), Tag("sub global tag") },
flip_p=0.0
)
}
self.assertEqual(expected, actual)
def test_can_load_dataset_from_json_manifest(self):
self.fs.create_file("./stuff/image_1.jpg")
self.fs.create_file("./stuff/default.caption", contents= "default caption")
self.fs.create_file("./other/image_1.jpg")
self.fs.create_file("./other/image_2.jpg")
self.fs.create_file("./other/image_3.jpg")
self.fs.create_file("./manifest.json", contents=dedent("""
[
{ "image": "./stuff/image_1.jpg", "caption": "./stuff/default.caption" },
{ "image": "./other/image_1.jpg", "caption": "other caption" },
{
"image": "./other/image_2.jpg",
"caption": {
"main_prompt": "complex caption",
"rating": 0.1,
"max_caption_length": 1000,
"tags": [
{"tag": "including"},
{"tag": "weighted tag", "weight": 999.9}
]
}
},
{
"image": "./other/image_3.jpg",
"multiply": 2,
"flip_p": 0.5,
"cond_dropout": 0.01,
"main_prompt": [
"first caption",
"second caption"
]
}
]
"""))
actual = Dataset.from_json("./manifest.json").image_configs
expected = {
"./stuff/image_1.jpg": ImageConfig( main_prompts={"default caption"} ),
"./other/image_1.jpg": ImageConfig( main_prompts={"other caption"} ),
"./other/image_2.jpg": ImageConfig(
main_prompts={ "complex caption" },
rating=0.1,
max_caption_length=1000,
tags={
Tag("including"),
Tag("weighted tag", 999.9)
}
),
"./other/image_3.jpg": ImageConfig(
main_prompts={ "first caption", "second caption" },
multiply=2,
flip_p=0.5,
cond_dropout=0.01
)
}
self.assertEqual(expected, actual)
def test_dataset_can_produce_train_items(self):
dataset = Dataset({
"1.jpg": ImageConfig(
multiply=2,
flip_p=0.1,
cond_dropout=0.01,
main_prompts=["first caption","second caption"],
rating = 1.1,
max_caption_length=1024,
tags=frozenset([
Tag("tag"),
Tag("tag_2", 2.0)
])),
"2.jpg": ImageConfig( main_prompts="single caption")
})
aspects = []
actual = dataset.image_train_items(aspects)
self.assertEqual(len(actual), 2)
self.assertEqual(actual[0].pathname, os.path.abspath('1.jpg'))
self.assertEqual(actual[0].multiplier, 2.0)
self.assertEqual(actual[0].flip.p, 0.1)
self.assertEqual(actual[0].cond_dropout, 0.01)
self.assertEqual(actual[0].caption.rating(), 1.1)
self.assertEqual(actual[0].caption.get_caption(), "first caption, tag, tag_2")
self.assertTrue(actual[0].caption._ImageCaption__use_weights)
self.assertEqual(actual[0].caption._ImageCaption__max_target_length, 1024)
self.assertEqual(actual[1].pathname, os.path.abspath('2.jpg'))
self.assertEqual(actual[1].multiplier, 1.0)
self.assertEqual(actual[1].flip.p, 0.0)
self.assertIsNone(actual[1].cond_dropout)
self.assertEqual(actual[1].caption.rating(), 1.0)
self.assertEqual(actual[1].caption.get_caption(), "single caption")
self.assertFalse(actual[1].caption._ImageCaption__use_weights)
self.assertEqual(actual[1].caption._ImageCaption__max_target_length, DEFAULT_MAX_CAPTION_LENGTH)