From 60e10867bcbb49b3a65d27cf5a76880bb2466ea4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Rold=C3=A1n?= Date: Thu, 5 Oct 2023 01:27:47 -0300 Subject: [PATCH] Fix undersize warning --- data/image_train_item.py | 4 +- test/test_image_train_item.py | 69 ++++++++++++++++++++++++++++++++++- 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/data/image_train_item.py b/data/image_train_item.py index 31815c3..a8979c3 100644 --- a/data/image_train_item.py +++ b/data/image_train_item.py @@ -314,8 +314,10 @@ class ImageTrainItem: image_aspect = width / height target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect)) - self.is_undersized = (width * height) < (target_wh[0]*1.02 * target_wh[1]*1.02) + self.is_undersized = (width != target_wh[0] and height != target_wh[1]) and (width * height) < (target_wh[0]*1.02 * target_wh[1]*1.02) + self.target_wh = target_wh + self.image_size = image.size except Exception as e: self.error = e diff --git a/test/test_image_train_item.py b/test/test_image_train_item.py index 75ffb41..69487f7 100644 --- a/test/test_image_train_item.py +++ b/test/test_image_train_item.py @@ -4,6 +4,7 @@ import pathlib import PIL.Image as Image from data.image_train_item import ImageCaption, ImageTrainItem +import data.aspects as aspects DATA_PATH = pathlib.Path('./test/data') @@ -32,4 +33,70 @@ class TestImageCaption(unittest.TestCase): self.assertEqual(caption.get_caption(), "hello world, one, two, three") caption = ImageCaption("hello world", 1.0, [], [], 2048, False) - self.assertEqual(caption.get_caption(), "hello world") \ No newline at end of file + self.assertEqual(caption.get_caption(), "hello world") + +class TestImageTrainItemConstructor(unittest.TestCase): + + def tearDown(self) -> None: + for file in DATA_PATH.glob("img_*"): + file.unlink() + + return super().tearDown() + + @staticmethod + def image_with_size(width, height): + filename = DATA_PATH / "img_{}x{}.jpg".format(width, height) + Image.new("RGB", (width, height)).save(filename) + caption = ImageCaption("hello world", 1.0, [], [], 2048, False) + return ImageTrainItem(None, caption, aspects.ASPECTS_512, filename, 0.0, 1.0, False, False, 0) + + def test_target_size_computation(self): + # Square images + image = self.image_with_size(30, 30) + self.assertEqual(image.target_wh, [512,512]) + self.assertTrue(image.is_undersized) + self.assertEqual(image.image_size, (30,30)) + + image = self.image_with_size(512, 512) + self.assertEqual(image.target_wh, [512,512]) + self.assertFalse(image.is_undersized) + self.assertEqual(image.image_size, (512,512)) + + image = self.image_with_size(580, 580) + self.assertEqual(image.target_wh, [512,512]) + self.assertFalse(image.is_undersized) + self.assertEqual(image.image_size, (580,580)) + + # Landscape images + image = self.image_with_size(64, 38) + self.assertEqual(image.target_wh, [640,384]) + self.assertTrue(image.is_undersized) + self.assertEqual(image.image_size, (64,38)) + + image = self.image_with_size(640, 384) + self.assertEqual(image.target_wh, [640,384]) + self.assertFalse(image.is_undersized) + self.assertEqual(image.image_size, (640,384)) + + image = self.image_with_size(704, 422) + self.assertEqual(image.target_wh, [640,384]) + self.assertFalse(image.is_undersized) + self.assertEqual(image.image_size, (704,422)) + + # Portrait images + image = self.image_with_size(38, 64) + self.assertEqual(image.target_wh, [384,640]) + self.assertTrue(image.is_undersized) + self.assertEqual(image.image_size, (38,64)) + + image = self.image_with_size(384, 640) + self.assertEqual(image.target_wh, [384,640]) + self.assertFalse(image.is_undersized) + self.assertEqual(image.image_size, (384,640)) + + image = self.image_with_size(422, 704) + self.assertEqual(image.target_wh, [384,640]) + self.assertFalse(image.is_undersized) + self.assertEqual(image.image_size, (422,704)) + +