Merge pull request #116 from qslug/retain-tag-order
Retain original tag order when parsing captions
This commit is contained in:
commit
467f3e5aae
|
@ -7,7 +7,7 @@ from functools import total_ordering
|
|||
from attrs import define, field, Factory
|
||||
from data.image_train_item import ImageCaption, ImageTrainItem
|
||||
from utils.fs_helpers import *
|
||||
from typing import TypeVar, Iterable
|
||||
from typing import Iterable
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
@ -19,12 +19,12 @@ def overlay(overlay, base):
|
|||
|
||||
def safe_set(val):
|
||||
if isinstance(val, str):
|
||||
return {val} if val else {}
|
||||
return dict.fromkeys([val]) if val else dict()
|
||||
|
||||
if isinstance(val, Iterable):
|
||||
return {i for i in val if i is not None}
|
||||
return dict.fromkeys((i for i in val if i is not None))
|
||||
|
||||
return val or {}
|
||||
return val or dict()
|
||||
|
||||
@define(frozen=True)
|
||||
@total_ordering
|
||||
|
@ -51,10 +51,10 @@ class Tag:
|
|||
@define
|
||||
class ImageConfig:
|
||||
# Captions
|
||||
main_prompts: set[str] = field(factory=set, converter=safe_set)
|
||||
main_prompts: dict[str, None] = field(factory=dict, converter=safe_set)
|
||||
rating: float = None
|
||||
max_caption_length: int = None
|
||||
tags: set[Tag] = field(factory=set, converter=safe_set)
|
||||
tags: dict[Tag, None] = field(factory=dict, converter=safe_set)
|
||||
|
||||
# Options
|
||||
multiply: float = None
|
||||
|
@ -66,10 +66,10 @@ class ImageConfig:
|
|||
return self
|
||||
|
||||
return ImageConfig(
|
||||
main_prompts=self.main_prompts.union(other.main_prompts),
|
||||
main_prompts=self.main_prompts | other.main_prompts,
|
||||
rating=overlay(other.rating, self.rating),
|
||||
max_caption_length=overlay(other.max_caption_length, self.max_caption_length),
|
||||
tags=self.tags.union(other.tags),
|
||||
tags= self.tags | other.tags,
|
||||
multiply=overlay(other.multiply, self.multiply),
|
||||
cond_dropout=overlay(other.cond_dropout, self.cond_dropout),
|
||||
flip_p=overlay(other.flip_p, self.flip_p),
|
||||
|
|
|
@ -249,6 +249,19 @@ class TestDataset(TestCase):
|
|||
}
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_tag_order_is_retained(self):
|
||||
import uuid
|
||||
tags=[str(uuid.uuid4()) for _ in range(10000)]
|
||||
caption='main_prompt,'+", ".join(tags)
|
||||
self.fs.create_file("image.png")
|
||||
self.fs.create_file("image.txt", contents=caption)
|
||||
|
||||
actual = Dataset.from_path(".").image_configs
|
||||
|
||||
expected = { "./image.png": ImageConfig( main_prompts="main_prompt", tags=map(Tag.parse, tags)) }
|
||||
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_dataset_can_produce_train_items(self):
|
||||
dataset = Dataset({
|
||||
"1.jpg": ImageConfig(
|
||||
|
|
Loading…
Reference in New Issue