EveryDream2trainer/test/test_data_resolver.py

137 lines
5.1 KiB
Python
Raw Normal View History

import json
import glob
import os
import unittest
import argparse
import PIL.Image as Image
import data.aspects as aspects
import data.resolver as resolver
DATA_PATH = os.path.abspath('./test/data')
JSON_ROOT_PATH = os.path.join(DATA_PATH, 'test_root.json')
IMAGE_1_PATH = os.path.join(DATA_PATH, 'test1.jpg')
CAPTION_1_PATH = os.path.join(DATA_PATH, 'test1.txt')
IMAGE_2_PATH = os.path.join(DATA_PATH, 'test2.jpg')
IMAGE_3_PATH = os.path.join(DATA_PATH, 'test3.jpg')
ARGS = argparse.Namespace(
aspects=aspects.get_aspect_buckets(512),
flip_p=0.5,
seed=42,
)
class TestResolve(unittest.TestCase):
@classmethod
def setUpClass(cls):
Image.new('RGB', (512, 512)).save(IMAGE_1_PATH)
with open(CAPTION_1_PATH, 'w') as f:
f.write('caption for test1')
Image.new('RGB', (512, 512)).save(IMAGE_2_PATH)
2023-01-23 17:57:02 -07:00
# Undersized image
Image.new('RGB', (256, 256)).save(IMAGE_3_PATH)
json_data = [
{
'image': IMAGE_1_PATH,
'caption': CAPTION_1_PATH
},
{
'image': IMAGE_2_PATH,
'caption': 'caption for test2'
},
{
'image': IMAGE_3_PATH,
}
]
with open(JSON_ROOT_PATH, 'w') as f:
json.dump(json_data, f, indent=4)
@classmethod
def tearDownClass(cls):
for file in glob.glob(os.path.join(DATA_PATH, 'test*')):
os.remove(file)
def test_directory_resolve_with_str(self):
items = resolver.resolve(DATA_PATH, ARGS)
image_paths = set(item.pathname for item in items)
2023-01-23 00:58:25 -07:00
image_captions = [item.caption for item in items]
captions = set(caption.get_caption() for caption in image_captions)
2023-01-23 00:58:25 -07:00
self.assertEqual(len(items), 3)
self.assertEqual(image_paths, {IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH})
self.assertEqual(captions, {'caption for test1', 'test2', 'test3'})
2023-01-23 00:58:25 -07:00
2023-01-23 13:15:35 -07:00
undersized_images = list(filter(lambda i: i.is_undersized, items))
2023-01-23 17:57:02 -07:00
self.assertEqual(len(undersized_images), 1)
def test_directory_resolve_with_dict(self):
data_root_spec = {
'resolver': 'directory',
'path': DATA_PATH,
}
items = sorted(resolver.resolve(data_root_spec, ARGS), key=lambda i: i.pathname)
2023-01-23 00:58:25 -07:00
image_paths = [item.pathname for item in items]
image_captions = [item.caption for item in items]
captions = [caption.get_caption() for caption in image_captions]
2023-01-23 00:58:25 -07:00
self.assertEqual(len(items), 3)
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH])
self.assertEqual(captions, ['caption for test1', 'test2', 'test3'])
2023-01-23 00:58:25 -07:00
2023-01-23 13:15:35 -07:00
undersized_images = list(filter(lambda i: i.is_undersized, items))
2023-01-23 17:57:02 -07:00
self.assertEqual(len(undersized_images), 1)
def test_json_resolve_with_str(self):
items = sorted(resolver.resolve(JSON_ROOT_PATH, ARGS), key=lambda i: i.pathname)
2023-01-23 00:58:25 -07:00
image_paths = [item.pathname for item in items]
image_captions = [item.caption for item in items]
captions = [caption.get_caption() for caption in image_captions]
2023-01-23 00:58:25 -07:00
self.assertEqual(len(items), 3)
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH])
self.assertEqual(captions, ['caption for test1', 'caption for test2', 'test3'])
2023-01-23 00:58:25 -07:00
2023-01-23 13:15:35 -07:00
undersized_images = list(filter(lambda i: i.is_undersized, items))
2023-01-23 17:57:02 -07:00
self.assertEqual(len(undersized_images), 1)
def test_json_resolve_with_dict(self):
data_root_spec = {
'resolver': 'json',
'path': JSON_ROOT_PATH,
}
items = resolver.resolve(data_root_spec, ARGS)
2023-01-23 00:58:25 -07:00
image_paths = [item.pathname for item in items]
image_captions = [item.caption for item in items]
captions = [caption.get_caption() for caption in image_captions]
2023-01-23 00:58:25 -07:00
self.assertEqual(len(items), 3)
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH])
self.assertEqual(captions, ['caption for test1', 'caption for test2', 'test3'])
2023-01-23 00:58:25 -07:00
2023-01-23 13:15:35 -07:00
undersized_images = list(filter(lambda i: i.is_undersized, items))
2023-01-29 18:21:12 -07:00
self.assertEqual(len(undersized_images), 1)
def test_resolve_with_list(self):
data_root_spec = [
DATA_PATH,
JSON_ROOT_PATH,
]
items = sorted(resolver.resolve(data_root_spec, ARGS), key=lambda i: i.pathname)
2023-01-29 18:21:12 -07:00
image_paths = [item.pathname for item in items]
image_captions = [item.caption for item in items]
captions = [caption.get_caption() for caption in image_captions]
self.assertEqual(len(items), 6)
self.assertEqual(set(image_paths), set([IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH] * 2))
2023-04-22 16:36:32 -06:00
self.assertEqual(set(captions), {'caption for test1', 'test2', 'test3', 'caption for test1', 'caption for test2', 'test3'})
2023-01-29 18:21:12 -07:00
undersized_images = list(filter(lambda i: i.is_undersized, items))
self.assertEqual(len(undersized_images), 2)