Retain original tag order when parsing captions

This commit is contained in:
Augusto de la Torre 2023-03-19 23:30:42 +01:00
parent 43ff722e95
commit fae0b3c535
2 changed files with 21 additions and 8 deletions

View File

@ -7,7 +7,7 @@ from functools import total_ordering
from attrs import define, field, Factory
from data.image_train_item import ImageCaption, ImageTrainItem
from utils.fs_helpers import *
from typing import TypeVar, Iterable
from typing import Iterable
from tqdm import tqdm
@ -19,12 +19,12 @@ def overlay(overlay, base):
def safe_set(val):
if isinstance(val, str):
return {val} if val else {}
return dict.fromkeys([val]) if val else dict()
if isinstance(val, Iterable):
return {i for i in val if i is not None}
return dict.fromkeys((i for i in val if i is not None))
return val or {}
return val or dict()
@define(frozen=True)
@total_ordering
@ -51,10 +51,10 @@ class Tag:
@define
class ImageConfig:
# Captions
main_prompts: set[str] = field(factory=set, converter=safe_set)
main_prompts: dict[str, None] = field(factory=dict, converter=safe_set)
rating: float = None
max_caption_length: int = None
tags: set[Tag] = field(factory=set, converter=safe_set)
tags: dict[Tag, None] = field(factory=dict, converter=safe_set)
# Options
multiply: float = None
@ -66,10 +66,10 @@ class ImageConfig:
return self
return ImageConfig(
main_prompts=self.main_prompts.union(other.main_prompts),
main_prompts=self.main_prompts | other.main_prompts,
rating=overlay(other.rating, self.rating),
max_caption_length=overlay(other.max_caption_length, self.max_caption_length),
tags=self.tags.union(other.tags),
tags= self.tags | other.tags,
multiply=overlay(other.multiply, self.multiply),
cond_dropout=overlay(other.cond_dropout, self.cond_dropout),
flip_p=overlay(other.flip_p, self.flip_p),

View File

@ -249,6 +249,19 @@ class TestDataset(TestCase):
}
self.assertEqual(expected, actual)
def test_tag_order_is_retained(self):
import uuid
tags=[str(uuid.uuid4()) for _ in range(10000)]
caption='main_prompt,'+", ".join(tags)
self.fs.create_file("image.png")
self.fs.create_file("image.txt", contents=caption)
actual = Dataset.from_path(".").image_configs
expected = { "./image.png": ImageConfig( main_prompts="main_prompt", tags=map(Tag.parse, tags)) }
self.assertEqual(actual, expected)
def test_dataset_can_produce_train_items(self):
dataset = Dataset({
"1.jpg": ImageConfig(