diff --git a/data/dataset.py b/data/dataset.py index 1c231ba..69f04fa 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -12,6 +12,8 @@ from typing import TypeVar, Iterable from tqdm import tqdm +DEFAULT_MAX_CAPTION_LENGTH = 2048 + def overlay(overlay, base): return overlay if overlay is not None else base @@ -235,7 +237,7 @@ class Dataset: rating=config.rating or 1.0, tags=tags, tag_weights=tag_weights, - max_target_length=config.max_caption_length, + max_target_length=config.max_caption_length or DEFAULT_MAX_CAPTION_LENGTH, use_weights=use_weights) item = ImageTrainItem( diff --git a/data/image_train_item.py b/data/image_train_item.py index bab9c13..623c664 100644 --- a/data/image_train_item.py +++ b/data/image_train_item.py @@ -29,7 +29,6 @@ from torchvision import transforms _RANDOM_TRIM = 0.04 -DEFAULT_MAX_CAPTION_LENGTH = 2048 OptionalImageCaption = typing.Optional['ImageCaption'] diff --git a/test/test_dataset.py b/test/test_dataset.py index 981ca3d..6179234 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -1,5 +1,5 @@ import os -from data.dataset import Dataset, ImageConfig, Tag +from data.dataset import Dataset, ImageConfig, Tag, DEFAULT_MAX_CAPTION_LENGTH from textwrap import dedent from pyfakefs.fake_filesystem_unittest import TestCase @@ -276,8 +276,8 @@ class TestDataset(TestCase): self.assertEqual(actual[0].cond_dropout, 0.01) self.assertEqual(actual[0].caption.rating(), 1.1) self.assertEqual(actual[0].caption.get_caption(), "first caption, tag, tag_2") - # Can't test this - # self.assertTrue(actual[0].caption.__use_weights) + self.assertTrue(actual[0].caption._ImageCaption__use_weights) + self.assertEqual(actual[0].caption._ImageCaption__max_target_length, 1024) self.assertEqual(actual[1].pathname, os.path.abspath('2.jpg')) self.assertEqual(actual[1].multiplier, 1.0) @@ -285,5 +285,5 @@ class TestDataset(TestCase): self.assertIsNone(actual[1].cond_dropout) self.assertEqual(actual[1].caption.rating(), 1.0) self.assertEqual(actual[1].caption.get_caption(), "single caption") - # Can't test this - # self.assertFalse(actual[1].caption.__use_weights) \ No newline at end of file + self.assertFalse(actual[1].caption._ImageCaption__use_weights) + self.assertEqual(actual[1].caption._ImageCaption__max_target_length, DEFAULT_MAX_CAPTION_LENGTH) \ No newline at end of file