import os import logging import yaml import json from functools import total_ordering from attrs import define, field, Factory from data.image_train_item import ImageCaption, ImageTrainItem from utils.fs_helpers import * from typing import TypeVar, Iterable def overlay(overlay, base): return overlay if overlay is not None else base def safe_set(val): if isinstance(val, str): return {val} if val else {} if isinstance(val, Iterable): return {i for i in val if i is not None} return val or {} @define(frozen=True) @total_ordering class Tag: value: str weight: float = field(default=1.0, converter=lambda x: x if x is not None else 1.0) @classmethod def parse(cls, data): if isinstance(data, str): return Tag(data) if isinstance(data, dict): value = data.get("tag") weight = data.get("weight") if value: return Tag(value, weight) return None def __lt__(self, other): return self.weight < other.weight and self.value < other.value @define class ImageConfig: # Captions main_prompts: set[str] = field(factory=set, converter=safe_set) rating: float = None max_caption_length: int = None tags: set[Tag] = field(factory=set, converter=safe_set) # Options multiply: float = None cond_dropout: float = None flip_p: float = None def merge(self, other): if other is None: return self return ImageConfig( main_prompts=self.main_prompts.union(other.main_prompts), rating=overlay(other.rating, self.rating), max_caption_length=overlay(other.max_caption_length, self.max_caption_length), tags=self.tags.union(other.tags), multiply=overlay(other.multiply, self.multiply), cond_dropout=overlay(other.cond_dropout, self.cond_dropout), flip_p=overlay(other.flip_p, self.flip_p), ) @classmethod def from_dict(cls, data: dict): # Parse standard yaml tag file (with options) parsed_cfg = ImageConfig( main_prompts=safe_set(data.get("main_prompt")), rating=data.get("rating"), max_caption_length=data.get("max_caption_length"), tags=safe_set(map(Tag.parse, data.get("tags", []))), multiply=data.get("multiply"), cond_dropout=data.get("cond_dropout"), flip_p=data.get("flip_p"), ) # Alternatively parse from dedicated `caption` attribute if cap_attr := data.get('caption'): parsed_cfg = parsed_cfg.merge(ImageConfig.parse(cap_attr)) return parsed_cfg @classmethod def fold(cls, configs): acc = ImageConfig() for cfg in configs: acc = acc.merge(cfg) return acc def ensure_caption(self): return self @classmethod def from_caption_text(cls, text: str): if not text: return ImageConfig() if os.path.isfile(text): return ImageConfig.from_file(text) split_caption = list(map(str.strip, text.split(","))) return ImageConfig( main_prompts=split_caption[0], tags=map(Tag.parse, split_caption[1:]) ) @classmethod def from_file(cls, file: str): match ext(file): case '.jpg' | '.jpeg' | '.png' | '.bmp' | '.webp' | '.jfif': return ImageConfig(image=file) case ".json": return ImageConfig.from_dict(json.load(read_text(file))) case ".yaml" | ".yml": return ImageConfig.from_dict(yaml.safe_load(read_text(file))) case ".txt" | ".caption": return ImageConfig.from_caption_text(read_text(file)) case _: return logging.warning(" *** Unrecognized config extension {ext}") @classmethod def parse(cls, input): if isinstance(input, str): if os.path.isfile(input): return ImageConfig.from_file(input) else: return ImageConfig.from_caption_text(input) elif isinstance(input, dict): return ImageConfig.from_dict(input) @define() class Dataset: image_configs: dict[str, ImageConfig] def __global_cfg(files): cfgs = [] for file in files: match os.path.basename(file): case 'global.yaml' | 'global.yml': cfgs.append(ImageConfig.from_file(file)) return ImageConfig.fold(cfgs) def __local_cfg(files): cfgs = [] for file in files: match os.path.basename(file): case 'multiply.txt': cfgs.append(ImageConfig(multiply=read_float(file))) case 'cond_dropout.txt': cfgs.append(ImageConfig(cond_dropout=read_float(file))) case 'flip_p.txt': cfgs.append(ImageConfig(flip_p=read_float(file))) case 'local.yaml' | 'local.yml': cfgs.append(ImageConfig.from_file(file)) return ImageConfig.fold(cfgs) def __sidecar_cfg(imagepath, files): cfgs = [] for file in files: if same_barename(imagepath, file): match ext(file): case '.txt' | '.caption' | '.yml' | '.yaml': cfgs.append(ImageConfig.from_file(file)) return ImageConfig.fold(cfgs) # Use file name for caption only as a last resort @classmethod def __ensure_caption(cls, cfg: ImageConfig, file: str): if cfg.main_prompts or cfg.tags: return cfg cap_cfg = ImageConfig.from_caption_text(barename(file).split("_")[0]) return cfg.merge(cap_cfg) @classmethod def from_path(cls, data_root): # Create a visitor that maintains global config stack # and accumulates image configs as it traverses dataset image_configs = {} def process_dir(files, parent_globals): global_cfg = parent_globals.merge(Dataset.__global_cfg(files)) local_cfg = Dataset.__local_cfg(files) for img in filter(is_image, files): img_cfg = Dataset.__sidecar_cfg(img, files) resolved_cfg = ImageConfig.fold([global_cfg, local_cfg, img_cfg]) image_configs[img] = Dataset.__ensure_caption(resolved_cfg, img) return global_cfg walk_and_visit(data_root, process_dir, ImageConfig()) return Dataset(image_configs) @classmethod def from_json(cls, json_path): """ Import a dataset definition from a JSON file """ image_configs = {} with open(json_path, encoding='utf-8', mode='r') as stream: for data in json.load(stream): img = data.get("image") cfg = Dataset.__ensure_caption(ImageConfig.parse(data), img) if not img: logging.warning(f" *** Error parsing json image entry in {json_path}: {data}") continue image_configs[img] = cfg return Dataset(image_configs) def image_train_items(self, aspects): items = [] for image in self.image_configs: config = self.image_configs[image] if len(config.main_prompts) > 1: logging.warning(f" *** Found multiple multiple main_prompts for image {image}, but only one will be applied: {config.main_prompts}") tags = [] tag_weights = [] for tag in sorted(config.tags): tags.append(tag.value) tag_weights.append(tag.weight) use_weights = len(set(tag_weights)) > 1 caption = ImageCaption( main_prompt=next(iter(sorted(config.main_prompts))), rating=config.rating or 1.0, tags=tags, tag_weights=tag_weights, max_target_length=config.max_caption_length, use_weights=use_weights) item = ImageTrainItem( image=None, caption=caption, aspects=aspects, pathname=os.path.abspath(image), flip_p=config.flip_p or 0.0, multiplier=config.multiply or 1.0, cond_dropout=config.cond_dropout ) items.append(item) return list(sorted(items, key=lambda ti: ti.pathname))