import yaml import json from attrs import define, field from data.image_train_item import ImageCaption, ImageTrainItem from utils.fs_helpers import * from typing import Iterable from tqdm import tqdm DEFAULT_MAX_CAPTION_LENGTH = 2048 def overlay(overlay, base): return overlay if overlay is not None else base def safe_set(val): if isinstance(val, str): return dict.fromkeys([val]) if val else dict() if isinstance(val, Iterable): return dict.fromkeys((i for i in val if i is not None)) return val or dict() @define(frozen=True) class Tag: value: str weight: float = field(default=1.0, converter=lambda x: x if x is not None else 1.0) @classmethod def parse(cls, data): if isinstance(data, str): return Tag(data) if isinstance(data, dict): value = str(data.get("tag")) weight = data.get("weight") if value: return Tag(value, weight) return None @define class ImageConfig: # Captions main_prompts: dict[str, None] = field(factory=dict, converter=safe_set) rating: float = None max_caption_length: int = None tags: dict[Tag, None] = field(factory=dict, converter=safe_set) batch_id: str = None # Options multiply: float = None cond_dropout: float = None flip_p: float = None shuffle_tags: bool = False loss_scale: float = None def merge(self, other): if other is None: return self return ImageConfig( main_prompts=other.main_prompts | self.main_prompts, rating=overlay(other.rating, self.rating), max_caption_length=overlay(other.max_caption_length, self.max_caption_length), tags= other.tags | self.tags, multiply=overlay(other.multiply, self.multiply), cond_dropout=overlay(other.cond_dropout, self.cond_dropout), flip_p=overlay(other.flip_p, self.flip_p), shuffle_tags=overlay(other.shuffle_tags, self.shuffle_tags), batch_id=overlay(other.batch_id, self.batch_id), loss_scale=overlay(other.loss_scale, self.loss_scale) ) @classmethod def from_dict(cls, data: dict): # Parse standard yaml tag file (with options) parsed_cfg = ImageConfig( main_prompts=safe_set(data.get("main_prompt")), rating=data.get("rating"), max_caption_length=data.get("max_caption_length"), tags=safe_set(map(Tag.parse, data.get("tags", []))), multiply=data.get("multiply"), cond_dropout=data.get("cond_dropout"), flip_p=data.get("flip_p"), shuffle_tags=data.get("shuffle_tags"), batch_id=data.get("batch_id"), loss_scale=data.get("loss_scale") ) # Alternatively parse from dedicated `caption` attribute if cap_attr := data.get('caption'): parsed_cfg = parsed_cfg.merge(ImageConfig.parse(cap_attr)) return parsed_cfg @classmethod def fold(cls, configs): acc = ImageConfig() for cfg in configs: acc = acc.merge(cfg) acc.shuffle_tags = any(cfg.shuffle_tags for cfg in configs) #print(f"accum shuffle:{acc.shuffle_tags}") return acc def ensure_caption(self): return self @classmethod def from_caption_text(cls, text: str): if not text: return ImageConfig() if os.path.isfile(text): return ImageConfig.from_file(text) split_caption = list(map(str.strip, text.split(","))) return ImageConfig( main_prompts=split_caption[0], tags=map(Tag.parse, split_caption[1:]) ) @classmethod def from_file(cls, file: str): match ext(file): case '.jpg' | '.jpeg' | '.png' | '.bmp' | '.webp' | '.jfif': return ImageConfig(image=file) case ".json": return ImageConfig.from_dict(json.load(read_text(file))) case ".yaml" | ".yml": return ImageConfig.from_dict(yaml.safe_load(read_text(file))) case ".txt" | ".caption": return ImageConfig.from_caption_text(read_text(file)) case _: return logging.warning(" *** Unrecognized config extension {ext}") @classmethod def parse(cls, input): if isinstance(input, str): if os.path.isfile(input): return ImageConfig.from_file(input) else: return ImageConfig.from_caption_text(input) elif isinstance(input, dict): return ImageConfig.from_dict(input) @define() class Dataset: image_configs: dict[str, ImageConfig] def __global_cfg(fileset): cfgs = [] for cfgfile in ['global.yaml', 'global.yml']: if cfgfile in fileset: cfgs.append(ImageConfig.from_file(fileset[cfgfile])) return ImageConfig.fold(cfgs) def __local_cfg(fileset): cfgs = [] if 'multiply.txt' in fileset: cfgs.append(ImageConfig(multiply=read_float(fileset['multiply.txt']))) if 'cond_dropout.txt' in fileset: cfgs.append(ImageConfig(cond_dropout=read_float(fileset['cond_dropout.txt']))) if 'flip_p.txt' in fileset: cfgs.append(ImageConfig(flip_p=read_float(fileset['flip_p.txt']))) if 'local.yaml' in fileset: cfgs.append(ImageConfig.from_file(fileset['local.yaml'])) if 'local.yml' in fileset: cfgs.append(ImageConfig.from_file(fileset['local.yml'])) if 'batch_id.txt' in fileset: cfgs.append(ImageConfig(batch_id=read_text(fileset['batch_id.txt']))) if 'loss_scale.txt' in fileset: cfgs.append(ImageConfig(loss_scale=read_float(fileset['loss_scale.txt']))) result = ImageConfig.fold(cfgs) if 'shuffle_tags.txt' in fileset: result.shuffle_tags = True return result def __sidecar_cfg(imagepath, fileset): cfgs = [] for cfgext in ['.txt', '.caption', '.yml', '.yaml']: cfgfile = barename(imagepath) + cfgext if cfgfile in fileset: cfgs.append(ImageConfig.from_file(fileset[cfgfile])) return ImageConfig.fold(cfgs) # Use file name for caption only as a last resort @classmethod def __ensure_caption(cls, cfg: ImageConfig, file: str): if cfg.main_prompts: return cfg cap_cfg = ImageConfig.from_caption_text(barename(file).split("_")[0]) return cfg.merge(cap_cfg) @classmethod def from_path(cls, data_root): # Create a visitor that maintains global config stack # and accumulates image configs as it traverses dataset image_configs = {} def process_dir(files, parent_globals): fileset = {os.path.basename(f): f for f in files} global_cfg = parent_globals.merge(Dataset.__global_cfg(fileset)) local_cfg = Dataset.__local_cfg(fileset) for img in filter(is_image, files): img_cfg = Dataset.__sidecar_cfg(img, fileset) resolved_cfg = ImageConfig.fold([global_cfg, local_cfg, img_cfg]) image_configs[img] = Dataset.__ensure_caption(resolved_cfg, img) return global_cfg walk_and_visit(data_root, process_dir, ImageConfig()) return Dataset(image_configs) @classmethod def from_json(cls, json_path): """ Import a dataset definition from a JSON file """ image_configs = {} with open(json_path, encoding='utf-8', mode='r') as stream: for data in json.load(stream): img = data.get("image") cfg = Dataset.__ensure_caption(ImageConfig.parse(data), img) if not img: logging.warning(f" *** Error parsing json image entry in {json_path}: {data}") continue image_configs[img] = cfg return Dataset(image_configs) def image_train_items(self, aspects): items = [] for image in tqdm(self.image_configs, desc="preloading", dynamic_ncols=True): config = self.image_configs[image] #print(f" ********* shuffle: {config.shuffle_tags}") if len(config.main_prompts) > 1: logging.warning(f" *** Found multiple multiple main_prompts for image {image}, but only one will be applied: {config.main_prompts}") if len(config.main_prompts) < 1: logging.warning(f" *** No main_prompts for image {image}") tags = [] tag_weights = [] for tag in sorted(config.tags, key=lambda x: x.weight or 1.0, reverse=True): tags.append(tag.value) tag_weights.append(tag.weight) use_weights = len(set(tag_weights)) > 1 try: caption = ImageCaption( main_prompt=next(iter(config.main_prompts)), rating=config.rating or 1.0, tags=tags, tag_weights=tag_weights, max_target_length=config.max_caption_length or DEFAULT_MAX_CAPTION_LENGTH, use_weights=use_weights) item = ImageTrainItem( image=None, caption=caption, aspects=aspects, pathname=os.path.abspath(image), flip_p=config.flip_p or 0.0, multiplier=config.multiply or 1.0, cond_dropout=config.cond_dropout, shuffle_tags=config.shuffle_tags, batch_id=config.batch_id, loss_scale=config.loss_scale ) items.append(item) except Exception as e: logging.error(f" *** Error preloading image or caption for: {image}, error: {e}") raise e return items