Prioritize tags, `image > local > global`, but respect weights

This commit is contained in:
Augusto de la Torre 2023-03-21 00:15:53 +01:00
parent fae0b3c535
commit 161e0a563c
2 changed files with 98 additions and 43 deletions

View File

@ -27,7 +27,6 @@ def safe_set(val):
return val or dict()
@define(frozen=True)
@total_ordering
class Tag:
value: str
weight: float = field(default=1.0, converter=lambda x: x if x is not None else 1.0)
@ -45,9 +44,6 @@ class Tag:
return None
def __lt__(self, other):
return self.weight < other.weight and self.value < other.value
@define
class ImageConfig:
# Captions
@ -66,10 +62,10 @@ class ImageConfig:
return self
return ImageConfig(
main_prompts=self.main_prompts | other.main_prompts,
main_prompts=other.main_prompts | self.main_prompts,
rating=overlay(other.rating, self.rating),
max_caption_length=overlay(other.max_caption_length, self.max_caption_length),
tags= self.tags | other.tags,
tags= other.tags | self.tags,
multiply=overlay(other.multiply, self.multiply),
cond_dropout=overlay(other.cond_dropout, self.cond_dropout),
flip_p=overlay(other.flip_p, self.flip_p),
@ -227,14 +223,14 @@ class Dataset:
tags = []
tag_weights = []
for tag in sorted(config.tags):
for tag in sorted(config.tags, key=lambda x: x.weight or 1.0, reverse=True):
tags.append(tag.value)
tag_weights.append(tag.weight)
use_weights = len(set(tag_weights)) > 1
try:
caption = ImageCaption(
main_prompt=next(iter(sorted(config.main_prompts))),
main_prompt=next(iter(config.main_prompts)),
rating=config.rating or 1.0,
tags=tags,
tag_weights=tag_weights,

View File

@ -166,27 +166,27 @@ class TestDataset(TestCase):
expected = {
"./arbitrary filename.png": ImageConfig(
main_prompts={ 'global prompt', 'local prompt' },
tags={ Tag("global tag"), Tag("local tag") },
tags=[ Tag("global tag"), Tag("local tag") ],
flip_p=0.0
),
"./sub/sub arbitrary filename.png": ImageConfig(
main_prompts={ 'global prompt' },
tags={ Tag("global tag") },
tags=[ Tag("global tag") ],
flip_p=0.0
),
"./sub/sidecar.png": ImageConfig(
main_prompts={ 'global prompt', 'sidecar prompt' },
tags={ Tag("global tag"), Tag("sidecar tag") },
tags=[ Tag("global tag"), Tag("sidecar tag") ],
flip_p=0.0
),
"./optfile/optfile.png": ImageConfig(
main_prompts={ 'global prompt' },
tags={ Tag("global tag") },
tags=[ Tag("global tag") ],
flip_p=0.1234
),
"./sub/sub2/sub3/xyz.png": ImageConfig(
main_prompts={ 'global prompt' },
tags={ Tag("global tag"), Tag("sub global tag") },
tags=[ Tag("global tag"), Tag("sub global tag") ],
flip_p=0.0
)
}
@ -249,7 +249,11 @@ class TestDataset(TestCase):
}
self.assertEqual(expected, actual)
def test_tag_order_is_retained(self):
def test_original_tag_order_is_retained_in_dataset(self):
def get_random_string(length):
letters = string.ascii_lowercase
return ''.join(random.choice(letters) for _ in range(length))
import uuid
tags=[str(uuid.uuid4()) for _ in range(10000)]
caption='main_prompt,'+", ".join(tags)
@ -262,41 +266,96 @@ class TestDataset(TestCase):
self.assertEqual(actual, expected)
def test_dataset_can_produce_train_items(self):
def test_tag_order_is_retained_in_train_item(self):
dataset = Dataset({
"1.jpg": ImageConfig(
multiply=2,
flip_p=0.1,
cond_dropout=0.01,
main_prompts=["first caption","second caption"],
rating = 1.1,
max_caption_length=1024,
tags=frozenset([
Tag("tag"),
Tag("tag_2", 2.0)
])),
"2.jpg": ImageConfig( main_prompts="single caption")
main_prompts="main_prompt",
tags=[
Tag("xyz"),
Tag("abc"),
Tag("ijk")
])
})
aspects = []
actual = dataset.image_train_items(aspects)
self.assertEqual(len(actual), 2)
self.assertEqual(len(actual), 1)
self.assertEqual(actual[0].pathname, os.path.abspath('1.jpg'))
self.assertEqual(actual[0].multiplier, 2.0)
self.assertEqual(actual[0].flip.p, 0.1)
self.assertEqual(actual[0].cond_dropout, 0.01)
self.assertEqual(actual[0].caption.rating(), 1.1)
self.assertEqual(actual[0].caption.get_caption(), "first caption, tag, tag_2")
self.assertTrue(actual[0].caption._ImageCaption__use_weights)
self.assertEqual(actual[0].caption._ImageCaption__max_target_length, 1024)
self.assertEqual(actual[0].caption.get_caption(), "main_prompt, xyz, abc, ijk")
self.assertEqual(actual[1].pathname, os.path.abspath('2.jpg'))
self.assertEqual(actual[1].multiplier, 1.0)
self.assertEqual(actual[1].flip.p, 0.0)
self.assertIsNone(actual[1].cond_dropout)
self.assertEqual(actual[1].caption.rating(), 1.0)
self.assertEqual(actual[1].caption.get_caption(), "single caption")
self.assertFalse(actual[1].caption._ImageCaption__use_weights)
self.assertEqual(actual[1].caption._ImageCaption__max_target_length, DEFAULT_MAX_CAPTION_LENGTH)
def test_dataset_can_produce_train_items(self):
self.fs.create_file("./sub/global.yaml",
contents=dedent("""\
main_prompt: global prompt
tags:
- low prio global tag
- tag: high prio global tag
weight: 10.0
"""))
self.fs.create_file("./sub/nested/local.yaml",
contents=dedent("""
tags:
- tag: local tag
"""))
self.fs.create_file("./sub/sub.jpg")
self.fs.create_file("./sub/sub.yaml",
contents=dedent("""\
main_prompt: sub.jpg prompt
tags:
- sub.jpg tag
- another tag
- last tag
rating: 1.1
max_caption_length: 1024
multiply: 2
flip_p: 0.1
cond_dropout: 0.01
"""))
self.fs.create_file("./sub/nested/nested.jpg")
self.fs.create_file("./sub/nested/nested.yaml",
contents=dedent("""\
main_prompt: nested.jpg prompt
tags:
- tag: nested.jpg tag
weight: 0.1
"""))
self.fs.create_file("./root.jpg")
self.fs.create_file("./root.txt", contents="root.jpg prompt, root.jpg tag")
aspects = []
dataset = Dataset.from_path(".")
actual = dataset.image_train_items(aspects)
self.assertEqual(len(actual), 3)
self.assertEqual(actual[0].pathname, os.path.abspath('root.jpg'))
self.assertEqual(actual[0].multiplier, 1.0)
self.assertEqual(actual[0].flip.p, 0.0)
self.assertIsNone(actual[0].cond_dropout)
self.assertEqual(actual[0].caption.rating(), 1.0)
self.assertEqual(actual[0].caption.get_caption(), "root.jpg prompt, root.jpg tag")
self.assertFalse(actual[0].caption._ImageCaption__use_weights)
self.assertEqual(actual[0].caption._ImageCaption__max_target_length, DEFAULT_MAX_CAPTION_LENGTH)
self.assertEqual(actual[1].pathname, os.path.abspath('sub/sub.jpg'))
self.assertEqual(actual[1].multiplier, 2.0)
self.assertEqual(actual[1].flip.p, 0.1)
self.assertEqual(actual[1].cond_dropout, 0.01)
self.assertEqual(actual[1].caption.rating(), 1.1)
self.assertEqual(actual[1].caption.get_caption(), "sub.jpg prompt, high prio global tag, sub.jpg tag, another tag, last tag, low prio global tag")
self.assertTrue(actual[1].caption._ImageCaption__use_weights)
self.assertEqual(actual[1].caption._ImageCaption__max_target_length, 1024)
self.assertEqual(actual[2].pathname, os.path.abspath('sub/nested/nested.jpg'))
self.assertEqual(actual[2].multiplier, 1.0)
self.assertEqual(actual[2].flip.p, 0.0)
self.assertIsNone(actual[2].cond_dropout)
self.assertEqual(actual[2].caption.rating(), 1.0)
self.assertEqual(actual[2].caption.get_caption(), "nested.jpg prompt, high prio global tag, local tag, low prio global tag, nested.jpg tag")
self.assertTrue(actual[2].caption._ImageCaption__use_weights)
self.assertEqual(actual[2].caption._ImageCaption__max_target_length, DEFAULT_MAX_CAPTION_LENGTH)