From fae0b3c5356fb6fa61216ae05ab5ad54bc9cb4c6 Mon Sep 17 00:00:00 2001 From: Augusto de la Torre Date: Sun, 19 Mar 2023 23:30:42 +0100 Subject: [PATCH] Retain original tag order when parsing captions --- data/dataset.py | 16 ++++++++-------- test/test_dataset.py | 13 +++++++++++++ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/data/dataset.py b/data/dataset.py index 9105f85..1d184cf 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -7,7 +7,7 @@ from functools import total_ordering from attrs import define, field, Factory from data.image_train_item import ImageCaption, ImageTrainItem from utils.fs_helpers import * -from typing import TypeVar, Iterable +from typing import Iterable from tqdm import tqdm @@ -19,12 +19,12 @@ def overlay(overlay, base): def safe_set(val): if isinstance(val, str): - return {val} if val else {} + return dict.fromkeys([val]) if val else dict() if isinstance(val, Iterable): - return {i for i in val if i is not None} + return dict.fromkeys((i for i in val if i is not None)) - return val or {} + return val or dict() @define(frozen=True) @total_ordering @@ -51,10 +51,10 @@ class Tag: @define class ImageConfig: # Captions - main_prompts: set[str] = field(factory=set, converter=safe_set) + main_prompts: dict[str, None] = field(factory=dict, converter=safe_set) rating: float = None max_caption_length: int = None - tags: set[Tag] = field(factory=set, converter=safe_set) + tags: dict[Tag, None] = field(factory=dict, converter=safe_set) # Options multiply: float = None @@ -66,10 +66,10 @@ class ImageConfig: return self return ImageConfig( - main_prompts=self.main_prompts.union(other.main_prompts), + main_prompts=self.main_prompts | other.main_prompts, rating=overlay(other.rating, self.rating), max_caption_length=overlay(other.max_caption_length, self.max_caption_length), - tags=self.tags.union(other.tags), + tags= self.tags | other.tags, multiply=overlay(other.multiply, self.multiply), cond_dropout=overlay(other.cond_dropout, self.cond_dropout), flip_p=overlay(other.flip_p, self.flip_p), diff --git a/test/test_dataset.py b/test/test_dataset.py index 6179234..ee0b932 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -249,6 +249,19 @@ class TestDataset(TestCase): } self.assertEqual(expected, actual) + def test_tag_order_is_retained(self): + import uuid + tags=[str(uuid.uuid4()) for _ in range(10000)] + caption='main_prompt,'+", ".join(tags) + self.fs.create_file("image.png") + self.fs.create_file("image.txt", contents=caption) + + actual = Dataset.from_path(".").image_configs + + expected = { "./image.png": ImageConfig( main_prompts="main_prompt", tags=map(Tag.parse, tags)) } + + self.assertEqual(actual, expected) + def test_dataset_can_produce_train_items(self): dataset = Dataset({ "1.jpg": ImageConfig(