Retain original tag order when parsing captions
This commit is contained in:
parent
43ff722e95
commit
fae0b3c535
|
@ -7,7 +7,7 @@ from functools import total_ordering
|
||||||
from attrs import define, field, Factory
|
from attrs import define, field, Factory
|
||||||
from data.image_train_item import ImageCaption, ImageTrainItem
|
from data.image_train_item import ImageCaption, ImageTrainItem
|
||||||
from utils.fs_helpers import *
|
from utils.fs_helpers import *
|
||||||
from typing import TypeVar, Iterable
|
from typing import Iterable
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
@ -19,12 +19,12 @@ def overlay(overlay, base):
|
||||||
|
|
||||||
def safe_set(val):
|
def safe_set(val):
|
||||||
if isinstance(val, str):
|
if isinstance(val, str):
|
||||||
return {val} if val else {}
|
return dict.fromkeys([val]) if val else dict()
|
||||||
|
|
||||||
if isinstance(val, Iterable):
|
if isinstance(val, Iterable):
|
||||||
return {i for i in val if i is not None}
|
return dict.fromkeys((i for i in val if i is not None))
|
||||||
|
|
||||||
return val or {}
|
return val or dict()
|
||||||
|
|
||||||
@define(frozen=True)
|
@define(frozen=True)
|
||||||
@total_ordering
|
@total_ordering
|
||||||
|
@ -51,10 +51,10 @@ class Tag:
|
||||||
@define
|
@define
|
||||||
class ImageConfig:
|
class ImageConfig:
|
||||||
# Captions
|
# Captions
|
||||||
main_prompts: set[str] = field(factory=set, converter=safe_set)
|
main_prompts: dict[str, None] = field(factory=dict, converter=safe_set)
|
||||||
rating: float = None
|
rating: float = None
|
||||||
max_caption_length: int = None
|
max_caption_length: int = None
|
||||||
tags: set[Tag] = field(factory=set, converter=safe_set)
|
tags: dict[Tag, None] = field(factory=dict, converter=safe_set)
|
||||||
|
|
||||||
# Options
|
# Options
|
||||||
multiply: float = None
|
multiply: float = None
|
||||||
|
@ -66,10 +66,10 @@ class ImageConfig:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
return ImageConfig(
|
return ImageConfig(
|
||||||
main_prompts=self.main_prompts.union(other.main_prompts),
|
main_prompts=self.main_prompts | other.main_prompts,
|
||||||
rating=overlay(other.rating, self.rating),
|
rating=overlay(other.rating, self.rating),
|
||||||
max_caption_length=overlay(other.max_caption_length, self.max_caption_length),
|
max_caption_length=overlay(other.max_caption_length, self.max_caption_length),
|
||||||
tags=self.tags.union(other.tags),
|
tags= self.tags | other.tags,
|
||||||
multiply=overlay(other.multiply, self.multiply),
|
multiply=overlay(other.multiply, self.multiply),
|
||||||
cond_dropout=overlay(other.cond_dropout, self.cond_dropout),
|
cond_dropout=overlay(other.cond_dropout, self.cond_dropout),
|
||||||
flip_p=overlay(other.flip_p, self.flip_p),
|
flip_p=overlay(other.flip_p, self.flip_p),
|
||||||
|
|
|
@ -249,6 +249,19 @@ class TestDataset(TestCase):
|
||||||
}
|
}
|
||||||
self.assertEqual(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
|
|
||||||
|
def test_tag_order_is_retained(self):
|
||||||
|
import uuid
|
||||||
|
tags=[str(uuid.uuid4()) for _ in range(10000)]
|
||||||
|
caption='main_prompt,'+", ".join(tags)
|
||||||
|
self.fs.create_file("image.png")
|
||||||
|
self.fs.create_file("image.txt", contents=caption)
|
||||||
|
|
||||||
|
actual = Dataset.from_path(".").image_configs
|
||||||
|
|
||||||
|
expected = { "./image.png": ImageConfig( main_prompts="main_prompt", tags=map(Tag.parse, tags)) }
|
||||||
|
|
||||||
|
self.assertEqual(actual, expected)
|
||||||
|
|
||||||
def test_dataset_can_produce_train_items(self):
|
def test_dataset_can_produce_train_items(self):
|
||||||
dataset = Dataset({
|
dataset = Dataset({
|
||||||
"1.jpg": ImageConfig(
|
"1.jpg": ImageConfig(
|
||||||
|
|
Loading…
Reference in New Issue