2023-03-08 07:02:14 -07:00
|
|
|
import os
|
|
|
|
import logging
|
|
|
|
import yaml
|
|
|
|
import json
|
|
|
|
|
|
|
|
from functools import total_ordering
|
2023-03-12 12:48:15 -06:00
|
|
|
from attrs import define, field, Factory
|
2023-03-08 07:02:14 -07:00
|
|
|
from data.image_train_item import ImageCaption, ImageTrainItem
|
|
|
|
from utils.fs_helpers import *
|
2023-03-12 12:48:15 -06:00
|
|
|
from typing import TypeVar, Iterable
|
2023-03-08 07:02:14 -07:00
|
|
|
|
2023-03-15 10:06:29 -06:00
|
|
|
from tqdm import tqdm
|
|
|
|
|
2023-03-08 07:02:14 -07:00
|
|
|
|
2023-03-12 12:48:15 -06:00
|
|
|
def overlay(overlay, base):
|
|
|
|
return overlay if overlay is not None else base
|
|
|
|
|
|
|
|
def safe_set(val):
|
|
|
|
if isinstance(val, str):
|
|
|
|
return {val} if val else {}
|
|
|
|
|
|
|
|
if isinstance(val, Iterable):
|
|
|
|
return {i for i in val if i is not None}
|
|
|
|
|
|
|
|
return val or {}
|
|
|
|
|
2023-03-08 07:02:14 -07:00
|
|
|
@define(frozen=True)
|
|
|
|
@total_ordering
|
|
|
|
class Tag:
|
|
|
|
value: str
|
2023-03-12 12:48:15 -06:00
|
|
|
weight: float = field(default=1.0, converter=lambda x: x if x is not None else 1.0)
|
|
|
|
|
2023-03-08 07:02:14 -07:00
|
|
|
@classmethod
|
2023-03-12 12:48:15 -06:00
|
|
|
def parse(cls, data):
|
|
|
|
if isinstance(data, str):
|
|
|
|
return Tag(data)
|
2023-03-08 07:02:14 -07:00
|
|
|
|
2023-03-12 12:48:15 -06:00
|
|
|
if isinstance(data, dict):
|
|
|
|
value = data.get("tag")
|
|
|
|
weight = data.get("weight")
|
|
|
|
if value:
|
|
|
|
return Tag(value, weight)
|
2023-03-08 07:02:14 -07:00
|
|
|
|
2023-03-12 12:48:15 -06:00
|
|
|
return None
|
2023-03-08 07:02:14 -07:00
|
|
|
|
2023-03-12 12:48:15 -06:00
|
|
|
def __lt__(self, other):
|
|
|
|
return self.weight < other.weight and self.value < other.value
|
|
|
|
|
|
|
|
@define
|
2023-03-08 07:02:14 -07:00
|
|
|
class ImageConfig:
|
2023-03-12 12:48:15 -06:00
|
|
|
# Captions
|
|
|
|
main_prompts: set[str] = field(factory=set, converter=safe_set)
|
|
|
|
rating: float = None
|
|
|
|
max_caption_length: int = None
|
|
|
|
tags: set[Tag] = field(factory=set, converter=safe_set)
|
|
|
|
|
|
|
|
# Options
|
2023-03-08 07:02:14 -07:00
|
|
|
multiply: float = None
|
|
|
|
cond_dropout: float = None
|
|
|
|
flip_p: float = None
|
|
|
|
|
|
|
|
def merge(self, other):
|
|
|
|
if other is None:
|
|
|
|
return self
|
|
|
|
|
|
|
|
return ImageConfig(
|
2023-03-12 12:48:15 -06:00
|
|
|
main_prompts=self.main_prompts.union(other.main_prompts),
|
|
|
|
rating=overlay(other.rating, self.rating),
|
|
|
|
max_caption_length=overlay(other.max_caption_length, self.max_caption_length),
|
|
|
|
tags=self.tags.union(other.tags),
|
|
|
|
multiply=overlay(other.multiply, self.multiply),
|
|
|
|
cond_dropout=overlay(other.cond_dropout, self.cond_dropout),
|
|
|
|
flip_p=overlay(other.flip_p, self.flip_p),
|
|
|
|
)
|
2023-03-08 07:02:14 -07:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_dict(cls, data: dict):
|
2023-03-12 12:48:15 -06:00
|
|
|
# Parse standard yaml tag file (with options)
|
|
|
|
parsed_cfg = ImageConfig(
|
|
|
|
main_prompts=safe_set(data.get("main_prompt")),
|
|
|
|
rating=data.get("rating"),
|
|
|
|
max_caption_length=data.get("max_caption_length"),
|
|
|
|
tags=safe_set(map(Tag.parse, data.get("tags", []))),
|
2023-03-08 07:02:14 -07:00
|
|
|
multiply=data.get("multiply"),
|
|
|
|
cond_dropout=data.get("cond_dropout"),
|
2023-03-12 12:48:15 -06:00
|
|
|
flip_p=data.get("flip_p"),
|
|
|
|
)
|
|
|
|
|
|
|
|
# Alternatively parse from dedicated `caption` attribute
|
|
|
|
if cap_attr := data.get('caption'):
|
|
|
|
parsed_cfg = parsed_cfg.merge(ImageConfig.parse(cap_attr))
|
|
|
|
|
|
|
|
return parsed_cfg
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def fold(cls, configs):
|
|
|
|
acc = ImageConfig()
|
|
|
|
for cfg in configs:
|
|
|
|
acc = acc.merge(cfg)
|
|
|
|
return acc
|
|
|
|
|
|
|
|
def ensure_caption(self):
|
|
|
|
return self
|
2023-03-08 07:02:14 -07:00
|
|
|
|
|
|
|
@classmethod
|
2023-03-12 12:48:15 -06:00
|
|
|
def from_caption_text(cls, text: str):
|
|
|
|
if not text:
|
|
|
|
return ImageConfig()
|
|
|
|
if os.path.isfile(text):
|
|
|
|
return ImageConfig.from_file(text)
|
|
|
|
|
|
|
|
split_caption = list(map(str.strip, text.split(",")))
|
|
|
|
return ImageConfig(
|
|
|
|
main_prompts=split_caption[0],
|
|
|
|
tags=map(Tag.parse, split_caption[1:])
|
|
|
|
)
|
2023-03-08 07:02:14 -07:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_file(cls, file: str):
|
2023-03-12 12:48:15 -06:00
|
|
|
match ext(file):
|
|
|
|
case '.jpg' | '.jpeg' | '.png' | '.bmp' | '.webp' | '.jfif':
|
|
|
|
return ImageConfig(image=file)
|
|
|
|
case ".json":
|
|
|
|
return ImageConfig.from_dict(json.load(read_text(file)))
|
|
|
|
case ".yaml" | ".yml":
|
|
|
|
return ImageConfig.from_dict(yaml.safe_load(read_text(file)))
|
|
|
|
case ".txt" | ".caption":
|
|
|
|
return ImageConfig.from_caption_text(read_text(file))
|
|
|
|
case _:
|
|
|
|
return logging.warning(" *** Unrecognized config extension {ext}")
|
2023-03-08 07:02:14 -07:00
|
|
|
|
|
|
|
@classmethod
|
2023-03-12 12:48:15 -06:00
|
|
|
def parse(cls, input):
|
2023-03-08 07:02:14 -07:00
|
|
|
if isinstance(input, str):
|
2023-03-12 12:48:15 -06:00
|
|
|
if os.path.isfile(input):
|
|
|
|
return ImageConfig.from_file(input)
|
|
|
|
else:
|
|
|
|
return ImageConfig.from_caption_text(input)
|
2023-03-08 07:02:14 -07:00
|
|
|
elif isinstance(input, dict):
|
|
|
|
return ImageConfig.from_dict(input)
|
2023-03-12 12:48:15 -06:00
|
|
|
|
2023-03-08 07:02:14 -07:00
|
|
|
|
|
|
|
@define()
|
|
|
|
class Dataset:
|
2023-03-12 12:48:15 -06:00
|
|
|
image_configs: dict[str, ImageConfig]
|
2023-03-08 07:02:14 -07:00
|
|
|
|
|
|
|
def __global_cfg(files):
|
|
|
|
cfgs = []
|
|
|
|
for file in files:
|
|
|
|
match os.path.basename(file):
|
|
|
|
case 'global.yaml' | 'global.yml':
|
|
|
|
cfgs.append(ImageConfig.from_file(file))
|
|
|
|
return ImageConfig.fold(cfgs)
|
|
|
|
|
|
|
|
def __local_cfg(files):
|
|
|
|
cfgs = []
|
|
|
|
for file in files:
|
|
|
|
match os.path.basename(file):
|
|
|
|
case 'multiply.txt':
|
|
|
|
cfgs.append(ImageConfig(multiply=read_float(file)))
|
|
|
|
case 'cond_dropout.txt':
|
|
|
|
cfgs.append(ImageConfig(cond_dropout=read_float(file)))
|
|
|
|
case 'flip_p.txt':
|
|
|
|
cfgs.append(ImageConfig(flip_p=read_float(file)))
|
|
|
|
case 'local.yaml' | 'local.yml':
|
|
|
|
cfgs.append(ImageConfig.from_file(file))
|
|
|
|
return ImageConfig.fold(cfgs)
|
|
|
|
|
2023-03-12 12:48:15 -06:00
|
|
|
def __sidecar_cfg(imagepath, files):
|
|
|
|
cfgs = []
|
2023-03-08 07:02:14 -07:00
|
|
|
for file in files:
|
|
|
|
if same_barename(imagepath, file):
|
|
|
|
match ext(file):
|
|
|
|
case '.txt' | '.caption' | '.yml' | '.yaml':
|
|
|
|
cfgs.append(ImageConfig.from_file(file))
|
|
|
|
return ImageConfig.fold(cfgs)
|
|
|
|
|
2023-03-12 12:48:15 -06:00
|
|
|
# Use file name for caption only as a last resort
|
|
|
|
@classmethod
|
|
|
|
def __ensure_caption(cls, cfg: ImageConfig, file: str):
|
|
|
|
if cfg.main_prompts or cfg.tags:
|
|
|
|
return cfg
|
|
|
|
cap_cfg = ImageConfig.from_caption_text(barename(file).split("_")[0])
|
|
|
|
return cfg.merge(cap_cfg)
|
|
|
|
|
2023-03-08 07:02:14 -07:00
|
|
|
@classmethod
|
|
|
|
def from_path(cls, data_root):
|
|
|
|
# Create a visitor that maintains global config stack
|
|
|
|
# and accumulates image configs as it traverses dataset
|
2023-03-12 12:48:15 -06:00
|
|
|
image_configs = {}
|
2023-03-08 07:02:14 -07:00
|
|
|
def process_dir(files, parent_globals):
|
2023-03-12 12:48:15 -06:00
|
|
|
global_cfg = parent_globals.merge(Dataset.__global_cfg(files))
|
|
|
|
local_cfg = Dataset.__local_cfg(files)
|
2023-03-08 07:02:14 -07:00
|
|
|
for img in filter(is_image, files):
|
2023-03-12 12:48:15 -06:00
|
|
|
img_cfg = Dataset.__sidecar_cfg(img, files)
|
|
|
|
resolved_cfg = ImageConfig.fold([global_cfg, local_cfg, img_cfg])
|
|
|
|
image_configs[img] = Dataset.__ensure_caption(resolved_cfg, img)
|
|
|
|
return global_cfg
|
2023-03-08 07:02:14 -07:00
|
|
|
|
|
|
|
walk_and_visit(data_root, process_dir, ImageConfig())
|
|
|
|
return Dataset(image_configs)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_json(cls, json_path):
|
|
|
|
"""
|
|
|
|
Import a dataset definition from a JSON file
|
|
|
|
"""
|
2023-03-12 12:48:15 -06:00
|
|
|
image_configs = {}
|
2023-03-08 07:02:14 -07:00
|
|
|
with open(json_path, encoding='utf-8', mode='r') as stream:
|
|
|
|
for data in json.load(stream):
|
2023-03-12 12:48:15 -06:00
|
|
|
img = data.get("image")
|
|
|
|
cfg = Dataset.__ensure_caption(ImageConfig.parse(data), img)
|
|
|
|
if not img:
|
2023-03-08 07:02:14 -07:00
|
|
|
logging.warning(f" *** Error parsing json image entry in {json_path}: {data}")
|
|
|
|
continue
|
2023-03-12 12:48:15 -06:00
|
|
|
image_configs[img] = cfg
|
|
|
|
return Dataset(image_configs)
|
2023-03-08 07:02:14 -07:00
|
|
|
|
|
|
|
def image_train_items(self, aspects):
|
|
|
|
items = []
|
2023-03-15 10:06:29 -06:00
|
|
|
for image in tqdm(self.image_configs, desc="preloading", dynamic_ncols=True):
|
2023-03-12 12:48:15 -06:00
|
|
|
config = self.image_configs[image]
|
|
|
|
if len(config.main_prompts) > 1:
|
|
|
|
logging.warning(f" *** Found multiple multiple main_prompts for image {image}, but only one will be applied: {config.main_prompts}")
|
2023-03-08 07:02:14 -07:00
|
|
|
|
|
|
|
tags = []
|
|
|
|
tag_weights = []
|
2023-03-12 12:48:15 -06:00
|
|
|
for tag in sorted(config.tags):
|
2023-03-08 07:02:14 -07:00
|
|
|
tags.append(tag.value)
|
2023-03-12 12:48:15 -06:00
|
|
|
tag_weights.append(tag.weight)
|
2023-03-08 07:02:14 -07:00
|
|
|
use_weights = len(set(tag_weights)) > 1
|
|
|
|
|
|
|
|
caption = ImageCaption(
|
2023-03-12 12:48:15 -06:00
|
|
|
main_prompt=next(iter(sorted(config.main_prompts))),
|
2023-03-12 17:36:59 -06:00
|
|
|
rating=config.rating or 1.0,
|
2023-03-08 07:02:14 -07:00
|
|
|
tags=tags,
|
|
|
|
tag_weights=tag_weights,
|
2023-03-12 12:48:15 -06:00
|
|
|
max_target_length=config.max_caption_length,
|
2023-03-08 07:02:14 -07:00
|
|
|
use_weights=use_weights)
|
|
|
|
|
|
|
|
item = ImageTrainItem(
|
|
|
|
image=None,
|
|
|
|
caption=caption,
|
|
|
|
aspects=aspects,
|
2023-03-12 12:48:15 -06:00
|
|
|
pathname=os.path.abspath(image),
|
2023-03-08 07:02:14 -07:00
|
|
|
flip_p=config.flip_p or 0.0,
|
|
|
|
multiplier=config.multiply or 1.0,
|
|
|
|
cond_dropout=config.cond_dropout
|
|
|
|
)
|
|
|
|
items.append(item)
|
2023-03-12 12:48:15 -06:00
|
|
|
return list(sorted(items, key=lambda ti: ti.pathname))
|