Add support for enhanced dataset configuration
Add support for: * flip_p.txt * cond_dropout.txt * local.yaml consolidated config (including default captions) * global.yaml consolidated config which applies recursively to subfolders * flip_p, and cond_dropout config per image * manifest.json with full image-level configuration
This commit is contained in:
parent
2e3d044ba3
commit
0716c40ab6
|
@ -0,0 +1,261 @@
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import yaml
|
||||||
|
import json
|
||||||
|
|
||||||
|
from functools import total_ordering
|
||||||
|
from attrs import define, Factory
|
||||||
|
from data.image_train_item import ImageCaption, ImageTrainItem
|
||||||
|
from utils.fs_helpers import *
|
||||||
|
|
||||||
|
|
||||||
|
@define(frozen=True)
|
||||||
|
@total_ordering
|
||||||
|
class Tag:
|
||||||
|
value: str
|
||||||
|
weight: float = None
|
||||||
|
|
||||||
|
def __lt__(self, other):
|
||||||
|
return self.value < other.value
|
||||||
|
|
||||||
|
@define(frozen=True)
|
||||||
|
@total_ordering
|
||||||
|
class Caption:
|
||||||
|
main_prompt: str = None
|
||||||
|
rating: float = None
|
||||||
|
max_caption_length: int = None
|
||||||
|
tags: frozenset[Tag] = Factory(frozenset)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict):
|
||||||
|
main_prompt = data.get("main_prompt")
|
||||||
|
rating = data.get("rating")
|
||||||
|
max_caption_length = data.get("max_caption_length")
|
||||||
|
|
||||||
|
tags = frozenset([ Tag(value=t.get("tag"), weight=t.get("weight"))
|
||||||
|
for t in data.get("tags", [])
|
||||||
|
if "tag" in t and len(t.get("tag")) > 0])
|
||||||
|
|
||||||
|
if not main_prompt and not tags:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return Caption(main_prompt=main_prompt, rating=rating, max_caption_length=max_caption_length, tags=tags)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_text(cls, text: str):
|
||||||
|
if text is None:
|
||||||
|
return Caption(main_prompt="")
|
||||||
|
split_caption = list(map(str.strip, text.split(",")))
|
||||||
|
main_prompt = split_caption[0]
|
||||||
|
tags = frozenset(Tag(value=t) for t in split_caption[1:])
|
||||||
|
return Caption(main_prompt=main_prompt, tags=tags)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, input):
|
||||||
|
if isinstance(input, str):
|
||||||
|
if os.path.isfile(input):
|
||||||
|
return Caption.from_text(read_text(input))
|
||||||
|
else:
|
||||||
|
return Caption.from_text(input)
|
||||||
|
elif isinstance(input, dict):
|
||||||
|
return Caption.from_dict(input)
|
||||||
|
|
||||||
|
def __lt__(self, other):
|
||||||
|
self_str = ",".join([self.main_prompt] + sorted(repr(t) for t in self.tags))
|
||||||
|
other_str = ",".join([other.main_prompt] + sorted(repr(t) for t in other.tags))
|
||||||
|
return self_str < other_str
|
||||||
|
|
||||||
|
@define(frozen=True)
|
||||||
|
class ImageConfig:
|
||||||
|
image: str = None
|
||||||
|
captions: frozenset[Caption] = Factory(frozenset)
|
||||||
|
multiply: float = None
|
||||||
|
cond_dropout: float = None
|
||||||
|
flip_p: float = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fold(cls, configs):
|
||||||
|
acc = ImageConfig()
|
||||||
|
[acc := acc.merge(cfg) for cfg in configs]
|
||||||
|
return acc
|
||||||
|
|
||||||
|
def merge(self, other):
|
||||||
|
if other is None:
|
||||||
|
return self
|
||||||
|
|
||||||
|
if other.image and self.image:
|
||||||
|
logging(f"Found two images with different extensions and the same barename: {self.image} and {other.image}")
|
||||||
|
|
||||||
|
return ImageConfig(
|
||||||
|
image = other.image or self.image,
|
||||||
|
captions = other.captions.union(self.captions),
|
||||||
|
multiply = other.multiply if other.multiply is not None else self.multiply,
|
||||||
|
cond_dropout = other.cond_dropout if other.cond_dropout is not None else self.cond_dropout,
|
||||||
|
flip_p = other.flip_p if other.flip_p is not None else self.flip_p
|
||||||
|
)
|
||||||
|
|
||||||
|
def ensure_caption(self):
|
||||||
|
if not self.captions:
|
||||||
|
filename_caption = Caption.from_text(barename(self.image).split("_")[0])
|
||||||
|
return self.merge(ImageConfig(captions=frozenset([filename_caption])))
|
||||||
|
return self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict):
|
||||||
|
captions = set()
|
||||||
|
if "captions" in data:
|
||||||
|
captions.update(Caption.load(cap) for cap in data.get("captions"))
|
||||||
|
|
||||||
|
if "caption" in data:
|
||||||
|
captions.add(Caption.load(data.get("caption")))
|
||||||
|
|
||||||
|
if not captions:
|
||||||
|
# For backward compatibility with existing caption yaml
|
||||||
|
caption = Caption.load(data)
|
||||||
|
if caption:
|
||||||
|
captions.add(caption)
|
||||||
|
|
||||||
|
return ImageConfig(
|
||||||
|
image = data.get("image"),
|
||||||
|
captions=frozenset(captions),
|
||||||
|
multiply=data.get("multiply"),
|
||||||
|
cond_dropout=data.get("cond_dropout"),
|
||||||
|
flip_p=data.get("flip_p"))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_text(cls, text: str):
|
||||||
|
try:
|
||||||
|
if os.path.isfile(text):
|
||||||
|
return ImageConfig.from_file(text)
|
||||||
|
return ImageConfig(captions=frozenset({Caption.from_text(text)}))
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f" *** Error parsing config from text {text}: \n{e}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_file(cls, file: str):
|
||||||
|
try:
|
||||||
|
match ext(file):
|
||||||
|
case '.jpg' | '.jpeg' | '.png' | '.bmp' | '.webp' | '.jfif':
|
||||||
|
return ImageConfig(image=file)
|
||||||
|
case ".json":
|
||||||
|
return ImageConfig.from_dict(json.load(read_text(file)))
|
||||||
|
case ".yaml" | ".yml":
|
||||||
|
return ImageConfig.from_dict(yaml.safe_load(read_text(file)))
|
||||||
|
case ".txt" | ".caption":
|
||||||
|
return ImageConfig.from_text(read_text(file))
|
||||||
|
case _:
|
||||||
|
return logging.warning(" *** Unrecognized config extension {ext}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f" *** Error parsing config from {file}: {e}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, input):
|
||||||
|
if isinstance(input, str):
|
||||||
|
return ImageConfig.from_text(input)
|
||||||
|
elif isinstance(input, dict):
|
||||||
|
return ImageConfig.from_dict(input)
|
||||||
|
|
||||||
|
@define()
|
||||||
|
class Dataset:
|
||||||
|
image_configs: set[ImageConfig]
|
||||||
|
|
||||||
|
def __global_cfg(files):
|
||||||
|
cfgs = []
|
||||||
|
for file in files:
|
||||||
|
match os.path.basename(file):
|
||||||
|
case 'global.yaml' | 'global.yml':
|
||||||
|
cfgs.append(ImageConfig.from_file(file))
|
||||||
|
return ImageConfig.fold(cfgs)
|
||||||
|
|
||||||
|
def __local_cfg(files):
|
||||||
|
cfgs = []
|
||||||
|
for file in files:
|
||||||
|
match os.path.basename(file):
|
||||||
|
case 'multiply.txt':
|
||||||
|
cfgs.append(ImageConfig(multiply=read_float(file)))
|
||||||
|
case 'cond_dropout.txt':
|
||||||
|
cfgs.append(ImageConfig(cond_dropout=read_float(file)))
|
||||||
|
case 'flip_p.txt':
|
||||||
|
cfgs.append(ImageConfig(flip_p=read_float(file)))
|
||||||
|
case 'local.yaml' | 'local.yml':
|
||||||
|
cfgs.append(ImageConfig.from_file(file))
|
||||||
|
return ImageConfig.fold(cfgs)
|
||||||
|
|
||||||
|
def __image_cfg(imagepath, files):
|
||||||
|
cfgs = [ImageConfig.from_file(imagepath)]
|
||||||
|
for file in files:
|
||||||
|
if same_barename(imagepath, file):
|
||||||
|
match ext(file):
|
||||||
|
case '.txt' | '.caption' | '.yml' | '.yaml':
|
||||||
|
cfgs.append(ImageConfig.from_file(file))
|
||||||
|
return ImageConfig.fold(cfgs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_path(cls, data_root):
|
||||||
|
# Create a visitor that maintains global config stack
|
||||||
|
# and accumulates image configs as it traverses dataset
|
||||||
|
image_configs = set()
|
||||||
|
def process_dir(files, parent_globals):
|
||||||
|
globals = parent_globals.merge(Dataset.__global_cfg(files))
|
||||||
|
locals = Dataset.__local_cfg(files)
|
||||||
|
for img in filter(is_image, files):
|
||||||
|
img_cfg = Dataset.__image_cfg(img, files)
|
||||||
|
collapsed_cfg = ImageConfig.fold([globals, locals, img_cfg])
|
||||||
|
resolved_cfg = collapsed_cfg.ensure_caption()
|
||||||
|
image_configs.add(resolved_cfg)
|
||||||
|
return globals
|
||||||
|
|
||||||
|
walk_and_visit(data_root, process_dir, ImageConfig())
|
||||||
|
return Dataset(image_configs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json(cls, json_path):
|
||||||
|
"""
|
||||||
|
Import a dataset definition from a JSON file
|
||||||
|
"""
|
||||||
|
configs = set()
|
||||||
|
with open(json_path, encoding='utf-8', mode='r') as stream:
|
||||||
|
for data in json.load(stream):
|
||||||
|
cfg = ImageConfig.load(data).ensure_caption()
|
||||||
|
if not cfg or not cfg.image:
|
||||||
|
logging.warning(f" *** Error parsing json image entry in {json_path}: {data}")
|
||||||
|
continue
|
||||||
|
configs.add(cfg)
|
||||||
|
return Dataset(configs)
|
||||||
|
|
||||||
|
def image_train_items(self, aspects):
|
||||||
|
items = []
|
||||||
|
for config in self.image_configs:
|
||||||
|
caption = next(iter(sorted(config.captions)))
|
||||||
|
if len(config.captions) > 1:
|
||||||
|
logging.warning(f" *** Found multiple captions for image {config.image}, but only one will be applied: {config.captions}")
|
||||||
|
|
||||||
|
use_weights = len(set(t.weight or 1.0 for t in caption.tags)) > 1
|
||||||
|
tags = []
|
||||||
|
tag_weights = []
|
||||||
|
for tag in sorted(caption.tags):
|
||||||
|
tags.append(tag.value)
|
||||||
|
tag_weights.append(tag.weight or 1.0)
|
||||||
|
use_weights = len(set(tag_weights)) > 1
|
||||||
|
|
||||||
|
caption = ImageCaption(
|
||||||
|
main_prompt=caption.main_prompt,
|
||||||
|
rating=caption.rating,
|
||||||
|
tags=tags,
|
||||||
|
tag_weights=tag_weights,
|
||||||
|
max_target_length=caption.max_caption_length,
|
||||||
|
use_weights=use_weights)
|
||||||
|
|
||||||
|
item = ImageTrainItem(
|
||||||
|
image=None,
|
||||||
|
caption=caption,
|
||||||
|
aspects=aspects,
|
||||||
|
pathname=os.path.abspath(config.image),
|
||||||
|
flip_p=config.flip_p or 0.0,
|
||||||
|
multiplier=config.multiply or 1.0,
|
||||||
|
cond_dropout=config.cond_dropout
|
||||||
|
)
|
||||||
|
items.append(item)
|
||||||
|
return list(sorted(items, key=lambda ti: ti.pathname))
|
||||||
|
|
||||||
|
|
|
@ -104,7 +104,7 @@ class EveryDreamBatch(Dataset):
|
||||||
|
|
||||||
example["image"] = image_transforms(train_item["image"])
|
example["image"] = image_transforms(train_item["image"])
|
||||||
|
|
||||||
if random.random() > self.conditional_dropout:
|
if random.random() > (train_item.cond_dropout or self.conditional_dropout):
|
||||||
example["tokens"] = self.tokenizer(example["caption"],
|
example["tokens"] = self.tokenizer(example["caption"],
|
||||||
truncation=True,
|
truncation=True,
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
|
|
|
@ -113,136 +113,6 @@ class ImageCaption:
|
||||||
random.Random(seed).shuffle(tags)
|
random.Random(seed).shuffle(tags)
|
||||||
return ", ".join(tags)
|
return ", ".join(tags)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def parse(string: str) -> 'ImageCaption':
|
|
||||||
"""
|
|
||||||
Parses a string to get the caption.
|
|
||||||
|
|
||||||
:param string: String to parse.
|
|
||||||
:return: `ImageCaption` object.
|
|
||||||
"""
|
|
||||||
split_caption = list(map(str.strip, string.split(",")))
|
|
||||||
main_prompt = split_caption[0]
|
|
||||||
tags = split_caption[1:]
|
|
||||||
tag_weights = [1.0] * len(tags)
|
|
||||||
|
|
||||||
return ImageCaption(main_prompt, 1.0, tags, tag_weights, DEFAULT_MAX_CAPTION_LENGTH, False)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_file_name(file_path: str) -> 'ImageCaption':
|
|
||||||
"""
|
|
||||||
Parses the file name to get the caption.
|
|
||||||
|
|
||||||
:param file_path: Path to the image file.
|
|
||||||
:return: `ImageCaption` object.
|
|
||||||
"""
|
|
||||||
(file_name, _) = os.path.splitext(os.path.basename(file_path))
|
|
||||||
caption = file_name.split("_")[0]
|
|
||||||
return ImageCaption.parse(caption)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_text_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
|
|
||||||
"""
|
|
||||||
Parses a text file to get the caption. Returns the default caption if
|
|
||||||
the file does not exist or is invalid.
|
|
||||||
|
|
||||||
:param file_path: Path to the text file.
|
|
||||||
:param default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid.
|
|
||||||
:return: `ImageCaption` object or `None`.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
with open(file_path, encoding='utf-8', mode='r') as caption_file:
|
|
||||||
caption_text = caption_file.read()
|
|
||||||
return ImageCaption.parse(caption_text)
|
|
||||||
except:
|
|
||||||
logging.error(f" *** Error reading {file_path} to get caption")
|
|
||||||
return default_caption
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_yaml_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
|
|
||||||
"""
|
|
||||||
Parses a yaml file to get the caption. Returns the default caption if
|
|
||||||
the file does not exist or is invalid.
|
|
||||||
|
|
||||||
:param file_path: path to the yaml file
|
|
||||||
:param default_caption: caption to return if the file does not exist or is invalid
|
|
||||||
:return: `ImageCaption` object or `None`.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
with open(file_path, "r") as stream:
|
|
||||||
file_content = yaml.safe_load(stream)
|
|
||||||
main_prompt = file_content.get("main_prompt", "")
|
|
||||||
rating = file_content.get("rating", 1.0)
|
|
||||||
unparsed_tags = file_content.get("tags", [])
|
|
||||||
|
|
||||||
max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH)
|
|
||||||
|
|
||||||
tags = []
|
|
||||||
tag_weights = []
|
|
||||||
last_weight = None
|
|
||||||
weights_differ = False
|
|
||||||
for unparsed_tag in unparsed_tags:
|
|
||||||
tag = unparsed_tag.get("tag", "").strip()
|
|
||||||
if len(tag) == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
tags.append(tag)
|
|
||||||
tag_weight = unparsed_tag.get("weight", 1.0)
|
|
||||||
tag_weights.append(tag_weight)
|
|
||||||
|
|
||||||
if last_weight is not None and weights_differ is False:
|
|
||||||
weights_differ = last_weight != tag_weight
|
|
||||||
|
|
||||||
last_weight = tag_weight
|
|
||||||
|
|
||||||
return ImageCaption(main_prompt, rating, tags, tag_weights, max_caption_length, weights_differ)
|
|
||||||
except:
|
|
||||||
logging.error(f" *** Error reading {file_path} to get caption")
|
|
||||||
return default_caption
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
|
|
||||||
"""
|
|
||||||
Try to resolve a caption from a file path or return `default_caption`.
|
|
||||||
|
|
||||||
:string: The path to the file to parse.
|
|
||||||
:default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid.
|
|
||||||
:return: `ImageCaption` object or `None`.
|
|
||||||
"""
|
|
||||||
if os.path.exists(file_path):
|
|
||||||
(file_path_without_ext, ext) = os.path.splitext(file_path)
|
|
||||||
match ext:
|
|
||||||
case ".yaml" | ".yml":
|
|
||||||
return ImageCaption.from_yaml_file(file_path, default_caption)
|
|
||||||
|
|
||||||
case ".txt" | ".caption":
|
|
||||||
return ImageCaption.from_text_file(file_path, default_caption)
|
|
||||||
|
|
||||||
case '.jpg'| '.jpeg'| '.png'| '.bmp'| '.webp'| '.jfif':
|
|
||||||
for ext in [".yaml", ".yml", ".txt", ".caption"]:
|
|
||||||
file_path = file_path_without_ext + ext
|
|
||||||
image_caption = ImageCaption.from_file(file_path)
|
|
||||||
if image_caption is not None:
|
|
||||||
return image_caption
|
|
||||||
return ImageCaption.from_file_name(file_path)
|
|
||||||
|
|
||||||
case _:
|
|
||||||
return default_caption
|
|
||||||
else:
|
|
||||||
return default_caption
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def resolve(string: str) -> 'ImageCaption':
|
|
||||||
"""
|
|
||||||
Try to resolve a caption from a string. If the string is a file path,
|
|
||||||
the caption will be read from the file, otherwise the string will be
|
|
||||||
parsed as a caption.
|
|
||||||
|
|
||||||
:string: The string to resolve.
|
|
||||||
:return: `ImageCaption` object.
|
|
||||||
"""
|
|
||||||
return ImageCaption.from_file(string, None) or ImageCaption.parse(string)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageTrainItem:
|
class ImageTrainItem:
|
||||||
"""
|
"""
|
||||||
|
@ -253,7 +123,7 @@ class ImageTrainItem:
|
||||||
flip_p: probability of flipping image (0.0 to 1.0)
|
flip_p: probability of flipping image (0.0 to 1.0)
|
||||||
rating: the relative rating of the images. The rating is measured in comparison to the other images.
|
rating: the relative rating of the images. The rating is measured in comparison to the other images.
|
||||||
"""
|
"""
|
||||||
def __init__(self, image: PIL.Image, caption: ImageCaption, aspects: list[float], pathname: str, flip_p=0.0, multiplier: float=1.0):
|
def __init__(self, image: PIL.Image, caption: ImageCaption, aspects: list[float], pathname: str, flip_p=0.0, multiplier: float=1.0, cond_dropout=None):
|
||||||
self.caption = caption
|
self.caption = caption
|
||||||
self.aspects = aspects
|
self.aspects = aspects
|
||||||
self.pathname = pathname
|
self.pathname = pathname
|
||||||
|
@ -261,6 +131,7 @@ class ImageTrainItem:
|
||||||
self.cropped_img = None
|
self.cropped_img = None
|
||||||
self.runt_size = 0
|
self.runt_size = 0
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
|
self.cond_dropout = cond_dropout
|
||||||
|
|
||||||
self.image_size = None
|
self.image_size = None
|
||||||
if image is None or len(image) == 0:
|
if image is None or len(image) == 0:
|
||||||
|
|
109
data/resolver.py
109
data/resolver.py
|
@ -4,6 +4,7 @@ import os
|
||||||
import typing
|
import typing
|
||||||
import zipfile
|
import zipfile
|
||||||
import argparse
|
import argparse
|
||||||
|
from data.dataset import Dataset
|
||||||
|
|
||||||
import tqdm
|
import tqdm
|
||||||
from colorama import Fore, Style
|
from colorama import Fore, Style
|
||||||
|
@ -27,16 +28,6 @@ class DataResolver:
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def image_train_item(self, image_path: str, caption: ImageCaption, multiplier: float=1) -> ImageTrainItem:
|
|
||||||
return ImageTrainItem(
|
|
||||||
image=None,
|
|
||||||
caption=caption,
|
|
||||||
aspects=self.aspects,
|
|
||||||
pathname=image_path,
|
|
||||||
flip_p=self.flip_p,
|
|
||||||
multiplier=multiplier
|
|
||||||
)
|
|
||||||
|
|
||||||
class JSONResolver(DataResolver):
|
class JSONResolver(DataResolver):
|
||||||
def image_train_items(self, json_path: str) -> list[ImageTrainItem]:
|
def image_train_items(self, json_path: str) -> list[ImageTrainItem]:
|
||||||
"""
|
"""
|
||||||
|
@ -45,62 +36,8 @@ class JSONResolver(DataResolver):
|
||||||
|
|
||||||
:param json_path: The path to the JSON file.
|
:param json_path: The path to the JSON file.
|
||||||
"""
|
"""
|
||||||
items = []
|
return Dataset.from_json(json_path).image_train_items(self.aspects)
|
||||||
with open(json_path, encoding='utf-8', mode='r') as f:
|
|
||||||
json_data = json.load(f)
|
|
||||||
|
|
||||||
for data in tqdm.tqdm(json_data):
|
|
||||||
caption = JSONResolver.image_caption(data)
|
|
||||||
if caption:
|
|
||||||
image_value = JSONResolver.get_image_value(data)
|
|
||||||
item = self.image_train_item(image_value, caption)
|
|
||||||
if item:
|
|
||||||
items.append(item)
|
|
||||||
|
|
||||||
return items
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_image_value(json_data: dict) -> typing.Optional[str]:
|
|
||||||
"""
|
|
||||||
Get the image from the json data if possible.
|
|
||||||
|
|
||||||
:param json_data: The json data, a dict.
|
|
||||||
:return: The image, or None if not found.
|
|
||||||
"""
|
|
||||||
image_value = json_data.get("image", None)
|
|
||||||
if isinstance(image_value, str):
|
|
||||||
image_value = image_value.strip()
|
|
||||||
if os.path.exists(image_value):
|
|
||||||
return image_value
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_caption_value(json_data: dict) -> typing.Optional[str]:
|
|
||||||
"""
|
|
||||||
Get the caption from the json data if possible.
|
|
||||||
|
|
||||||
:param json_data: The json data, a dict.
|
|
||||||
:return: The caption, or None if not found.
|
|
||||||
"""
|
|
||||||
caption_value = json_data.get("caption", None)
|
|
||||||
if isinstance(caption_value, str):
|
|
||||||
return caption_value.strip()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def image_caption(json_data: dict) -> typing.Optional[ImageCaption]:
|
|
||||||
"""
|
|
||||||
Get the caption from the json data if possible.
|
|
||||||
|
|
||||||
:param json_data: The json data, a dict.
|
|
||||||
:return: The `ImageCaption`, or None if not found.
|
|
||||||
"""
|
|
||||||
image_value = JSONResolver.get_image_value(json_data)
|
|
||||||
caption_value = JSONResolver.get_caption_value(json_data)
|
|
||||||
if image_value:
|
|
||||||
if caption_value:
|
|
||||||
return ImageCaption.resolve(caption_value)
|
|
||||||
return ImageCaption.from_file(image_value)
|
|
||||||
|
|
||||||
|
|
||||||
class DirectoryResolver(DataResolver):
|
class DirectoryResolver(DataResolver):
|
||||||
def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
|
def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
|
||||||
"""
|
"""
|
||||||
|
@ -111,32 +48,7 @@ class DirectoryResolver(DataResolver):
|
||||||
:param data_root: The root directory to recurse through
|
:param data_root: The root directory to recurse through
|
||||||
"""
|
"""
|
||||||
DirectoryResolver.unzip_all(data_root)
|
DirectoryResolver.unzip_all(data_root)
|
||||||
image_paths = list(DirectoryResolver.recurse_data_root(data_root))
|
return Dataset.from_path(data_root).image_train_items(self.aspects)
|
||||||
items = []
|
|
||||||
multipliers = {}
|
|
||||||
|
|
||||||
for pathname in tqdm.tqdm(image_paths):
|
|
||||||
current_dir = os.path.dirname(pathname)
|
|
||||||
|
|
||||||
if current_dir not in multipliers:
|
|
||||||
multiply_txt_path = os.path.join(current_dir, "multiply.txt")
|
|
||||||
if os.path.exists(multiply_txt_path):
|
|
||||||
try:
|
|
||||||
with open(multiply_txt_path, 'r') as f:
|
|
||||||
val = float(f.read().strip())
|
|
||||||
multipliers[current_dir] = val
|
|
||||||
logging.info(f" - multiply.txt in '{current_dir}' set to {val}")
|
|
||||||
except Exception as e:
|
|
||||||
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}")
|
|
||||||
multipliers[current_dir] = 1.0
|
|
||||||
else:
|
|
||||||
multipliers[current_dir] = 1.0
|
|
||||||
|
|
||||||
caption = ImageCaption.resolve(pathname)
|
|
||||||
item = self.image_train_item(pathname, caption, multiplier=multipliers[current_dir])
|
|
||||||
items.append(item)
|
|
||||||
|
|
||||||
return items
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def unzip_all(path):
|
def unzip_all(path):
|
||||||
|
@ -150,21 +62,6 @@ class DirectoryResolver(DataResolver):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error unzipping files {e}")
|
logging.error(f"Error unzipping files {e}")
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def recurse_data_root(recurse_root):
|
|
||||||
for f in os.listdir(recurse_root):
|
|
||||||
current = os.path.join(recurse_root, f)
|
|
||||||
|
|
||||||
if os.path.isfile(current):
|
|
||||||
ext = os.path.splitext(f)[1].lower()
|
|
||||||
if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']:
|
|
||||||
yield current
|
|
||||||
|
|
||||||
for d in os.listdir(recurse_root):
|
|
||||||
current = os.path.join(recurse_root, d)
|
|
||||||
if os.path.isdir(current):
|
|
||||||
yield from DirectoryResolver.recurse_data_root(current)
|
|
||||||
|
|
||||||
def strategy(data_root: str) -> typing.Type[DataResolver]:
|
def strategy(data_root: str) -> typing.Type[DataResolver]:
|
||||||
"""
|
"""
|
||||||
Determine the strategy to use for resolving the data.
|
Determine the strategy to use for resolving the data.
|
||||||
|
|
|
@ -10,6 +10,7 @@ ninja
|
||||||
omegaconf==2.2.3
|
omegaconf==2.2.3
|
||||||
piexif==1.1.3
|
piexif==1.1.3
|
||||||
protobuf==3.20.3
|
protobuf==3.20.3
|
||||||
|
pyfakefs
|
||||||
pynvml==11.5.0
|
pynvml==11.5.0
|
||||||
pyre-extensions==0.0.30
|
pyre-extensions==0.0.30
|
||||||
pytorch-lightning==1.9.2
|
pytorch-lightning==1.9.2
|
||||||
|
|
|
@ -58,13 +58,13 @@ class TestResolve(unittest.TestCase):
|
||||||
|
|
||||||
def test_directory_resolve_with_str(self):
|
def test_directory_resolve_with_str(self):
|
||||||
items = resolver.resolve(DATA_PATH, ARGS)
|
items = resolver.resolve(DATA_PATH, ARGS)
|
||||||
image_paths = [item.pathname for item in items]
|
image_paths = set(item.pathname for item in items)
|
||||||
image_captions = [item.caption for item in items]
|
image_captions = [item.caption for item in items]
|
||||||
captions = [caption.get_caption() for caption in image_captions]
|
captions = set(caption.get_caption() for caption in image_captions)
|
||||||
|
|
||||||
self.assertEqual(len(items), 3)
|
self.assertEqual(len(items), 3)
|
||||||
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH])
|
self.assertEqual(image_paths, {IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH})
|
||||||
self.assertEqual(captions, ['caption for test1', 'test2', 'test3'])
|
self.assertEqual(captions, {'caption for test1', 'test2', 'test3'})
|
||||||
|
|
||||||
undersized_images = list(filter(lambda i: i.is_undersized, items))
|
undersized_images = list(filter(lambda i: i.is_undersized, items))
|
||||||
self.assertEqual(len(undersized_images), 1)
|
self.assertEqual(len(undersized_images), 1)
|
||||||
|
|
|
@ -0,0 +1,329 @@
|
||||||
|
import os
|
||||||
|
from data.dataset import Dataset, ImageConfig, Caption, Tag
|
||||||
|
|
||||||
|
from textwrap import dedent
|
||||||
|
from pyfakefs.fake_filesystem_unittest import TestCase
|
||||||
|
|
||||||
|
class TestResolve(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.setUpPyfakefs()
|
||||||
|
|
||||||
|
def test_simple_image(self):
|
||||||
|
self.fs.create_file("image, tag1, tag2.jpg")
|
||||||
|
|
||||||
|
actual = Dataset.from_path(".").image_configs
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
ImageConfig(
|
||||||
|
image="./image, tag1, tag2.jpg",
|
||||||
|
captions=frozenset([
|
||||||
|
Caption(main_prompt="image", tags=frozenset([Tag("tag1"), Tag("tag2")]))
|
||||||
|
]))
|
||||||
|
}
|
||||||
|
self.assertEqual(expected, actual)
|
||||||
|
|
||||||
|
def test_image_types(self):
|
||||||
|
self.fs.create_file("image_1.JPG")
|
||||||
|
self.fs.create_file("image_2.jpeg")
|
||||||
|
self.fs.create_file("image_3.png")
|
||||||
|
self.fs.create_file("image_4.webp")
|
||||||
|
self.fs.create_file("image_5.jfif")
|
||||||
|
self.fs.create_file("image_6.bmp")
|
||||||
|
|
||||||
|
actual = Dataset.from_path(".").image_configs
|
||||||
|
|
||||||
|
captions = frozenset([Caption(main_prompt="image")])
|
||||||
|
expected = {
|
||||||
|
ImageConfig(image="./image_1.JPG", captions=captions),
|
||||||
|
ImageConfig(image="./image_2.jpeg", captions=captions),
|
||||||
|
ImageConfig(image="./image_3.png", captions=captions),
|
||||||
|
ImageConfig(image="./image_4.webp", captions=captions),
|
||||||
|
ImageConfig(image="./image_5.jfif", captions=captions),
|
||||||
|
ImageConfig(image="./image_6.bmp", captions=captions),
|
||||||
|
}
|
||||||
|
self.assertEqual(expected, actual)
|
||||||
|
|
||||||
|
def test_caption_file(self):
|
||||||
|
self.fs.create_file("image_1.jpg")
|
||||||
|
self.fs.create_file("image_1.txt", contents="an image, test, from .txt")
|
||||||
|
self.fs.create_file("image_2.jpg")
|
||||||
|
self.fs.create_file("image_2.caption", contents="an image, test, from .caption")
|
||||||
|
|
||||||
|
actual = Dataset.from_path(".").image_configs
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
ImageConfig(
|
||||||
|
image="./image_1.jpg",
|
||||||
|
captions=frozenset([
|
||||||
|
Caption(main_prompt="an image", tags=frozenset([Tag("test"), Tag("from .txt")]))
|
||||||
|
])),
|
||||||
|
ImageConfig(
|
||||||
|
image="./image_2.jpg",
|
||||||
|
captions=frozenset([
|
||||||
|
Caption(main_prompt="an image", tags=frozenset([Tag("test"), Tag("from .caption")]))
|
||||||
|
]))
|
||||||
|
}
|
||||||
|
self.assertEqual(expected, actual)
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_yaml(self):
|
||||||
|
self.fs.create_file("image_1.jpg")
|
||||||
|
self.fs.create_file("image_1.yaml",
|
||||||
|
contents=dedent("""
|
||||||
|
multiply: 2
|
||||||
|
cond_dropout: 0.05
|
||||||
|
flip_p: 0.5
|
||||||
|
caption: "A simple caption, from .yaml"
|
||||||
|
"""))
|
||||||
|
self.fs.create_file("image_2.jpg")
|
||||||
|
self.fs.create_file("image_2.yml",
|
||||||
|
contents=dedent("""
|
||||||
|
flip_p: 0.0
|
||||||
|
caption:
|
||||||
|
main_prompt: A complex caption
|
||||||
|
rating: 1.1
|
||||||
|
max_caption_length: 1024
|
||||||
|
tags:
|
||||||
|
- tag: from .yml
|
||||||
|
- tag: with weight
|
||||||
|
weight: 0.5
|
||||||
|
"""))
|
||||||
|
|
||||||
|
actual = Dataset.from_path(".").image_configs
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
ImageConfig(
|
||||||
|
image="./image_1.jpg",
|
||||||
|
multiply=2,
|
||||||
|
cond_dropout=0.05,
|
||||||
|
flip_p=0.5,
|
||||||
|
captions=frozenset([
|
||||||
|
Caption(main_prompt="A simple caption", tags=frozenset([Tag("from .yaml")]))
|
||||||
|
])),
|
||||||
|
ImageConfig(
|
||||||
|
image="./image_2.jpg",
|
||||||
|
flip_p=0.0,
|
||||||
|
captions=frozenset([
|
||||||
|
Caption(main_prompt="A complex caption", rating=1.1,
|
||||||
|
max_caption_length=1024,
|
||||||
|
tags=frozenset([
|
||||||
|
Tag("from .yml"),
|
||||||
|
Tag("with weight", weight=0.5)
|
||||||
|
]))
|
||||||
|
]))
|
||||||
|
}
|
||||||
|
self.assertEqual(expected, actual)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_caption(self):
|
||||||
|
self.fs.create_file("image_1.jpg")
|
||||||
|
self.fs.create_file("image_1.yaml", contents=dedent("""
|
||||||
|
caption: "A simple caption, from .yaml"
|
||||||
|
captions:
|
||||||
|
- "Another simple caption"
|
||||||
|
- main_prompt: A complex caption
|
||||||
|
"""))
|
||||||
|
self.fs.create_file("image_1.txt", contents="A .txt caption")
|
||||||
|
self.fs.create_file("image_1.caption", contents="A .caption caption")
|
||||||
|
|
||||||
|
actual = Dataset.from_path(".").image_configs
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
ImageConfig(
|
||||||
|
image="./image_1.jpg",
|
||||||
|
captions=frozenset([
|
||||||
|
Caption(main_prompt="A simple caption", tags=frozenset([Tag("from .yaml")])),
|
||||||
|
Caption(main_prompt="Another simple caption", tags=frozenset()),
|
||||||
|
Caption(main_prompt="A complex caption", tags=frozenset()),
|
||||||
|
Caption(main_prompt="A .txt caption", tags=frozenset()),
|
||||||
|
Caption(main_prompt="A .caption caption", tags=frozenset())
|
||||||
|
])
|
||||||
|
),
|
||||||
|
}
|
||||||
|
self.assertEqual(expected, actual)
|
||||||
|
|
||||||
|
def test_globals_and_locals(self):
|
||||||
|
self.fs.create_file("./people/global.yaml", contents=dedent("""\
|
||||||
|
multiply: 1.0
|
||||||
|
cond_dropout: 0.0
|
||||||
|
flip_p: 0.0
|
||||||
|
"""))
|
||||||
|
self.fs.create_file("./people/alice/local.yaml", contents="multiply: 1.5")
|
||||||
|
self.fs.create_file("./people/alice/alice_1.png")
|
||||||
|
self.fs.create_file("./people/alice/alice_1.yaml", contents="multiply: 2")
|
||||||
|
self.fs.create_file("./people/alice/alice_2.png")
|
||||||
|
|
||||||
|
self.fs.create_file("./people/bob/multiply.txt", contents="3")
|
||||||
|
self.fs.create_file("./people/bob/cond_dropout.txt", contents="0.05")
|
||||||
|
self.fs.create_file("./people/bob/flip_p.txt", contents="0.05")
|
||||||
|
self.fs.create_file("./people/bob/bob.png")
|
||||||
|
|
||||||
|
self.fs.create_file("./people/cleo/cleo.png")
|
||||||
|
self.fs.create_file("./people/dan.png")
|
||||||
|
|
||||||
|
self.fs.create_file("./other/dog/local.yaml", contents="caption: spike")
|
||||||
|
self.fs.create_file("./other/dog/xyz.png")
|
||||||
|
|
||||||
|
actual = Dataset.from_path(".").image_configs
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
ImageConfig(
|
||||||
|
image="./people/alice/alice_1.png",
|
||||||
|
captions=frozenset([Caption(main_prompt="alice")]),
|
||||||
|
multiply=2,
|
||||||
|
cond_dropout=0.0,
|
||||||
|
flip_p=0.0
|
||||||
|
),
|
||||||
|
ImageConfig(
|
||||||
|
image="./people/alice/alice_2.png",
|
||||||
|
captions=frozenset([Caption(main_prompt="alice")]),
|
||||||
|
multiply=1.5,
|
||||||
|
cond_dropout=0.0,
|
||||||
|
flip_p=0.0
|
||||||
|
),
|
||||||
|
ImageConfig(
|
||||||
|
image="./people/bob/bob.png",
|
||||||
|
captions=frozenset([Caption(main_prompt="bob")]),
|
||||||
|
multiply=3,
|
||||||
|
cond_dropout=0.05,
|
||||||
|
flip_p=0.05
|
||||||
|
),
|
||||||
|
ImageConfig(
|
||||||
|
image="./people/cleo/cleo.png",
|
||||||
|
captions=frozenset([Caption(main_prompt="cleo")]),
|
||||||
|
multiply=1.0,
|
||||||
|
cond_dropout=0.0,
|
||||||
|
flip_p=0.0
|
||||||
|
),
|
||||||
|
ImageConfig(
|
||||||
|
image="./people/dan.png",
|
||||||
|
captions=frozenset([Caption(main_prompt="dan")]),
|
||||||
|
multiply=1.0,
|
||||||
|
cond_dropout=0.0,
|
||||||
|
flip_p=0.0
|
||||||
|
),
|
||||||
|
ImageConfig(
|
||||||
|
image="./other/dog/xyz.png",
|
||||||
|
captions=frozenset([Caption(main_prompt="spike")]),
|
||||||
|
multiply=None,
|
||||||
|
cond_dropout=None,
|
||||||
|
flip_p=None
|
||||||
|
)
|
||||||
|
}
|
||||||
|
self.assertEqual(expected, actual)
|
||||||
|
|
||||||
|
def test_json_manifest(self):
|
||||||
|
self.fs.create_file("./stuff/image_1.jpg")
|
||||||
|
self.fs.create_file("./stuff/default.caption", contents= "default caption")
|
||||||
|
self.fs.create_file("./other/image_1.jpg")
|
||||||
|
self.fs.create_file("./other/image_2.jpg")
|
||||||
|
self.fs.create_file("./other/image_3.jpg")
|
||||||
|
self.fs.create_file("./manifest.json", contents=dedent("""
|
||||||
|
[
|
||||||
|
{ "image": "./stuff/image_1.jpg", "caption": "./stuff/default.caption" },
|
||||||
|
{ "image": "./other/image_1.jpg", "caption": "other caption" },
|
||||||
|
{
|
||||||
|
"image": "./other/image_2.jpg",
|
||||||
|
"caption": {
|
||||||
|
"main_prompt": "complex caption",
|
||||||
|
"rating": 0.1,
|
||||||
|
"max_caption_length": 1000,
|
||||||
|
"tags": [
|
||||||
|
{"tag": "including"},
|
||||||
|
{"tag": "weighted tag", "weight": 999.9}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"image": "./other/image_3.jpg",
|
||||||
|
"multiply": 2,
|
||||||
|
"flip_p": 0.5,
|
||||||
|
"cond_dropout": 0.01,
|
||||||
|
"captions": [
|
||||||
|
"first caption",
|
||||||
|
{ "main_prompt": "second caption" }
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
"""))
|
||||||
|
|
||||||
|
actual = Dataset.from_json("./manifest.json").image_configs
|
||||||
|
expected = {
|
||||||
|
ImageConfig(
|
||||||
|
image="./stuff/image_1.jpg",
|
||||||
|
captions=frozenset([Caption(main_prompt="default caption")])
|
||||||
|
),
|
||||||
|
ImageConfig(
|
||||||
|
image="./other/image_1.jpg",
|
||||||
|
captions=frozenset([Caption(main_prompt="other caption")])
|
||||||
|
),
|
||||||
|
ImageConfig(
|
||||||
|
image="./other/image_2.jpg",
|
||||||
|
captions=frozenset([
|
||||||
|
Caption(
|
||||||
|
main_prompt="complex caption",
|
||||||
|
rating=0.1,
|
||||||
|
max_caption_length=1000,
|
||||||
|
tags=frozenset([
|
||||||
|
Tag("including"),
|
||||||
|
Tag("weighted tag", 999.9)
|
||||||
|
]))
|
||||||
|
])
|
||||||
|
),
|
||||||
|
ImageConfig(
|
||||||
|
image="./other/image_3.jpg",
|
||||||
|
multiply=2,
|
||||||
|
flip_p=0.5,
|
||||||
|
cond_dropout=0.01,
|
||||||
|
captions=frozenset([
|
||||||
|
Caption("first caption"),
|
||||||
|
Caption("second caption")
|
||||||
|
])
|
||||||
|
)
|
||||||
|
}
|
||||||
|
self.assertEqual(expected, actual)
|
||||||
|
|
||||||
|
def test_train_items(self):
|
||||||
|
dataset = Dataset([
|
||||||
|
ImageConfig(
|
||||||
|
image="1.jpg",
|
||||||
|
multiply=2,
|
||||||
|
flip_p=0.1,
|
||||||
|
cond_dropout=0.01,
|
||||||
|
captions=frozenset([
|
||||||
|
Caption(
|
||||||
|
main_prompt="first caption",
|
||||||
|
rating = 1.1,
|
||||||
|
max_caption_length=1024,
|
||||||
|
tags=frozenset([
|
||||||
|
Tag("tag"),
|
||||||
|
Tag("tag_2", 2.0)
|
||||||
|
])),
|
||||||
|
Caption(main_prompt="second_caption")
|
||||||
|
])),
|
||||||
|
ImageConfig(
|
||||||
|
image="2.jpg",
|
||||||
|
captions=frozenset([Caption(main_prompt="single caption")])
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
aspects = []
|
||||||
|
actual = dataset.image_train_items(aspects)
|
||||||
|
|
||||||
|
self.assertEqual(len(actual), 2)
|
||||||
|
|
||||||
|
self.assertEqual(actual[0].pathname, os.path.abspath('1.jpg'))
|
||||||
|
self.assertEqual(actual[0].multiplier, 2.0)
|
||||||
|
self.assertEqual(actual[0].flip.p, 0.1)
|
||||||
|
self.assertEqual(actual[0].cond_dropout, 0.01)
|
||||||
|
self.assertEqual(actual[0].caption.get_caption(), "first caption, tag, tag_2")
|
||||||
|
# Can't test this
|
||||||
|
# self.assertTrue(actual[0].caption.__use_weights)
|
||||||
|
|
||||||
|
self.assertEqual(actual[1].pathname, os.path.abspath('2.jpg'))
|
||||||
|
self.assertEqual(actual[1].multiplier, 1.0)
|
||||||
|
self.assertEqual(actual[1].flip.p, 0.0)
|
||||||
|
self.assertIsNone(actual[1].cond_dropout)
|
||||||
|
self.assertEqual(actual[1].caption.get_caption(), "single caption")
|
||||||
|
# Can't test this
|
||||||
|
# self.assertFalse(actual[1].caption.__use_weights)
|
|
@ -32,40 +32,4 @@ class TestImageCaption(unittest.TestCase):
|
||||||
self.assertEqual(caption.get_caption(), "hello world, one, two, three")
|
self.assertEqual(caption.get_caption(), "hello world, one, two, three")
|
||||||
|
|
||||||
caption = ImageCaption("hello world", 1.0, [], [], 2048, False)
|
caption = ImageCaption("hello world", 1.0, [], [], 2048, False)
|
||||||
self.assertEqual(caption.get_caption(), "hello world")
|
self.assertEqual(caption.get_caption(), "hello world")
|
||||||
|
|
||||||
def test_parse(self):
|
|
||||||
caption = ImageCaption.parse("hello world, one, two, three")
|
|
||||||
|
|
||||||
self.assertEqual(caption.get_caption(), "hello world, one, two, three")
|
|
||||||
|
|
||||||
def test_from_file_name(self):
|
|
||||||
caption = ImageCaption.from_file_name("foo bar_1_2_3.jpg")
|
|
||||||
self.assertEqual(caption.get_caption(), "foo bar")
|
|
||||||
|
|
||||||
def test_from_text_file(self):
|
|
||||||
caption = ImageCaption.from_text_file("test/data/test1.txt")
|
|
||||||
self.assertEqual(caption.get_caption(), "caption for test1")
|
|
||||||
|
|
||||||
def test_from_file(self):
|
|
||||||
caption = ImageCaption.from_file("test/data/test1.txt")
|
|
||||||
self.assertEqual(caption.get_caption(), "caption for test1")
|
|
||||||
|
|
||||||
caption = ImageCaption.from_file("test/data/test_caption.caption")
|
|
||||||
self.assertEqual(caption.get_caption(), "caption for test2")
|
|
||||||
|
|
||||||
def test_resolve(self):
|
|
||||||
caption = ImageCaption.resolve("test/data/test1.txt")
|
|
||||||
self.assertEqual(caption.get_caption(), "caption for test1")
|
|
||||||
|
|
||||||
caption = ImageCaption.resolve("test/data/test_caption.caption")
|
|
||||||
self.assertEqual(caption.get_caption(), "caption for test2")
|
|
||||||
|
|
||||||
caption = ImageCaption.resolve("hello world")
|
|
||||||
self.assertEqual(caption.get_caption(), "hello world")
|
|
||||||
|
|
||||||
caption = ImageCaption.resolve("test/data/test1.jpg")
|
|
||||||
self.assertEqual(caption.get_caption(), "caption for test1")
|
|
||||||
|
|
||||||
caption = ImageCaption.resolve("test/data/test2.jpg")
|
|
||||||
self.assertEqual(caption.get_caption(), "test2")
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
|
||||||
|
def barename(file):
|
||||||
|
(val, _) = os.path.splitext(os.path.basename(file))
|
||||||
|
return val
|
||||||
|
|
||||||
|
def ext(file):
|
||||||
|
(_, val) = os.path.splitext(os.path.basename(file))
|
||||||
|
return val.lower()
|
||||||
|
|
||||||
|
def same_barename(lhs, rhs):
|
||||||
|
return barename(lhs) == barename(rhs)
|
||||||
|
|
||||||
|
def is_image(file):
|
||||||
|
return ext(file) in {'.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif'}
|
||||||
|
|
||||||
|
def read_text(file):
|
||||||
|
try:
|
||||||
|
with open(file, encoding='utf-8', mode='r') as stream:
|
||||||
|
return stream.read().strip()
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f" *** Error reading text file {file}: {e}")
|
||||||
|
|
||||||
|
def read_float(file):
|
||||||
|
try:
|
||||||
|
return float(read_text(file))
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f" *** Could not parse '{data}' to float in file {file}: {e}")
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
def walk_and_visit(path, visit_fn, context=None):
|
||||||
|
names = [entry.name for entry in os.scandir(path)]
|
||||||
|
|
||||||
|
dirs = []
|
||||||
|
files = []
|
||||||
|
for name in names:
|
||||||
|
fullname = os.path.join(path, name)
|
||||||
|
if os.path.isdir(fullname) and not str(name).startswith('.'):
|
||||||
|
dirs.append(fullname)
|
||||||
|
else:
|
||||||
|
files.append(fullname)
|
||||||
|
|
||||||
|
subcontext = visit_fn(files, context)
|
||||||
|
|
||||||
|
for subdir in dirs:
|
||||||
|
walk_and_visit(subdir, visit_fn, subcontext)
|
Loading…
Reference in New Issue