EveryDream2trainer/test/test_dataset.py

379 lines
14 KiB
Python

import os
from data.dataset import Dataset, ImageConfig, Tag, DEFAULT_MAX_CAPTION_LENGTH
from textwrap import dedent
from pyfakefs.fake_filesystem_unittest import TestCase
class TestDataset(TestCase):
def setUp(self):
self.maxDiff = None
self.setUpPyfakefs()
def test_a_caption_is_generated_from_image_given_no_other_config(self):
self.fs.create_file("image, tag1, tag2.jpg")
actual = Dataset.from_path(".").image_configs
expected = {
"./image, tag1, tag2.jpg": ImageConfig(main_prompts="image", tags=frozenset([Tag("tag1"), Tag("tag2")]))
}
self.assertEqual(expected, actual)
def test_several_image_formats_are_supported(self):
self.fs.create_file("image.JPG")
self.fs.create_file("image.jpeg")
self.fs.create_file("image.png")
self.fs.create_file("image.webp")
self.fs.create_file("image.jfif")
self.fs.create_file("image.bmp")
actual = Dataset.from_path(".").image_configs
common_cfg = ImageConfig(main_prompts="image")
expected = {
"./image.JPG": common_cfg,
"./image.jpeg": common_cfg,
"./image.png": common_cfg,
"./image.webp": common_cfg,
"./image.jfif": common_cfg,
"./image.bmp": common_cfg,
}
self.assertEqual(expected, actual)
def test_captions_can_be_read_from_txt_or_caption_sidecar(self):
self.fs.create_file("image_1.jpg")
self.fs.create_file("image_1.txt", contents="an image, test, from .txt")
self.fs.create_file("image_2.jpg")
self.fs.create_file("image_2.caption", contents="an image, test, from .caption")
actual = Dataset.from_path(".").image_configs
expected = {
"./image_1.jpg": ImageConfig(main_prompts="an image", tags=frozenset([Tag("test"), Tag("from .txt")])),
"./image_2.jpg": ImageConfig(main_prompts="an image", tags=frozenset([Tag("test"), Tag("from .caption")]))
}
self.assertEqual(expected, actual)
def test_captions_and_options_can_be_read_from_yaml_sidecar(self):
self.fs.create_file("image_1.jpg")
self.fs.create_file("image_1.yaml",
contents=dedent("""
multiply: 2
cond_dropout: 0.05
flip_p: 0.5
caption: "A simple caption, from .yaml"
"""))
self.fs.create_file("image_2.jpg")
self.fs.create_file("image_2.yml",
contents=dedent("""
flip_p: 0.0
caption:
main_prompt: A complex caption
rating: 1.1
max_caption_length: 1024
tags:
- tag: from .yml
- tag: with weight
weight: 0.5
"""))
actual = Dataset.from_path(".").image_configs
expected = {
"./image_1.jpg": ImageConfig(
multiply=2,
cond_dropout=0.05,
flip_p=0.5,
main_prompts="A simple caption",
tags= { Tag("from .yaml") }
),
"./image_2.jpg": ImageConfig(
flip_p=0.0,
rating=1.1,
max_caption_length=1024,
main_prompts="A complex caption",
tags= { Tag("from .yml"), Tag("with weight", weight=0.5) }
)
}
self.assertEqual(expected, actual)
def test_captions_are_read_from_filename_if_no_main_prompt(self):
self.fs.create_file("filename main prompt, filename tag.jpg")
self.fs.create_file("filename main prompt, filename tag.yaml",
contents=dedent("""
caption:
tags:
- tag: standalone yaml tag
"""))
actual = Dataset.from_path(".").image_configs
expected = {
"./filename main prompt, filename tag.jpg": ImageConfig(
main_prompts="filename main prompt",
tags= [ Tag("filename tag"), Tag("standalone yaml tag") ]
)
}
self.assertEqual(expected, actual)
def test_multiple_prompts_and_tags_from_multiple_sidecars_are_supported(self):
self.fs.create_file("image_1.jpg")
self.fs.create_file("image_1.yaml", contents=dedent("""
main_prompt:
- unique prompt
- dupe prompt
tags:
- from .yaml
- dupe tag
"""))
self.fs.create_file("image_1.txt", contents="also unique prompt, from .txt, dupe tag")
self.fs.create_file("image_1.caption", contents="dupe prompt, from .caption")
actual = Dataset.from_path(".").image_configs
expected = {
"./image_1.jpg": ImageConfig(
main_prompts={ "unique prompt", "also unique prompt", "dupe prompt" },
tags={ Tag("from .yaml"), Tag("from .txt"), Tag("from .caption"), Tag("dupe tag") }
)
}
self.assertEqual(expected, actual)
def test_sidecars_can_also_be_attached_to_local_and_recursive_folders(self):
self.fs.create_file("./global.yaml",
contents=dedent("""\
main_prompt: global prompt
tags:
- global tag
flip_p: 0.0
"""))
self.fs.create_file("./local.yaml",
contents=dedent("""
main_prompt: local prompt
tags:
- tag: local tag
"""))
self.fs.create_file("./arbitrary filename.png")
self.fs.create_file("./sub/sub arbitrary filename.png")
self.fs.create_file("./sub/sidecar.png")
self.fs.create_file("./sub/sidecar.txt",
contents="sidecar prompt, sidecar tag")
self.fs.create_file("./optfile/optfile.png")
self.fs.create_file("./optfile/flip_p.txt",
contents="0.1234")
self.fs.create_file("./sub/sub2/global.yaml",
contents=dedent("""
tags:
- tag: sub global tag
"""))
self.fs.create_file("./sub/sub2/local.yaml",
contents=dedent("""
tags:
- This tag wil not apply to any files
"""))
self.fs.create_file("./sub/sub2/sub3/xyz.png")
actual = Dataset.from_path(".").image_configs
expected = {
"./arbitrary filename.png": ImageConfig(
main_prompts={ 'global prompt', 'local prompt' },
tags=[ Tag("global tag"), Tag("local tag") ],
flip_p=0.0
),
"./sub/sub arbitrary filename.png": ImageConfig(
main_prompts={ 'global prompt' },
tags=[ Tag("global tag") ],
flip_p=0.0
),
"./sub/sidecar.png": ImageConfig(
main_prompts={ 'global prompt', 'sidecar prompt' },
tags=[ Tag("global tag"), Tag("sidecar tag") ],
flip_p=0.0
),
"./optfile/optfile.png": ImageConfig(
main_prompts={ 'global prompt' },
tags=[ Tag("global tag") ],
flip_p=0.1234
),
"./sub/sub2/sub3/xyz.png": ImageConfig(
main_prompts={ 'global prompt' },
tags=[ Tag("global tag"), Tag("sub global tag") ],
flip_p=0.0
)
}
self.assertEqual(expected, actual)
def test_can_load_dataset_from_json_manifest(self):
self.fs.create_file("./stuff/image_1.jpg")
self.fs.create_file("./stuff/default.caption", contents= "default caption")
self.fs.create_file("./other/image_1.jpg")
self.fs.create_file("./other/image_2.jpg")
self.fs.create_file("./other/image_3.jpg")
self.fs.create_file("./manifest.json", contents=dedent("""
[
{ "image": "./stuff/image_1.jpg", "caption": "./stuff/default.caption" },
{ "image": "./other/image_1.jpg", "caption": "other caption" },
{
"image": "./other/image_2.jpg",
"caption": {
"main_prompt": "complex caption",
"rating": 0.1,
"max_caption_length": 1000,
"tags": [
{"tag": "including"},
{"tag": "weighted tag", "weight": 999.9}
]
}
},
{
"image": "./other/image_3.jpg",
"multiply": 2,
"flip_p": 0.5,
"cond_dropout": 0.01,
"main_prompt": [
"first caption",
"second caption"
]
}
]
"""))
actual = Dataset.from_json("./manifest.json").image_configs
expected = {
"./stuff/image_1.jpg": ImageConfig( main_prompts={"default caption"} ),
"./other/image_1.jpg": ImageConfig( main_prompts={"other caption"} ),
"./other/image_2.jpg": ImageConfig(
main_prompts={ "complex caption" },
rating=0.1,
max_caption_length=1000,
tags={
Tag("including"),
Tag("weighted tag", 999.9)
}
),
"./other/image_3.jpg": ImageConfig(
main_prompts={ "first caption", "second caption" },
multiply=2,
flip_p=0.5,
cond_dropout=0.01
)
}
self.assertEqual(expected, actual)
def test_original_tag_order_is_retained_in_dataset(self):
def get_random_string(length):
letters = string.ascii_lowercase
return ''.join(random.choice(letters) for _ in range(length))
import uuid
tags=[str(uuid.uuid4()) for _ in range(10000)]
caption='main_prompt,'+", ".join(tags)
self.fs.create_file("image.png")
self.fs.create_file("image.txt", contents=caption)
actual = Dataset.from_path(".").image_configs
expected = { "./image.png": ImageConfig( main_prompts="main_prompt", tags=map(Tag.parse, tags)) }
self.assertEqual(actual, expected)
def test_tag_order_is_retained_in_train_item(self):
dataset = Dataset({
"1.jpg": ImageConfig(
main_prompts="main_prompt",
tags=[
Tag("xyz"),
Tag("abc"),
Tag("ijk")
])
})
aspects = []
actual = dataset.image_train_items(aspects)
self.assertEqual(len(actual), 1)
self.assertEqual(actual[0].pathname, os.path.abspath('1.jpg'))
self.assertEqual(actual[0].caption.get_caption(), "main_prompt, xyz, abc, ijk")
def test_dataset_can_produce_train_items(self):
self.fs.create_file("./sub/global.yaml",
contents=dedent("""\
main_prompt: global prompt
tags:
- low prio global tag
- tag: high prio global tag
weight: 10.0
"""))
self.fs.create_file("./sub/nested/local.yaml",
contents=dedent("""
tags:
- tag: local tag
"""))
self.fs.create_file("./sub/sub.jpg")
self.fs.create_file("./sub/sub.yaml",
contents=dedent("""\
main_prompt: sub.jpg prompt
tags:
- sub.jpg tag
- another tag
- last tag
rating: 1.1
max_caption_length: 1024
multiply: 2
flip_p: 0.1
cond_dropout: 0.01
"""))
self.fs.create_file("./sub/nested/nested.jpg")
self.fs.create_file("./sub/nested/nested.yaml",
contents=dedent("""\
main_prompt: nested.jpg prompt
tags:
- tag: nested.jpg tag
weight: 0.1
"""))
self.fs.create_file("./root.jpg")
self.fs.create_file("./root.txt", contents="root.jpg prompt, root.jpg tag")
aspects = []
dataset = Dataset.from_path(".")
actual = dataset.image_train_items(aspects)
self.assertEqual(len(actual), 3)
self.assertEqual(actual[0].pathname, os.path.abspath('root.jpg'))
self.assertEqual(actual[0].multiplier, 1.0)
self.assertEqual(actual[0].flip.p, 0.0)
self.assertIsNone(actual[0].cond_dropout)
self.assertEqual(actual[0].caption.rating(), 1.0)
self.assertEqual(actual[0].caption.get_caption(), "root.jpg prompt, root.jpg tag")
self.assertFalse(actual[0].caption._ImageCaption__use_weights)
self.assertEqual(actual[0].caption._ImageCaption__max_target_length, DEFAULT_MAX_CAPTION_LENGTH)
self.assertEqual(actual[1].pathname, os.path.abspath('sub/sub.jpg'))
self.assertEqual(actual[1].multiplier, 2.0)
self.assertEqual(actual[1].flip.p, 0.1)
self.assertEqual(actual[1].cond_dropout, 0.01)
self.assertEqual(actual[1].caption.rating(), 1.1)
self.assertEqual(actual[1].caption.get_caption(), "sub.jpg prompt, high prio global tag, sub.jpg tag, another tag, last tag, low prio global tag")
self.assertTrue(actual[1].caption._ImageCaption__use_weights)
self.assertEqual(actual[1].caption._ImageCaption__max_target_length, 1024)
self.assertEqual(actual[2].pathname, os.path.abspath('sub/nested/nested.jpg'))
self.assertEqual(actual[2].multiplier, 1.0)
self.assertEqual(actual[2].flip.p, 0.0)
self.assertIsNone(actual[2].cond_dropout)
self.assertEqual(actual[2].caption.rating(), 1.0)
self.assertEqual(actual[2].caption.get_caption(), "nested.jpg prompt, high prio global tag, local tag, low prio global tag, nested.jpg tag")
self.assertTrue(actual[2].caption._ImageCaption__use_weights)
self.assertEqual(actual[2].caption._ImageCaption__max_target_length, DEFAULT_MAX_CAPTION_LENGTH)