Add unit tests for data.image_train_item
This commit is contained in:
parent
a6cabe8d7d
commit
f4f684a915
|
@ -0,0 +1,71 @@
|
||||||
|
import unittest
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import PIL.Image as Image
|
||||||
|
|
||||||
|
from data.image_train_item import ImageCaption, ImageTrainItem
|
||||||
|
|
||||||
|
DATA_PATH = pathlib.Path('./test/data')
|
||||||
|
|
||||||
|
class TestImageCaption(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
with open(DATA_PATH / "test1.txt", encoding='utf-8', mode='w') as f:
|
||||||
|
f.write("caption for test1")
|
||||||
|
|
||||||
|
Image.new("RGB", (512,512)).save(DATA_PATH / "test1.jpg")
|
||||||
|
Image.new("RGB", (512,512)).save(DATA_PATH / "test2.jpg")
|
||||||
|
|
||||||
|
with open(DATA_PATH / "test_caption.caption", encoding='utf-8', mode='w') as f:
|
||||||
|
f.write("caption for test2")
|
||||||
|
|
||||||
|
return super().setUp()
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
for file in DATA_PATH.glob("test*"):
|
||||||
|
file.unlink()
|
||||||
|
|
||||||
|
return super().tearDown()
|
||||||
|
|
||||||
|
def test_constructor(self):
|
||||||
|
caption = ImageCaption("hello world", 1.0, ["one", "two", "three"], [1.0]*3, 2048, False)
|
||||||
|
self.assertEqual(caption.get_caption(), "hello world, one, two, three")
|
||||||
|
|
||||||
|
caption = ImageCaption("hello world", 1.0, [], [], 2048, False)
|
||||||
|
self.assertEqual(caption.get_caption(), "hello world")
|
||||||
|
|
||||||
|
def test_parse(self):
|
||||||
|
caption = ImageCaption.parse("hello world, one, two, three")
|
||||||
|
|
||||||
|
self.assertEqual(caption.get_caption(), "hello world, one, two, three")
|
||||||
|
|
||||||
|
def test_from_file_name(self):
|
||||||
|
caption = ImageCaption.from_file_name("foo bar_1_2_3.jpg")
|
||||||
|
self.assertEqual(caption.get_caption(), "foo bar")
|
||||||
|
|
||||||
|
def test_from_text_file(self):
|
||||||
|
caption = ImageCaption.from_text_file("test/data/test1.txt")
|
||||||
|
self.assertEqual(caption.get_caption(), "caption for test1")
|
||||||
|
|
||||||
|
def test_from_file(self):
|
||||||
|
caption = ImageCaption.from_file("test/data/test1.txt")
|
||||||
|
self.assertEqual(caption.get_caption(), "caption for test1")
|
||||||
|
|
||||||
|
caption = ImageCaption.from_file("test/data/test_caption.caption")
|
||||||
|
self.assertEqual(caption.get_caption(), "caption for test2")
|
||||||
|
|
||||||
|
def test_resolve(self):
|
||||||
|
caption = ImageCaption.resolve("test/data/test1.txt")
|
||||||
|
self.assertEqual(caption.get_caption(), "caption for test1")
|
||||||
|
|
||||||
|
caption = ImageCaption.resolve("test/data/test_caption.caption")
|
||||||
|
self.assertEqual(caption.get_caption(), "caption for test2")
|
||||||
|
|
||||||
|
caption = ImageCaption.resolve("hello world")
|
||||||
|
self.assertEqual(caption.get_caption(), "hello world")
|
||||||
|
|
||||||
|
caption = ImageCaption.resolve("test/data/test1.jpg")
|
||||||
|
self.assertEqual(caption.get_caption(), "caption for test1")
|
||||||
|
|
||||||
|
caption = ImageCaption.resolve("test/data/test2.jpg")
|
||||||
|
self.assertEqual(caption.get_caption(), "test2")
|
Loading…
Reference in New Issue