diff --git a/data/dataset.py b/data/dataset.py index a0c761a..2f42b18 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -36,7 +36,7 @@ class Tag: return Tag(data) if isinstance(data, dict): - value = data.get("tag") + value = str(data.get("tag")) weight = data.get("weight") if value: return Tag(value, weight) diff --git a/test/test_data_resolver.py b/test/test_data_resolver.py index 85042cd..e0b7502 100644 --- a/test/test_data_resolver.py +++ b/test/test_data_resolver.py @@ -131,7 +131,7 @@ class TestResolve(unittest.TestCase): self.assertEqual(len(items), 6) self.assertEqual(set(image_paths), set([IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH] * 2)) - self.assertEqual(set(captions), {}'caption for test1', 'test2', 'test3', 'caption for test1', 'caption for test2', 'test3'}) + self.assertEqual(set(captions), {'caption for test1', 'test2', 'test3', 'caption for test1', 'caption for test2', 'test3'}) undersized_images = list(filter(lambda i: i.is_undersized, items)) self.assertEqual(len(undersized_images), 2) \ No newline at end of file diff --git a/test/test_dataset.py b/test/test_dataset.py index 5399184..2acd381 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -77,6 +77,7 @@ class TestDataset(TestCase): - tag: from .yml - tag: with weight weight: 0.5 + - tag: 1234.5 """)) actual = Dataset.from_path(".").image_configs @@ -94,7 +95,7 @@ class TestDataset(TestCase): rating=1.1, max_caption_length=1024, main_prompts="A complex caption", - tags= { Tag("from .yml"), Tag("with weight", weight=0.5) } + tags= { Tag("from .yml"), Tag("with weight", weight=0.5), Tag("1234.5") } ) } self.assertEqual(expected, actual)