Use filenames for caption if no main prompt in yaml
This commit is contained in:
parent
8895e8d0d6
commit
2bb35eaa0a
|
@ -174,7 +174,7 @@ class Dataset:
|
|||
# Use file name for caption only as a last resort
|
||||
@classmethod
|
||||
def __ensure_caption(cls, cfg: ImageConfig, file: str):
|
||||
if cfg.main_prompts or cfg.tags:
|
||||
if cfg.main_prompts:
|
||||
return cfg
|
||||
cap_cfg = ImageConfig.from_caption_text(barename(file).split("_")[0])
|
||||
return cfg.merge(cap_cfg)
|
||||
|
@ -217,9 +217,13 @@ class Dataset:
|
|||
items = []
|
||||
for image in tqdm(self.image_configs, desc="preloading", dynamic_ncols=True):
|
||||
config = self.image_configs[image]
|
||||
|
||||
if len(config.main_prompts) > 1:
|
||||
logging.warning(f" *** Found multiple multiple main_prompts for image {image}, but only one will be applied: {config.main_prompts}")
|
||||
|
||||
if len(config.main_prompts) < 1:
|
||||
logging.warning(f" *** No main_prompts for image {image}")
|
||||
|
||||
tags = []
|
||||
tag_weights = []
|
||||
for tag in sorted(config.tags, key=lambda x: x.weight or 1.0, reverse=True):
|
||||
|
|
|
@ -75,7 +75,7 @@ class TestResolve(unittest.TestCase):
|
|||
'path': DATA_PATH,
|
||||
}
|
||||
|
||||
items = resolver.resolve(data_root_spec, ARGS)
|
||||
items = sorted(resolver.resolve(data_root_spec, ARGS), key=lambda i: i.pathname)
|
||||
image_paths = [item.pathname for item in items]
|
||||
image_captions = [item.caption for item in items]
|
||||
captions = [caption.get_caption() for caption in image_captions]
|
||||
|
@ -88,7 +88,7 @@ class TestResolve(unittest.TestCase):
|
|||
self.assertEqual(len(undersized_images), 1)
|
||||
|
||||
def test_json_resolve_with_str(self):
|
||||
items = resolver.resolve(JSON_ROOT_PATH, ARGS)
|
||||
items = sorted(resolver.resolve(JSON_ROOT_PATH, ARGS), key=lambda i: i.pathname)
|
||||
image_paths = [item.pathname for item in items]
|
||||
image_captions = [item.caption for item in items]
|
||||
captions = [caption.get_caption() for caption in image_captions]
|
||||
|
@ -124,14 +124,14 @@ class TestResolve(unittest.TestCase):
|
|||
JSON_ROOT_PATH,
|
||||
]
|
||||
|
||||
items = resolver.resolve(data_root_spec, ARGS)
|
||||
items = sorted(resolver.resolve(data_root_spec, ARGS), key=lambda i: i.pathname)
|
||||
image_paths = [item.pathname for item in items]
|
||||
image_captions = [item.caption for item in items]
|
||||
captions = [caption.get_caption() for caption in image_captions]
|
||||
|
||||
self.assertEqual(len(items), 6)
|
||||
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH] * 2)
|
||||
self.assertEqual(captions, ['caption for test1', 'test2', 'test3', 'caption for test1', 'caption for test2', 'test3'])
|
||||
self.assertEqual(set(image_paths), set([IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH] * 2))
|
||||
self.assertEqual(set(captions), {}'caption for test1', 'test2', 'test3', 'caption for test1', 'caption for test2', 'test3'})
|
||||
|
||||
undersized_images = list(filter(lambda i: i.is_undersized, items))
|
||||
self.assertEqual(len(undersized_images), 2)
|
|
@ -100,6 +100,24 @@ class TestDataset(TestCase):
|
|||
self.assertEqual(expected, actual)
|
||||
|
||||
|
||||
def test_captions_are_read_from_filename_if_no_main_prompt(self):
|
||||
self.fs.create_file("filename main prompt, filename tag.jpg")
|
||||
self.fs.create_file("filename main prompt, filename tag.yaml",
|
||||
contents=dedent("""
|
||||
caption:
|
||||
tags:
|
||||
- tag: standalone yaml tag
|
||||
"""))
|
||||
actual = Dataset.from_path(".").image_configs
|
||||
|
||||
expected = {
|
||||
"./filename main prompt, filename tag.jpg": ImageConfig(
|
||||
main_prompts="filename main prompt",
|
||||
tags= [ Tag("filename tag"), Tag("standalone yaml tag") ]
|
||||
)
|
||||
}
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_multiple_prompts_and_tags_from_multiple_sidecars_are_supported(self):
|
||||
self.fs.create_file("image_1.jpg")
|
||||
self.fs.create_file("image_1.yaml", contents=dedent("""
|
||||
|
@ -358,4 +376,4 @@ class TestDataset(TestCase):
|
|||
self.assertEqual(actual[2].caption.rating(), 1.0)
|
||||
self.assertEqual(actual[2].caption.get_caption(), "nested.jpg prompt, high prio global tag, local tag, low prio global tag, nested.jpg tag")
|
||||
self.assertTrue(actual[2].caption._ImageCaption__use_weights)
|
||||
self.assertEqual(actual[2].caption._ImageCaption__max_target_length, DEFAULT_MAX_CAPTION_LENGTH)
|
||||
self.assertEqual(actual[2].caption._ImageCaption__max_target_length, DEFAULT_MAX_CAPTION_LENGTH)
|
Loading…
Reference in New Issue