EveryDream2trainer/test/test_dataset.py

329 lines
12 KiB
Python

import os
from data.dataset import Dataset, ImageConfig, Caption, Tag
from textwrap import dedent
from pyfakefs.fake_filesystem_unittest import TestCase
class TestResolve(TestCase):
def setUp(self):
self.setUpPyfakefs()
def test_simple_image(self):
self.fs.create_file("image, tag1, tag2.jpg")
actual = Dataset.from_path(".").image_configs
expected = {
ImageConfig(
image="./image, tag1, tag2.jpg",
captions=frozenset([
Caption(main_prompt="image", tags=frozenset([Tag("tag1"), Tag("tag2")]))
]))
}
self.assertEqual(expected, actual)
def test_image_types(self):
self.fs.create_file("image_1.JPG")
self.fs.create_file("image_2.jpeg")
self.fs.create_file("image_3.png")
self.fs.create_file("image_4.webp")
self.fs.create_file("image_5.jfif")
self.fs.create_file("image_6.bmp")
actual = Dataset.from_path(".").image_configs
captions = frozenset([Caption(main_prompt="image")])
expected = {
ImageConfig(image="./image_1.JPG", captions=captions),
ImageConfig(image="./image_2.jpeg", captions=captions),
ImageConfig(image="./image_3.png", captions=captions),
ImageConfig(image="./image_4.webp", captions=captions),
ImageConfig(image="./image_5.jfif", captions=captions),
ImageConfig(image="./image_6.bmp", captions=captions),
}
self.assertEqual(expected, actual)
def test_caption_file(self):
self.fs.create_file("image_1.jpg")
self.fs.create_file("image_1.txt", contents="an image, test, from .txt")
self.fs.create_file("image_2.jpg")
self.fs.create_file("image_2.caption", contents="an image, test, from .caption")
actual = Dataset.from_path(".").image_configs
expected = {
ImageConfig(
image="./image_1.jpg",
captions=frozenset([
Caption(main_prompt="an image", tags=frozenset([Tag("test"), Tag("from .txt")]))
])),
ImageConfig(
image="./image_2.jpg",
captions=frozenset([
Caption(main_prompt="an image", tags=frozenset([Tag("test"), Tag("from .caption")]))
]))
}
self.assertEqual(expected, actual)
def test_image_yaml(self):
self.fs.create_file("image_1.jpg")
self.fs.create_file("image_1.yaml",
contents=dedent("""
multiply: 2
cond_dropout: 0.05
flip_p: 0.5
caption: "A simple caption, from .yaml"
"""))
self.fs.create_file("image_2.jpg")
self.fs.create_file("image_2.yml",
contents=dedent("""
flip_p: 0.0
caption:
main_prompt: A complex caption
rating: 1.1
max_caption_length: 1024
tags:
- tag: from .yml
- tag: with weight
weight: 0.5
"""))
actual = Dataset.from_path(".").image_configs
expected = {
ImageConfig(
image="./image_1.jpg",
multiply=2,
cond_dropout=0.05,
flip_p=0.5,
captions=frozenset([
Caption(main_prompt="A simple caption", tags=frozenset([Tag("from .yaml")]))
])),
ImageConfig(
image="./image_2.jpg",
flip_p=0.0,
captions=frozenset([
Caption(main_prompt="A complex caption", rating=1.1,
max_caption_length=1024,
tags=frozenset([
Tag("from .yml"),
Tag("with weight", weight=0.5)
]))
]))
}
self.assertEqual(expected, actual)
def test_multi_caption(self):
self.fs.create_file("image_1.jpg")
self.fs.create_file("image_1.yaml", contents=dedent("""
caption: "A simple caption, from .yaml"
captions:
- "Another simple caption"
- main_prompt: A complex caption
"""))
self.fs.create_file("image_1.txt", contents="A .txt caption")
self.fs.create_file("image_1.caption", contents="A .caption caption")
actual = Dataset.from_path(".").image_configs
expected = {
ImageConfig(
image="./image_1.jpg",
captions=frozenset([
Caption(main_prompt="A simple caption", tags=frozenset([Tag("from .yaml")])),
Caption(main_prompt="Another simple caption", tags=frozenset()),
Caption(main_prompt="A complex caption", tags=frozenset()),
Caption(main_prompt="A .txt caption", tags=frozenset()),
Caption(main_prompt="A .caption caption", tags=frozenset())
])
),
}
self.assertEqual(expected, actual)
def test_globals_and_locals(self):
self.fs.create_file("./people/global.yaml", contents=dedent("""\
multiply: 1.0
cond_dropout: 0.0
flip_p: 0.0
"""))
self.fs.create_file("./people/alice/local.yaml", contents="multiply: 1.5")
self.fs.create_file("./people/alice/alice_1.png")
self.fs.create_file("./people/alice/alice_1.yaml", contents="multiply: 2")
self.fs.create_file("./people/alice/alice_2.png")
self.fs.create_file("./people/bob/multiply.txt", contents="3")
self.fs.create_file("./people/bob/cond_dropout.txt", contents="0.05")
self.fs.create_file("./people/bob/flip_p.txt", contents="0.05")
self.fs.create_file("./people/bob/bob.png")
self.fs.create_file("./people/cleo/cleo.png")
self.fs.create_file("./people/dan.png")
self.fs.create_file("./other/dog/local.yaml", contents="caption: spike")
self.fs.create_file("./other/dog/xyz.png")
actual = Dataset.from_path(".").image_configs
expected = {
ImageConfig(
image="./people/alice/alice_1.png",
captions=frozenset([Caption(main_prompt="alice")]),
multiply=2,
cond_dropout=0.0,
flip_p=0.0
),
ImageConfig(
image="./people/alice/alice_2.png",
captions=frozenset([Caption(main_prompt="alice")]),
multiply=1.5,
cond_dropout=0.0,
flip_p=0.0
),
ImageConfig(
image="./people/bob/bob.png",
captions=frozenset([Caption(main_prompt="bob")]),
multiply=3,
cond_dropout=0.05,
flip_p=0.05
),
ImageConfig(
image="./people/cleo/cleo.png",
captions=frozenset([Caption(main_prompt="cleo")]),
multiply=1.0,
cond_dropout=0.0,
flip_p=0.0
),
ImageConfig(
image="./people/dan.png",
captions=frozenset([Caption(main_prompt="dan")]),
multiply=1.0,
cond_dropout=0.0,
flip_p=0.0
),
ImageConfig(
image="./other/dog/xyz.png",
captions=frozenset([Caption(main_prompt="spike")]),
multiply=None,
cond_dropout=None,
flip_p=None
)
}
self.assertEqual(expected, actual)
def test_json_manifest(self):
self.fs.create_file("./stuff/image_1.jpg")
self.fs.create_file("./stuff/default.caption", contents= "default caption")
self.fs.create_file("./other/image_1.jpg")
self.fs.create_file("./other/image_2.jpg")
self.fs.create_file("./other/image_3.jpg")
self.fs.create_file("./manifest.json", contents=dedent("""
[
{ "image": "./stuff/image_1.jpg", "caption": "./stuff/default.caption" },
{ "image": "./other/image_1.jpg", "caption": "other caption" },
{
"image": "./other/image_2.jpg",
"caption": {
"main_prompt": "complex caption",
"rating": 0.1,
"max_caption_length": 1000,
"tags": [
{"tag": "including"},
{"tag": "weighted tag", "weight": 999.9}
]
}
},
{
"image": "./other/image_3.jpg",
"multiply": 2,
"flip_p": 0.5,
"cond_dropout": 0.01,
"captions": [
"first caption",
{ "main_prompt": "second caption" }
]
}
]
"""))
actual = Dataset.from_json("./manifest.json").image_configs
expected = {
ImageConfig(
image="./stuff/image_1.jpg",
captions=frozenset([Caption(main_prompt="default caption")])
),
ImageConfig(
image="./other/image_1.jpg",
captions=frozenset([Caption(main_prompt="other caption")])
),
ImageConfig(
image="./other/image_2.jpg",
captions=frozenset([
Caption(
main_prompt="complex caption",
rating=0.1,
max_caption_length=1000,
tags=frozenset([
Tag("including"),
Tag("weighted tag", 999.9)
]))
])
),
ImageConfig(
image="./other/image_3.jpg",
multiply=2,
flip_p=0.5,
cond_dropout=0.01,
captions=frozenset([
Caption("first caption"),
Caption("second caption")
])
)
}
self.assertEqual(expected, actual)
def test_train_items(self):
dataset = Dataset([
ImageConfig(
image="1.jpg",
multiply=2,
flip_p=0.1,
cond_dropout=0.01,
captions=frozenset([
Caption(
main_prompt="first caption",
rating = 1.1,
max_caption_length=1024,
tags=frozenset([
Tag("tag"),
Tag("tag_2", 2.0)
])),
Caption(main_prompt="second_caption")
])),
ImageConfig(
image="2.jpg",
captions=frozenset([Caption(main_prompt="single caption")])
)
])
aspects = []
actual = dataset.image_train_items(aspects)
self.assertEqual(len(actual), 2)
self.assertEqual(actual[0].pathname, os.path.abspath('1.jpg'))
self.assertEqual(actual[0].multiplier, 2.0)
self.assertEqual(actual[0].flip.p, 0.1)
self.assertEqual(actual[0].cond_dropout, 0.01)
self.assertEqual(actual[0].caption.get_caption(), "first caption, tag, tag_2")
# Can't test this
# self.assertTrue(actual[0].caption.__use_weights)
self.assertEqual(actual[1].pathname, os.path.abspath('2.jpg'))
self.assertEqual(actual[1].multiplier, 1.0)
self.assertEqual(actual[1].flip.p, 0.0)
self.assertIsNone(actual[1].cond_dropout)
self.assertEqual(actual[1].caption.get_caption(), "single caption")
# Can't test this
# self.assertFalse(actual[1].caption.__use_weights)