Prioritize tags, `image > local > global`, but respect weights
This commit is contained in:
parent
fae0b3c535
commit
161e0a563c
|
@ -27,7 +27,6 @@ def safe_set(val):
|
|||
return val or dict()
|
||||
|
||||
@define(frozen=True)
|
||||
@total_ordering
|
||||
class Tag:
|
||||
value: str
|
||||
weight: float = field(default=1.0, converter=lambda x: x if x is not None else 1.0)
|
||||
|
@ -45,9 +44,6 @@ class Tag:
|
|||
|
||||
return None
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.weight < other.weight and self.value < other.value
|
||||
|
||||
@define
|
||||
class ImageConfig:
|
||||
# Captions
|
||||
|
@ -66,10 +62,10 @@ class ImageConfig:
|
|||
return self
|
||||
|
||||
return ImageConfig(
|
||||
main_prompts=self.main_prompts | other.main_prompts,
|
||||
main_prompts=other.main_prompts | self.main_prompts,
|
||||
rating=overlay(other.rating, self.rating),
|
||||
max_caption_length=overlay(other.max_caption_length, self.max_caption_length),
|
||||
tags= self.tags | other.tags,
|
||||
tags= other.tags | self.tags,
|
||||
multiply=overlay(other.multiply, self.multiply),
|
||||
cond_dropout=overlay(other.cond_dropout, self.cond_dropout),
|
||||
flip_p=overlay(other.flip_p, self.flip_p),
|
||||
|
@ -227,14 +223,14 @@ class Dataset:
|
|||
|
||||
tags = []
|
||||
tag_weights = []
|
||||
for tag in sorted(config.tags):
|
||||
for tag in sorted(config.tags, key=lambda x: x.weight or 1.0, reverse=True):
|
||||
tags.append(tag.value)
|
||||
tag_weights.append(tag.weight)
|
||||
use_weights = len(set(tag_weights)) > 1
|
||||
|
||||
try:
|
||||
caption = ImageCaption(
|
||||
main_prompt=next(iter(sorted(config.main_prompts))),
|
||||
main_prompt=next(iter(config.main_prompts)),
|
||||
rating=config.rating or 1.0,
|
||||
tags=tags,
|
||||
tag_weights=tag_weights,
|
||||
|
|
|
@ -166,27 +166,27 @@ class TestDataset(TestCase):
|
|||
expected = {
|
||||
"./arbitrary filename.png": ImageConfig(
|
||||
main_prompts={ 'global prompt', 'local prompt' },
|
||||
tags={ Tag("global tag"), Tag("local tag") },
|
||||
tags=[ Tag("global tag"), Tag("local tag") ],
|
||||
flip_p=0.0
|
||||
),
|
||||
"./sub/sub arbitrary filename.png": ImageConfig(
|
||||
main_prompts={ 'global prompt' },
|
||||
tags={ Tag("global tag") },
|
||||
tags=[ Tag("global tag") ],
|
||||
flip_p=0.0
|
||||
),
|
||||
"./sub/sidecar.png": ImageConfig(
|
||||
main_prompts={ 'global prompt', 'sidecar prompt' },
|
||||
tags={ Tag("global tag"), Tag("sidecar tag") },
|
||||
tags=[ Tag("global tag"), Tag("sidecar tag") ],
|
||||
flip_p=0.0
|
||||
),
|
||||
"./optfile/optfile.png": ImageConfig(
|
||||
main_prompts={ 'global prompt' },
|
||||
tags={ Tag("global tag") },
|
||||
tags=[ Tag("global tag") ],
|
||||
flip_p=0.1234
|
||||
),
|
||||
"./sub/sub2/sub3/xyz.png": ImageConfig(
|
||||
main_prompts={ 'global prompt' },
|
||||
tags={ Tag("global tag"), Tag("sub global tag") },
|
||||
tags=[ Tag("global tag"), Tag("sub global tag") ],
|
||||
flip_p=0.0
|
||||
)
|
||||
}
|
||||
|
@ -249,7 +249,11 @@ class TestDataset(TestCase):
|
|||
}
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_tag_order_is_retained(self):
|
||||
def test_original_tag_order_is_retained_in_dataset(self):
|
||||
def get_random_string(length):
|
||||
letters = string.ascii_lowercase
|
||||
return ''.join(random.choice(letters) for _ in range(length))
|
||||
|
||||
import uuid
|
||||
tags=[str(uuid.uuid4()) for _ in range(10000)]
|
||||
caption='main_prompt,'+", ".join(tags)
|
||||
|
@ -262,41 +266,96 @@ class TestDataset(TestCase):
|
|||
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_dataset_can_produce_train_items(self):
|
||||
|
||||
def test_tag_order_is_retained_in_train_item(self):
|
||||
dataset = Dataset({
|
||||
"1.jpg": ImageConfig(
|
||||
multiply=2,
|
||||
flip_p=0.1,
|
||||
cond_dropout=0.01,
|
||||
main_prompts=["first caption","second caption"],
|
||||
rating = 1.1,
|
||||
max_caption_length=1024,
|
||||
tags=frozenset([
|
||||
Tag("tag"),
|
||||
Tag("tag_2", 2.0)
|
||||
])),
|
||||
"2.jpg": ImageConfig( main_prompts="single caption")
|
||||
main_prompts="main_prompt",
|
||||
tags=[
|
||||
Tag("xyz"),
|
||||
Tag("abc"),
|
||||
Tag("ijk")
|
||||
])
|
||||
})
|
||||
|
||||
aspects = []
|
||||
actual = dataset.image_train_items(aspects)
|
||||
|
||||
self.assertEqual(len(actual), 2)
|
||||
|
||||
self.assertEqual(len(actual), 1)
|
||||
self.assertEqual(actual[0].pathname, os.path.abspath('1.jpg'))
|
||||
self.assertEqual(actual[0].multiplier, 2.0)
|
||||
self.assertEqual(actual[0].flip.p, 0.1)
|
||||
self.assertEqual(actual[0].cond_dropout, 0.01)
|
||||
self.assertEqual(actual[0].caption.rating(), 1.1)
|
||||
self.assertEqual(actual[0].caption.get_caption(), "first caption, tag, tag_2")
|
||||
self.assertTrue(actual[0].caption._ImageCaption__use_weights)
|
||||
self.assertEqual(actual[0].caption._ImageCaption__max_target_length, 1024)
|
||||
self.assertEqual(actual[0].caption.get_caption(), "main_prompt, xyz, abc, ijk")
|
||||
|
||||
self.assertEqual(actual[1].pathname, os.path.abspath('2.jpg'))
|
||||
self.assertEqual(actual[1].multiplier, 1.0)
|
||||
self.assertEqual(actual[1].flip.p, 0.0)
|
||||
self.assertIsNone(actual[1].cond_dropout)
|
||||
self.assertEqual(actual[1].caption.rating(), 1.0)
|
||||
self.assertEqual(actual[1].caption.get_caption(), "single caption")
|
||||
self.assertFalse(actual[1].caption._ImageCaption__use_weights)
|
||||
self.assertEqual(actual[1].caption._ImageCaption__max_target_length, DEFAULT_MAX_CAPTION_LENGTH)
|
||||
def test_dataset_can_produce_train_items(self):
|
||||
self.fs.create_file("./sub/global.yaml",
|
||||
contents=dedent("""\
|
||||
main_prompt: global prompt
|
||||
tags:
|
||||
- low prio global tag
|
||||
- tag: high prio global tag
|
||||
weight: 10.0
|
||||
"""))
|
||||
|
||||
self.fs.create_file("./sub/nested/local.yaml",
|
||||
contents=dedent("""
|
||||
tags:
|
||||
- tag: local tag
|
||||
"""))
|
||||
|
||||
self.fs.create_file("./sub/sub.jpg")
|
||||
self.fs.create_file("./sub/sub.yaml",
|
||||
contents=dedent("""\
|
||||
main_prompt: sub.jpg prompt
|
||||
tags:
|
||||
- sub.jpg tag
|
||||
- another tag
|
||||
- last tag
|
||||
rating: 1.1
|
||||
max_caption_length: 1024
|
||||
multiply: 2
|
||||
flip_p: 0.1
|
||||
cond_dropout: 0.01
|
||||
"""))
|
||||
self.fs.create_file("./sub/nested/nested.jpg")
|
||||
self.fs.create_file("./sub/nested/nested.yaml",
|
||||
contents=dedent("""\
|
||||
main_prompt: nested.jpg prompt
|
||||
tags:
|
||||
- tag: nested.jpg tag
|
||||
weight: 0.1
|
||||
"""))
|
||||
self.fs.create_file("./root.jpg")
|
||||
self.fs.create_file("./root.txt", contents="root.jpg prompt, root.jpg tag")
|
||||
|
||||
aspects = []
|
||||
dataset = Dataset.from_path(".")
|
||||
actual = dataset.image_train_items(aspects)
|
||||
|
||||
self.assertEqual(len(actual), 3)
|
||||
|
||||
|
||||
self.assertEqual(actual[0].pathname, os.path.abspath('root.jpg'))
|
||||
self.assertEqual(actual[0].multiplier, 1.0)
|
||||
self.assertEqual(actual[0].flip.p, 0.0)
|
||||
self.assertIsNone(actual[0].cond_dropout)
|
||||
self.assertEqual(actual[0].caption.rating(), 1.0)
|
||||
self.assertEqual(actual[0].caption.get_caption(), "root.jpg prompt, root.jpg tag")
|
||||
self.assertFalse(actual[0].caption._ImageCaption__use_weights)
|
||||
self.assertEqual(actual[0].caption._ImageCaption__max_target_length, DEFAULT_MAX_CAPTION_LENGTH)
|
||||
|
||||
self.assertEqual(actual[1].pathname, os.path.abspath('sub/sub.jpg'))
|
||||
self.assertEqual(actual[1].multiplier, 2.0)
|
||||
self.assertEqual(actual[1].flip.p, 0.1)
|
||||
self.assertEqual(actual[1].cond_dropout, 0.01)
|
||||
self.assertEqual(actual[1].caption.rating(), 1.1)
|
||||
self.assertEqual(actual[1].caption.get_caption(), "sub.jpg prompt, high prio global tag, sub.jpg tag, another tag, last tag, low prio global tag")
|
||||
self.assertTrue(actual[1].caption._ImageCaption__use_weights)
|
||||
self.assertEqual(actual[1].caption._ImageCaption__max_target_length, 1024)
|
||||
|
||||
self.assertEqual(actual[2].pathname, os.path.abspath('sub/nested/nested.jpg'))
|
||||
self.assertEqual(actual[2].multiplier, 1.0)
|
||||
self.assertEqual(actual[2].flip.p, 0.0)
|
||||
self.assertIsNone(actual[2].cond_dropout)
|
||||
self.assertEqual(actual[2].caption.rating(), 1.0)
|
||||
self.assertEqual(actual[2].caption.get_caption(), "nested.jpg prompt, high prio global tag, local tag, low prio global tag, nested.jpg tag")
|
||||
self.assertTrue(actual[2].caption._ImageCaption__use_weights)
|
||||
self.assertEqual(actual[2].caption._ImageCaption__max_target_length, DEFAULT_MAX_CAPTION_LENGTH)
|
||||
|
|
Loading…
Reference in New Issue