import os from data.dataset import Dataset, ImageConfig, Tag, DEFAULT_MAX_CAPTION_LENGTH from textwrap import dedent from pyfakefs.fake_filesystem_unittest import TestCase class TestDataset(TestCase): def setUp(self): self.maxDiff = None self.setUpPyfakefs() def test_a_caption_is_generated_from_image_given_no_other_config(self): self.fs.create_file("image, tag1, tag2.jpg") actual = Dataset.from_path(".").image_configs expected = { "./image, tag1, tag2.jpg": ImageConfig(main_prompts="image", tags=frozenset([Tag("tag1"), Tag("tag2")])) } self.assertEqual(expected, actual) def test_several_image_formats_are_supported(self): self.fs.create_file("image.JPG") self.fs.create_file("image.jpeg") self.fs.create_file("image.png") self.fs.create_file("image.webp") self.fs.create_file("image.jfif") self.fs.create_file("image.bmp") actual = Dataset.from_path(".").image_configs common_cfg = ImageConfig(main_prompts="image") expected = { "./image.JPG": common_cfg, "./image.jpeg": common_cfg, "./image.png": common_cfg, "./image.webp": common_cfg, "./image.jfif": common_cfg, "./image.bmp": common_cfg, } self.assertEqual(expected, actual) def test_captions_can_be_read_from_txt_or_caption_sidecar(self): self.fs.create_file("image_1.jpg") self.fs.create_file("image_1.txt", contents="an image, test, from .txt") self.fs.create_file("image_2.jpg") self.fs.create_file("image_2.caption", contents="an image, test, from .caption") actual = Dataset.from_path(".").image_configs expected = { "./image_1.jpg": ImageConfig(main_prompts="an image", tags=frozenset([Tag("test"), Tag("from .txt")])), "./image_2.jpg": ImageConfig(main_prompts="an image", tags=frozenset([Tag("test"), Tag("from .caption")])) } self.assertEqual(expected, actual) def test_captions_and_options_can_be_read_from_yaml_sidecar(self): self.fs.create_file("image_1.jpg") self.fs.create_file("image_1.yaml", contents=dedent(""" multiply: 2 cond_dropout: 0.05 flip_p: 0.5 caption: "A simple caption, from .yaml" """)) self.fs.create_file("image_2.jpg") self.fs.create_file("image_2.yml", contents=dedent(""" flip_p: 0.0 caption: main_prompt: A complex caption rating: 1.1 max_caption_length: 1024 tags: - tag: from .yml - tag: with weight weight: 0.5 - tag: 1234.5 """)) actual = Dataset.from_path(".").image_configs expected = { "./image_1.jpg": ImageConfig( multiply=2, cond_dropout=0.05, flip_p=0.5, main_prompts="A simple caption", tags= { Tag("from .yaml") } ), "./image_2.jpg": ImageConfig( flip_p=0.0, rating=1.1, max_caption_length=1024, main_prompts="A complex caption", tags= { Tag("from .yml"), Tag("with weight", weight=0.5), Tag("1234.5") } ) } self.assertEqual(expected, actual) def test_captions_are_read_from_filename_if_no_main_prompt(self): self.fs.create_file("filename main prompt, filename tag.jpg") self.fs.create_file("filename main prompt, filename tag.yaml", contents=dedent(""" caption: tags: - tag: standalone yaml tag """)) actual = Dataset.from_path(".").image_configs expected = { "./filename main prompt, filename tag.jpg": ImageConfig( main_prompts="filename main prompt", tags= [ Tag("filename tag"), Tag("standalone yaml tag") ] ) } self.assertEqual(expected, actual) def test_multiple_prompts_and_tags_from_multiple_sidecars_are_supported(self): self.fs.create_file("image_1.jpg") self.fs.create_file("image_1.yaml", contents=dedent(""" main_prompt: - unique prompt - dupe prompt tags: - from .yaml - dupe tag """)) self.fs.create_file("image_1.txt", contents="also unique prompt, from .txt, dupe tag") self.fs.create_file("image_1.caption", contents="dupe prompt, from .caption") actual = Dataset.from_path(".").image_configs expected = { "./image_1.jpg": ImageConfig( main_prompts={ "unique prompt", "also unique prompt", "dupe prompt" }, tags={ Tag("from .yaml"), Tag("from .txt"), Tag("from .caption"), Tag("dupe tag") } ) } self.assertEqual(expected, actual) def test_sidecars_can_also_be_attached_to_local_and_recursive_folders(self): self.fs.create_file("./global.yaml", contents=dedent("""\ main_prompt: global prompt tags: - global tag flip_p: 0.0 """)) self.fs.create_file("./local.yaml", contents=dedent(""" main_prompt: local prompt tags: - tag: local tag """)) self.fs.create_file("./arbitrary filename.png") self.fs.create_file("./sub/sub arbitrary filename.png") self.fs.create_file("./sub/sidecar.png") self.fs.create_file("./sub/sidecar.txt", contents="sidecar prompt, sidecar tag") self.fs.create_file("./optfile/optfile.png") self.fs.create_file("./optfile/flip_p.txt", contents="0.1234") self.fs.create_file("./sub/sub2/global.yaml", contents=dedent(""" tags: - tag: sub global tag """)) self.fs.create_file("./sub/sub2/local.yaml", contents=dedent(""" tags: - This tag wil not apply to any files """)) self.fs.create_file("./sub/sub2/sub3/xyz.png") actual = Dataset.from_path(".").image_configs expected = { "./arbitrary filename.png": ImageConfig( main_prompts={ 'global prompt', 'local prompt' }, tags=[ Tag("global tag"), Tag("local tag") ], flip_p=0.0 ), "./sub/sub arbitrary filename.png": ImageConfig( main_prompts={ 'global prompt' }, tags=[ Tag("global tag") ], flip_p=0.0 ), "./sub/sidecar.png": ImageConfig( main_prompts={ 'global prompt', 'sidecar prompt' }, tags=[ Tag("global tag"), Tag("sidecar tag") ], flip_p=0.0 ), "./optfile/optfile.png": ImageConfig( main_prompts={ 'global prompt' }, tags=[ Tag("global tag") ], flip_p=0.1234 ), "./sub/sub2/sub3/xyz.png": ImageConfig( main_prompts={ 'global prompt' }, tags=[ Tag("global tag"), Tag("sub global tag") ], flip_p=0.0 ) } self.assertEqual(expected, actual) def test_can_load_dataset_from_json_manifest(self): self.fs.create_file("./stuff/image_1.jpg") self.fs.create_file("./stuff/default.caption", contents= "default caption") self.fs.create_file("./other/image_1.jpg") self.fs.create_file("./other/image_2.jpg") self.fs.create_file("./other/image_3.jpg") self.fs.create_file("./manifest.json", contents=dedent(""" [ { "image": "./stuff/image_1.jpg", "caption": "./stuff/default.caption" }, { "image": "./other/image_1.jpg", "caption": "other caption" }, { "image": "./other/image_2.jpg", "caption": { "main_prompt": "complex caption", "rating": 0.1, "max_caption_length": 1000, "tags": [ {"tag": "including"}, {"tag": "weighted tag", "weight": 999.9} ] } }, { "image": "./other/image_3.jpg", "multiply": 2, "flip_p": 0.5, "cond_dropout": 0.01, "main_prompt": [ "first caption", "second caption" ] } ] """)) actual = Dataset.from_json("./manifest.json").image_configs expected = { "./stuff/image_1.jpg": ImageConfig( main_prompts={"default caption"} ), "./other/image_1.jpg": ImageConfig( main_prompts={"other caption"} ), "./other/image_2.jpg": ImageConfig( main_prompts={ "complex caption" }, rating=0.1, max_caption_length=1000, tags={ Tag("including"), Tag("weighted tag", 999.9) } ), "./other/image_3.jpg": ImageConfig( main_prompts={ "first caption", "second caption" }, multiply=2, flip_p=0.5, cond_dropout=0.01 ) } self.assertEqual(expected, actual) def test_original_tag_order_is_retained_in_dataset(self): def get_random_string(length): letters = string.ascii_lowercase return ''.join(random.choice(letters) for _ in range(length)) import uuid tags=[str(uuid.uuid4()) for _ in range(10000)] caption='main_prompt,'+", ".join(tags) self.fs.create_file("image.png") self.fs.create_file("image.txt", contents=caption) actual = Dataset.from_path(".").image_configs expected = { "./image.png": ImageConfig( main_prompts="main_prompt", tags=map(Tag.parse, tags)) } self.assertEqual(actual, expected) def test_tag_order_is_retained_in_train_item(self): dataset = Dataset({ "1.jpg": ImageConfig( main_prompts="main_prompt", tags=[ Tag("xyz"), Tag("abc"), Tag("ijk") ]) }) aspects = [] actual = dataset.image_train_items(aspects) self.assertEqual(len(actual), 1) self.assertEqual(actual[0].pathname, os.path.abspath('1.jpg')) self.assertEqual(actual[0].caption.get_caption(), "main_prompt, xyz, abc, ijk") def test_dataset_can_produce_train_items(self): self.fs.create_file("./sub/global.yaml", contents=dedent("""\ main_prompt: global prompt tags: - low prio global tag - tag: high prio global tag weight: 10.0 """)) self.fs.create_file("./sub/nested/local.yaml", contents=dedent(""" tags: - tag: local tag """)) self.fs.create_file("./sub/sub.jpg") self.fs.create_file("./sub/sub.yaml", contents=dedent("""\ main_prompt: sub.jpg prompt tags: - sub.jpg tag - another tag - last tag rating: 1.1 max_caption_length: 1024 multiply: 2 flip_p: 0.1 cond_dropout: 0.01 """)) self.fs.create_file("./sub/nested/nested.jpg") self.fs.create_file("./sub/nested/nested.yaml", contents=dedent("""\ main_prompt: nested.jpg prompt tags: - tag: nested.jpg tag weight: 0.1 """)) self.fs.create_file("./root.jpg") self.fs.create_file("./root.txt", contents="root.jpg prompt, root.jpg tag") aspects = [] dataset = Dataset.from_path(".") actual = dataset.image_train_items(aspects) self.assertEqual(len(actual), 3) self.assertEqual(actual[0].pathname, os.path.abspath('root.jpg')) self.assertEqual(actual[0].multiplier, 1.0) self.assertEqual(actual[0].flip.p, 0.0) self.assertIsNone(actual[0].cond_dropout) self.assertEqual(actual[0].caption.rating(), 1.0) self.assertEqual(actual[0].caption.get_caption(), "root.jpg prompt, root.jpg tag") self.assertFalse(actual[0].caption._ImageCaption__use_weights) self.assertEqual(actual[0].caption._ImageCaption__max_target_length, DEFAULT_MAX_CAPTION_LENGTH) self.assertEqual(actual[1].pathname, os.path.abspath('sub/sub.jpg')) self.assertEqual(actual[1].multiplier, 2.0) self.assertEqual(actual[1].flip.p, 0.1) self.assertEqual(actual[1].cond_dropout, 0.01) self.assertEqual(actual[1].caption.rating(), 1.1) self.assertEqual(actual[1].caption.get_caption(), "sub.jpg prompt, high prio global tag, sub.jpg tag, another tag, last tag, low prio global tag") self.assertTrue(actual[1].caption._ImageCaption__use_weights) self.assertEqual(actual[1].caption._ImageCaption__max_target_length, 1024) self.assertEqual(actual[2].pathname, os.path.abspath('sub/nested/nested.jpg')) self.assertEqual(actual[2].multiplier, 1.0) self.assertEqual(actual[2].flip.p, 0.0) self.assertIsNone(actual[2].cond_dropout) self.assertEqual(actual[2].caption.rating(), 1.0) self.assertEqual(actual[2].caption.get_caption(), "nested.jpg prompt, high prio global tag, local tag, low prio global tag, nested.jpg tag") self.assertTrue(actual[2].caption._ImageCaption__use_weights) self.assertEqual(actual[2].caption._ImageCaption__max_target_length, DEFAULT_MAX_CAPTION_LENGTH)