Use filenames for caption if no main prompt in yaml

This commit is contained in:
Augusto de la Torre 2023-04-13 10:36:05 +02:00
parent 8895e8d0d6
commit 2bb35eaa0a
3 changed files with 29 additions and 7 deletions

View File

@ -174,7 +174,7 @@ class Dataset:
# Use file name for caption only as a last resort
@classmethod
def __ensure_caption(cls, cfg: ImageConfig, file: str):
if cfg.main_prompts or cfg.tags:
if cfg.main_prompts:
return cfg
cap_cfg = ImageConfig.from_caption_text(barename(file).split("_")[0])
return cfg.merge(cap_cfg)
@ -217,9 +217,13 @@ class Dataset:
items = []
for image in tqdm(self.image_configs, desc="preloading", dynamic_ncols=True):
config = self.image_configs[image]
if len(config.main_prompts) > 1:
logging.warning(f" *** Found multiple multiple main_prompts for image {image}, but only one will be applied: {config.main_prompts}")
if len(config.main_prompts) < 1:
logging.warning(f" *** No main_prompts for image {image}")
tags = []
tag_weights = []
for tag in sorted(config.tags, key=lambda x: x.weight or 1.0, reverse=True):

View File

@ -75,7 +75,7 @@ class TestResolve(unittest.TestCase):
'path': DATA_PATH,
}
items = resolver.resolve(data_root_spec, ARGS)
items = sorted(resolver.resolve(data_root_spec, ARGS), key=lambda i: i.pathname)
image_paths = [item.pathname for item in items]
image_captions = [item.caption for item in items]
captions = [caption.get_caption() for caption in image_captions]
@ -88,7 +88,7 @@ class TestResolve(unittest.TestCase):
self.assertEqual(len(undersized_images), 1)
def test_json_resolve_with_str(self):
items = resolver.resolve(JSON_ROOT_PATH, ARGS)
items = sorted(resolver.resolve(JSON_ROOT_PATH, ARGS), key=lambda i: i.pathname)
image_paths = [item.pathname for item in items]
image_captions = [item.caption for item in items]
captions = [caption.get_caption() for caption in image_captions]
@ -124,14 +124,14 @@ class TestResolve(unittest.TestCase):
JSON_ROOT_PATH,
]
items = resolver.resolve(data_root_spec, ARGS)
items = sorted(resolver.resolve(data_root_spec, ARGS), key=lambda i: i.pathname)
image_paths = [item.pathname for item in items]
image_captions = [item.caption for item in items]
captions = [caption.get_caption() for caption in image_captions]
self.assertEqual(len(items), 6)
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH] * 2)
self.assertEqual(captions, ['caption for test1', 'test2', 'test3', 'caption for test1', 'caption for test2', 'test3'])
self.assertEqual(set(image_paths), set([IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH] * 2))
self.assertEqual(set(captions), {}'caption for test1', 'test2', 'test3', 'caption for test1', 'caption for test2', 'test3'})
undersized_images = list(filter(lambda i: i.is_undersized, items))
self.assertEqual(len(undersized_images), 2)

View File

@ -100,6 +100,24 @@ class TestDataset(TestCase):
self.assertEqual(expected, actual)
def test_captions_are_read_from_filename_if_no_main_prompt(self):
self.fs.create_file("filename main prompt, filename tag.jpg")
self.fs.create_file("filename main prompt, filename tag.yaml",
contents=dedent("""
caption:
tags:
- tag: standalone yaml tag
"""))
actual = Dataset.from_path(".").image_configs
expected = {
"./filename main prompt, filename tag.jpg": ImageConfig(
main_prompts="filename main prompt",
tags= [ Tag("filename tag"), Tag("standalone yaml tag") ]
)
}
self.assertEqual(expected, actual)
def test_multiple_prompts_and_tags_from_multiple_sidecars_are_supported(self):
self.fs.create_file("image_1.jpg")
self.fs.create_file("image_1.yaml", contents=dedent("""
@ -358,4 +376,4 @@ class TestDataset(TestCase):
self.assertEqual(actual[2].caption.rating(), 1.0)
self.assertEqual(actual[2].caption.get_caption(), "nested.jpg prompt, high prio global tag, local tag, low prio global tag, nested.jpg tag")
self.assertTrue(actual[2].caption._ImageCaption__use_weights)
self.assertEqual(actual[2].caption._ImageCaption__max_target_length, DEFAULT_MAX_CAPTION_LENGTH)
self.assertEqual(actual[2].caption._ImageCaption__max_target_length, DEFAULT_MAX_CAPTION_LENGTH)