EveryDream2trainer/data/dataset.py

249 lines
8.4 KiB
Python
Raw Normal View History

import os
import logging
import yaml
import json
from functools import total_ordering
from attrs import define, field, Factory
from data.image_train_item import ImageCaption, ImageTrainItem
from utils.fs_helpers import *
from typing import TypeVar, Iterable
def overlay(overlay, base):
return overlay if overlay is not None else base
def safe_set(val):
if isinstance(val, str):
return {val} if val else {}
if isinstance(val, Iterable):
return {i for i in val if i is not None}
return val or {}
@define(frozen=True)
@total_ordering
class Tag:
value: str
weight: float = field(default=1.0, converter=lambda x: x if x is not None else 1.0)
@classmethod
def parse(cls, data):
if isinstance(data, str):
return Tag(data)
if isinstance(data, dict):
value = data.get("tag")
weight = data.get("weight")
if value:
return Tag(value, weight)
return None
def __lt__(self, other):
return self.weight < other.weight and self.value < other.value
@define
class ImageConfig:
# Captions
main_prompts: set[str] = field(factory=set, converter=safe_set)
rating: float = None
max_caption_length: int = None
tags: set[Tag] = field(factory=set, converter=safe_set)
# Options
multiply: float = None
cond_dropout: float = None
flip_p: float = None
def merge(self, other):
if other is None:
return self
return ImageConfig(
main_prompts=self.main_prompts.union(other.main_prompts),
rating=overlay(other.rating, self.rating),
max_caption_length=overlay(other.max_caption_length, self.max_caption_length),
tags=self.tags.union(other.tags),
multiply=overlay(other.multiply, self.multiply),
cond_dropout=overlay(other.cond_dropout, self.cond_dropout),
flip_p=overlay(other.flip_p, self.flip_p),
)
@classmethod
def from_dict(cls, data: dict):
# Parse standard yaml tag file (with options)
parsed_cfg = ImageConfig(
main_prompts=safe_set(data.get("main_prompt")),
rating=data.get("rating"),
max_caption_length=data.get("max_caption_length"),
tags=safe_set(map(Tag.parse, data.get("tags", []))),
multiply=data.get("multiply"),
cond_dropout=data.get("cond_dropout"),
flip_p=data.get("flip_p"),
)
# Alternatively parse from dedicated `caption` attribute
if cap_attr := data.get('caption'):
parsed_cfg = parsed_cfg.merge(ImageConfig.parse(cap_attr))
return parsed_cfg
@classmethod
def fold(cls, configs):
acc = ImageConfig()
for cfg in configs:
acc = acc.merge(cfg)
return acc
def ensure_caption(self):
return self
@classmethod
def from_caption_text(cls, text: str):
if not text:
return ImageConfig()
if os.path.isfile(text):
return ImageConfig.from_file(text)
split_caption = list(map(str.strip, text.split(",")))
return ImageConfig(
main_prompts=split_caption[0],
tags=map(Tag.parse, split_caption[1:])
)
@classmethod
def from_file(cls, file: str):
match ext(file):
case '.jpg' | '.jpeg' | '.png' | '.bmp' | '.webp' | '.jfif':
return ImageConfig(image=file)
case ".json":
return ImageConfig.from_dict(json.load(read_text(file)))
case ".yaml" | ".yml":
return ImageConfig.from_dict(yaml.safe_load(read_text(file)))
case ".txt" | ".caption":
return ImageConfig.from_caption_text(read_text(file))
case _:
return logging.warning(" *** Unrecognized config extension {ext}")
@classmethod
def parse(cls, input):
if isinstance(input, str):
if os.path.isfile(input):
return ImageConfig.from_file(input)
else:
return ImageConfig.from_caption_text(input)
elif isinstance(input, dict):
return ImageConfig.from_dict(input)
@define()
class Dataset:
image_configs: dict[str, ImageConfig]
def __global_cfg(files):
cfgs = []
for file in files:
match os.path.basename(file):
case 'global.yaml' | 'global.yml':
cfgs.append(ImageConfig.from_file(file))
return ImageConfig.fold(cfgs)
def __local_cfg(files):
cfgs = []
for file in files:
match os.path.basename(file):
case 'multiply.txt':
cfgs.append(ImageConfig(multiply=read_float(file)))
case 'cond_dropout.txt':
cfgs.append(ImageConfig(cond_dropout=read_float(file)))
case 'flip_p.txt':
cfgs.append(ImageConfig(flip_p=read_float(file)))
case 'local.yaml' | 'local.yml':
cfgs.append(ImageConfig.from_file(file))
return ImageConfig.fold(cfgs)
def __sidecar_cfg(imagepath, files):
cfgs = []
for file in files:
if same_barename(imagepath, file):
match ext(file):
case '.txt' | '.caption' | '.yml' | '.yaml':
cfgs.append(ImageConfig.from_file(file))
return ImageConfig.fold(cfgs)
# Use file name for caption only as a last resort
@classmethod
def __ensure_caption(cls, cfg: ImageConfig, file: str):
if cfg.main_prompts or cfg.tags:
return cfg
cap_cfg = ImageConfig.from_caption_text(barename(file).split("_")[0])
return cfg.merge(cap_cfg)
@classmethod
def from_path(cls, data_root):
# Create a visitor that maintains global config stack
# and accumulates image configs as it traverses dataset
image_configs = {}
def process_dir(files, parent_globals):
global_cfg = parent_globals.merge(Dataset.__global_cfg(files))
local_cfg = Dataset.__local_cfg(files)
for img in filter(is_image, files):
img_cfg = Dataset.__sidecar_cfg(img, files)
resolved_cfg = ImageConfig.fold([global_cfg, local_cfg, img_cfg])
image_configs[img] = Dataset.__ensure_caption(resolved_cfg, img)
return global_cfg
walk_and_visit(data_root, process_dir, ImageConfig())
return Dataset(image_configs)
@classmethod
def from_json(cls, json_path):
"""
Import a dataset definition from a JSON file
"""
image_configs = {}
with open(json_path, encoding='utf-8', mode='r') as stream:
for data in json.load(stream):
img = data.get("image")
cfg = Dataset.__ensure_caption(ImageConfig.parse(data), img)
if not img:
logging.warning(f" *** Error parsing json image entry in {json_path}: {data}")
continue
image_configs[img] = cfg
return Dataset(image_configs)
def image_train_items(self, aspects):
items = []
for image in self.image_configs:
config = self.image_configs[image]
if len(config.main_prompts) > 1:
logging.warning(f" *** Found multiple multiple main_prompts for image {image}, but only one will be applied: {config.main_prompts}")
tags = []
tag_weights = []
for tag in sorted(config.tags):
tags.append(tag.value)
tag_weights.append(tag.weight)
use_weights = len(set(tag_weights)) > 1
caption = ImageCaption(
main_prompt=next(iter(sorted(config.main_prompts))),
2023-03-12 17:36:59 -06:00
rating=config.rating or 1.0,
tags=tags,
tag_weights=tag_weights,
max_target_length=config.max_caption_length,
use_weights=use_weights)
item = ImageTrainItem(
image=None,
caption=caption,
aspects=aspects,
pathname=os.path.abspath(image),
flip_p=config.flip_p or 0.0,
multiplier=config.multiply or 1.0,
cond_dropout=config.cond_dropout
)
items.append(item)
return list(sorted(items, key=lambda ti: ti.pathname))