Merge pull request #106 from qslug/enhanced-config

Add support for enhanced dataset configuration
This commit is contained in:
Victor Hall 2023-03-14 01:28:54 -04:00 committed by GitHub
commit 493afc3f20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 667 additions and 348 deletions

View File

@ -51,10 +51,72 @@
},
{
"cell_type": "markdown",
"id": "3d9b0db8-c2b1-4f0a-b835-b6b2ef527019",
"id": "f15fcd56-0418-4be1-a5c3-38aa679b1aaf",
"metadata": {},
"source": [
"### HuggingFace Login\n",
"# Start Training\n",
"Naming your project will help you track what the heck you're doing when you're floating in checkpoint files later.\n",
"\n",
"You may wish to consider adding \"sd1\" or \"sd2v\" or similar to remember what the base was, as you'll also have to tell your inference app what you were using, as its difficult for programs to know what inference YAML to use automatically. For instance, Automatic1111 webui requires you to copy the v2 inference YAML and rename it to match your checkpoint name so it knows how to load the file, tough it assumes SD 1.x compatible. Something to keep in mind if you start training on SD2.1.\n",
"\n",
"`max_epochs`, `sample_steps`, and `save_every_n_epochs` should be tuned to your dataset. I like to generate one or two sets of samples per save, and aim for 5 (give or take 2) saved checkpoints.\n",
"\n",
"Next cell runs training. This will take a while depending on your number of images, repeats, and max_epochs.\n",
"\n",
"You can watch for test images in the logs folder.\n",
"\n",
"#### Weights and Balanaces\n",
"I you pass the `--wandb` flag you will be prompted for your W&B `API Key`. W&B is a free online logging utility. If you don't have a W&B account, you can create one for free at https://wandb.ai/site. Your key is on this page: https://wandb.ai/settings under \"Danger Zone\" \"API Keys\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6f73fb86-ebef-41e2-9382-4aa11be84be6",
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"%run train.py --config train.json \\\n",
"--resume_ckpt \"panopstor/EveryDream\" \\\n",
"--project_name \"sd1_mymodel\" \\\n",
"--data_root \"input\" \\\n",
"--max_epochs 200 \\\n",
"--sample_steps 150 \\\n",
"--save_every_n_epochs 35 \\\n",
"--lr 1.2e-6 \\\n",
"--lr_scheduler constant \\\n",
"--save_full_precision\n"
]
},
{
"cell_type": "markdown",
"id": "ed464c6b-1a8d-48e4-9787-265e8acaac43",
"metadata": {},
"source": [
"### Optionally you can chain trainings together using multiple configurations combined with `resume_ckpt: findlast`"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "492350d4-9b2f-4d2a-9641-1f723125b296",
"metadata": {},
"outputs": [],
"source": [
"%run train.py --config chain0.json --project_name \"sd1_chain_a\" --data_root \"input\" --resume_ckpt \"panopstor/EveryDream\"\n",
"%run train.py --config chain1.json --project_name \"sd1_chain_b\" --data_root \"input\" --resume_ckpt findlast\n",
"%run train.py --config chain2.json --project_name \"sd1_chain_c\" --data_root \"input\" --resume_ckpt findlast"
]
},
{
"cell_type": "markdown",
"id": "3c506e79-bf03-4e34-bf06-9371963d4d7d",
"metadata": {},
"source": [
"# HuggingFace download (Optional)\n",
"Run the cell below and paste your token into the prompt. You can get your token from your [huggingface account page](https://huggingface.co/settings/tokens).\n",
"\n",
"The token will not show on the screen, just press enter after you paste it."
@ -74,10 +136,10 @@
},
{
"cell_type": "markdown",
"id": "7a96f2af-8c93-4460-aa9e-2ff795fb06ea",
"id": "b252a308-49cf-443f-abbb-d08b471411fb",
"metadata": {},
"source": [
"#### Then run the following cell to download the base checkpoint (may take a minute)."
"Then run the following cell to download the base checkpoint (may take a minute)."
]
},
{
@ -111,68 +173,6 @@
"print(\"DONE\")"
]
},
{
"cell_type": "markdown",
"id": "f15fcd56-0418-4be1-a5c3-38aa679b1aaf",
"metadata": {},
"source": [
"# Start Training\n",
"Naming your project will help you track what the heck you're doing when you're floating in checkpoint files later.\n",
"\n",
"You may wish to consider adding \"sd1\" or \"sd2v\" or similar to remember what the base was, as you'll also have to tell your inference app what you were using, as its difficult for programs to know what inference YAML to use automatically. For instance, Automatic1111 webui requires you to copy the v2 inference YAML and rename it to match your checkpoint name so it knows how to load the file, tough it assumes SD 1.x compatible. Something to keep in mind if you start training on SD2.1.\n",
"\n",
"`max_epochs`, `sample_steps`, and `save_every_n_epochs` should be tuned to your dataset. I like to generate one or two sets of samples per save, and aim for 5 (give or take 2) saved checkpoints.\n",
"\n",
"Next cell runs training. This will take a while depending on your number of images, repeats, and max_epochs.\n",
"\n",
"You can watch for test images in the logs folder.\n",
"\n",
"## Weights and Balanaces\n",
"I you pass the `--wandb` flag you will be prompted for your W&B `API Key`. W&B is a free online logging utility. If you don't have a W&B account, you can create one for free at https://wandb.ai/site. Your key is on this page: https://wandb.ai/settings under \"Danger Zone\" \"API Keys\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6f73fb86-ebef-41e2-9382-4aa11be84be6",
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"%run train.py --config train.json \\\n",
"--resume_ckpt \"sd_v1-5_vae\" \\\n",
"--project_name \"sd1_mymodel\" \\\n",
"--data_root \"input\" \\\n",
"--max_epochs 200 \\\n",
"--sample_steps 150 \\\n",
"--save_every_n_epochs 35 \\\n",
"--lr 1.2e-6 \\\n",
"--lr_scheduler constant \\\n",
"--save_full_precision\n"
]
},
{
"cell_type": "markdown",
"id": "ed464c6b-1a8d-48e4-9787-265e8acaac43",
"metadata": {},
"source": [
"### Optionally you can chain trainings together using multiple configurations combined with `resume_ckpt: findlast`"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "492350d4-9b2f-4d2a-9641-1f723125b296",
"metadata": {},
"outputs": [],
"source": [
"%run train.py --config chain0.json --project_name \"sd1_chain_a\" --data_root \"input\" --resume_ckpt \"{ckpt_name}\"\n",
"%run train.py --config chain1.json --project_name \"sd1_chain_b\" --data_root \"input\" --resume_ckpt findlast\n",
"%run train.py --config chain2.json --project_name \"sd1_chain_c\" --data_root \"input\" --resume_ckpt findlast"
]
},
{
"cell_type": "markdown",
"id": "f24eee3d-f5df-45f3-9acc-ee0206cfe6b1",
@ -351,7 +351,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@ -365,7 +365,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.10"
"version": "3.10.6"
},
"vscode": {
"interpreter": {

249
data/dataset.py Normal file
View File

@ -0,0 +1,249 @@
import os
import logging
import yaml
import json
from functools import total_ordering
from attrs import define, field, Factory
from data.image_train_item import ImageCaption, ImageTrainItem
from utils.fs_helpers import *
from typing import TypeVar, Iterable
def overlay(overlay, base):
return overlay if overlay is not None else base
def safe_set(val):
if isinstance(val, str):
return {val} if val else {}
if isinstance(val, Iterable):
return {i for i in val if i is not None}
return val or {}
@define(frozen=True)
@total_ordering
class Tag:
value: str
weight: float = field(default=1.0, converter=lambda x: x if x is not None else 1.0)
@classmethod
def parse(cls, data):
if isinstance(data, str):
return Tag(data)
if isinstance(data, dict):
value = data.get("tag")
weight = data.get("weight")
if value:
return Tag(value, weight)
return None
def __lt__(self, other):
return self.weight < other.weight and self.value < other.value
@define
class ImageConfig:
# Captions
main_prompts: set[str] = field(factory=set, converter=safe_set)
rating: float = None
max_caption_length: int = None
tags: set[Tag] = field(factory=set, converter=safe_set)
# Options
multiply: float = None
cond_dropout: float = None
flip_p: float = None
def merge(self, other):
if other is None:
return self
return ImageConfig(
main_prompts=self.main_prompts.union(other.main_prompts),
rating=overlay(other.rating, self.rating),
max_caption_length=overlay(other.max_caption_length, self.max_caption_length),
tags=self.tags.union(other.tags),
multiply=overlay(other.multiply, self.multiply),
cond_dropout=overlay(other.cond_dropout, self.cond_dropout),
flip_p=overlay(other.flip_p, self.flip_p),
)
@classmethod
def from_dict(cls, data: dict):
# Parse standard yaml tag file (with options)
parsed_cfg = ImageConfig(
main_prompts=safe_set(data.get("main_prompt")),
rating=data.get("rating"),
max_caption_length=data.get("max_caption_length"),
tags=safe_set(map(Tag.parse, data.get("tags", []))),
multiply=data.get("multiply"),
cond_dropout=data.get("cond_dropout"),
flip_p=data.get("flip_p"),
)
# Alternatively parse from dedicated `caption` attribute
if cap_attr := data.get('caption'):
parsed_cfg = parsed_cfg.merge(ImageConfig.parse(cap_attr))
return parsed_cfg
@classmethod
def fold(cls, configs):
acc = ImageConfig()
for cfg in configs:
acc = acc.merge(cfg)
return acc
def ensure_caption(self):
return self
@classmethod
def from_caption_text(cls, text: str):
if not text:
return ImageConfig()
if os.path.isfile(text):
return ImageConfig.from_file(text)
split_caption = list(map(str.strip, text.split(",")))
return ImageConfig(
main_prompts=split_caption[0],
tags=map(Tag.parse, split_caption[1:])
)
@classmethod
def from_file(cls, file: str):
match ext(file):
case '.jpg' | '.jpeg' | '.png' | '.bmp' | '.webp' | '.jfif':
return ImageConfig(image=file)
case ".json":
return ImageConfig.from_dict(json.load(read_text(file)))
case ".yaml" | ".yml":
return ImageConfig.from_dict(yaml.safe_load(read_text(file)))
case ".txt" | ".caption":
return ImageConfig.from_caption_text(read_text(file))
case _:
return logging.warning(" *** Unrecognized config extension {ext}")
@classmethod
def parse(cls, input):
if isinstance(input, str):
if os.path.isfile(input):
return ImageConfig.from_file(input)
else:
return ImageConfig.from_caption_text(input)
elif isinstance(input, dict):
return ImageConfig.from_dict(input)
@define()
class Dataset:
image_configs: dict[str, ImageConfig]
def __global_cfg(files):
cfgs = []
for file in files:
match os.path.basename(file):
case 'global.yaml' | 'global.yml':
cfgs.append(ImageConfig.from_file(file))
return ImageConfig.fold(cfgs)
def __local_cfg(files):
cfgs = []
for file in files:
match os.path.basename(file):
case 'multiply.txt':
cfgs.append(ImageConfig(multiply=read_float(file)))
case 'cond_dropout.txt':
cfgs.append(ImageConfig(cond_dropout=read_float(file)))
case 'flip_p.txt':
cfgs.append(ImageConfig(flip_p=read_float(file)))
case 'local.yaml' | 'local.yml':
cfgs.append(ImageConfig.from_file(file))
return ImageConfig.fold(cfgs)
def __sidecar_cfg(imagepath, files):
cfgs = []
for file in files:
if same_barename(imagepath, file):
match ext(file):
case '.txt' | '.caption' | '.yml' | '.yaml':
cfgs.append(ImageConfig.from_file(file))
return ImageConfig.fold(cfgs)
# Use file name for caption only as a last resort
@classmethod
def __ensure_caption(cls, cfg: ImageConfig, file: str):
if cfg.main_prompts or cfg.tags:
return cfg
cap_cfg = ImageConfig.from_caption_text(barename(file).split("_")[0])
return cfg.merge(cap_cfg)
@classmethod
def from_path(cls, data_root):
# Create a visitor that maintains global config stack
# and accumulates image configs as it traverses dataset
image_configs = {}
def process_dir(files, parent_globals):
global_cfg = parent_globals.merge(Dataset.__global_cfg(files))
local_cfg = Dataset.__local_cfg(files)
for img in filter(is_image, files):
img_cfg = Dataset.__sidecar_cfg(img, files)
resolved_cfg = ImageConfig.fold([global_cfg, local_cfg, img_cfg])
image_configs[img] = Dataset.__ensure_caption(resolved_cfg, img)
return global_cfg
walk_and_visit(data_root, process_dir, ImageConfig())
return Dataset(image_configs)
@classmethod
def from_json(cls, json_path):
"""
Import a dataset definition from a JSON file
"""
image_configs = {}
with open(json_path, encoding='utf-8', mode='r') as stream:
for data in json.load(stream):
img = data.get("image")
cfg = Dataset.__ensure_caption(ImageConfig.parse(data), img)
if not img:
logging.warning(f" *** Error parsing json image entry in {json_path}: {data}")
continue
image_configs[img] = cfg
return Dataset(image_configs)
def image_train_items(self, aspects):
items = []
for image in self.image_configs:
config = self.image_configs[image]
if len(config.main_prompts) > 1:
logging.warning(f" *** Found multiple multiple main_prompts for image {image}, but only one will be applied: {config.main_prompts}")
tags = []
tag_weights = []
for tag in sorted(config.tags):
tags.append(tag.value)
tag_weights.append(tag.weight)
use_weights = len(set(tag_weights)) > 1
caption = ImageCaption(
main_prompt=next(iter(sorted(config.main_prompts))),
rating=config.rating or 1.0,
tags=tags,
tag_weights=tag_weights,
max_target_length=config.max_caption_length,
use_weights=use_weights)
item = ImageTrainItem(
image=None,
caption=caption,
aspects=aspects,
pathname=os.path.abspath(image),
flip_p=config.flip_p or 0.0,
multiplier=config.multiply or 1.0,
cond_dropout=config.cond_dropout
)
items.append(item)
return list(sorted(items, key=lambda ti: ti.pathname))

View File

@ -104,7 +104,7 @@ class EveryDreamBatch(Dataset):
example["image"] = image_transforms(train_item["image"])
if random.random() > self.conditional_dropout:
if random.random() > (train_item.get("cond_dropout", self.conditional_dropout)):
example["tokens"] = self.tokenizer(example["caption"],
truncation=True,
padding="max_length",
@ -132,6 +132,8 @@ class EveryDreamBatch(Dataset):
example["image"] = image_train_tmp.image.copy() # hack for now to avoid memory leak
image_train_tmp.image = None # hack for now to avoid memory leak
example["caption"] = image_train_tmp.caption
if image_train_tmp.cond_dropout is not None:
example["cond_dropout"] = image_train_tmp.cond_dropout
example["runt_size"] = image_train_tmp.runt_size
return example

View File

@ -113,136 +113,6 @@ class ImageCaption:
random.Random(seed).shuffle(tags)
return ", ".join(tags)
@staticmethod
def parse(string: str) -> 'ImageCaption':
"""
Parses a string to get the caption.
:param string: String to parse.
:return: `ImageCaption` object.
"""
split_caption = list(map(str.strip, string.split(",")))
main_prompt = split_caption[0]
tags = split_caption[1:]
tag_weights = [1.0] * len(tags)
return ImageCaption(main_prompt, 1.0, tags, tag_weights, DEFAULT_MAX_CAPTION_LENGTH, False)
@staticmethod
def from_file_name(file_path: str) -> 'ImageCaption':
"""
Parses the file name to get the caption.
:param file_path: Path to the image file.
:return: `ImageCaption` object.
"""
(file_name, _) = os.path.splitext(os.path.basename(file_path))
caption = file_name.split("_")[0]
return ImageCaption.parse(caption)
@staticmethod
def from_text_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
"""
Parses a text file to get the caption. Returns the default caption if
the file does not exist or is invalid.
:param file_path: Path to the text file.
:param default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid.
:return: `ImageCaption` object or `None`.
"""
try:
with open(file_path, encoding='utf-8', mode='r') as caption_file:
caption_text = caption_file.read()
return ImageCaption.parse(caption_text)
except:
logging.error(f" *** Error reading {file_path} to get caption")
return default_caption
@staticmethod
def from_yaml_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
"""
Parses a yaml file to get the caption. Returns the default caption if
the file does not exist or is invalid.
:param file_path: path to the yaml file
:param default_caption: caption to return if the file does not exist or is invalid
:return: `ImageCaption` object or `None`.
"""
try:
with open(file_path, "r") as stream:
file_content = yaml.safe_load(stream)
main_prompt = file_content.get("main_prompt", "")
rating = file_content.get("rating", 1.0)
unparsed_tags = file_content.get("tags", [])
max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH)
tags = []
tag_weights = []
last_weight = None
weights_differ = False
for unparsed_tag in unparsed_tags:
tag = unparsed_tag.get("tag", "").strip()
if len(tag) == 0:
continue
tags.append(tag)
tag_weight = unparsed_tag.get("weight", 1.0)
tag_weights.append(tag_weight)
if last_weight is not None and weights_differ is False:
weights_differ = last_weight != tag_weight
last_weight = tag_weight
return ImageCaption(main_prompt, rating, tags, tag_weights, max_caption_length, weights_differ)
except:
logging.error(f" *** Error reading {file_path} to get caption")
return default_caption
@staticmethod
def from_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
"""
Try to resolve a caption from a file path or return `default_caption`.
:string: The path to the file to parse.
:default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid.
:return: `ImageCaption` object or `None`.
"""
if os.path.exists(file_path):
(file_path_without_ext, ext) = os.path.splitext(file_path)
match ext:
case ".yaml" | ".yml":
return ImageCaption.from_yaml_file(file_path, default_caption)
case ".txt" | ".caption":
return ImageCaption.from_text_file(file_path, default_caption)
case '.jpg'| '.jpeg'| '.png'| '.bmp'| '.webp'| '.jfif':
for ext in [".yaml", ".yml", ".txt", ".caption"]:
file_path = file_path_without_ext + ext
image_caption = ImageCaption.from_file(file_path)
if image_caption is not None:
return image_caption
return ImageCaption.from_file_name(file_path)
case _:
return default_caption
else:
return default_caption
@staticmethod
def resolve(string: str) -> 'ImageCaption':
"""
Try to resolve a caption from a string. If the string is a file path,
the caption will be read from the file, otherwise the string will be
parsed as a caption.
:string: The string to resolve.
:return: `ImageCaption` object.
"""
return ImageCaption.from_file(string, None) or ImageCaption.parse(string)
class ImageTrainItem:
"""
@ -253,7 +123,7 @@ class ImageTrainItem:
flip_p: probability of flipping image (0.0 to 1.0)
rating: the relative rating of the images. The rating is measured in comparison to the other images.
"""
def __init__(self, image: PIL.Image, caption: ImageCaption, aspects: list[float], pathname: str, flip_p=0.0, multiplier: float=1.0):
def __init__(self, image: PIL.Image, caption: ImageCaption, aspects: list[float], pathname: str, flip_p=0.0, multiplier: float=1.0, cond_dropout=None):
self.caption = caption
self.aspects = aspects
self.pathname = pathname
@ -261,6 +131,7 @@ class ImageTrainItem:
self.cropped_img = None
self.runt_size = 0
self.multiplier = multiplier
self.cond_dropout = cond_dropout
self.image_size = None
if image is None or len(image) == 0:

View File

@ -4,6 +4,7 @@ import os
import typing
import zipfile
import argparse
from data.dataset import Dataset
import tqdm
from colorama import Fore, Style
@ -27,16 +28,6 @@ class DataResolver:
"""
raise NotImplementedError()
def image_train_item(self, image_path: str, caption: ImageCaption, multiplier: float=1) -> ImageTrainItem:
return ImageTrainItem(
image=None,
caption=caption,
aspects=self.aspects,
pathname=image_path,
flip_p=self.flip_p,
multiplier=multiplier
)
class JSONResolver(DataResolver):
def image_train_items(self, json_path: str) -> list[ImageTrainItem]:
"""
@ -45,62 +36,8 @@ class JSONResolver(DataResolver):
:param json_path: The path to the JSON file.
"""
items = []
with open(json_path, encoding='utf-8', mode='r') as f:
json_data = json.load(f)
for data in tqdm.tqdm(json_data):
caption = JSONResolver.image_caption(data)
if caption:
image_value = JSONResolver.get_image_value(data)
item = self.image_train_item(image_value, caption)
if item:
items.append(item)
return items
return Dataset.from_json(json_path).image_train_items(self.aspects)
@staticmethod
def get_image_value(json_data: dict) -> typing.Optional[str]:
"""
Get the image from the json data if possible.
:param json_data: The json data, a dict.
:return: The image, or None if not found.
"""
image_value = json_data.get("image", None)
if isinstance(image_value, str):
image_value = image_value.strip()
if os.path.exists(image_value):
return image_value
@staticmethod
def get_caption_value(json_data: dict) -> typing.Optional[str]:
"""
Get the caption from the json data if possible.
:param json_data: The json data, a dict.
:return: The caption, or None if not found.
"""
caption_value = json_data.get("caption", None)
if isinstance(caption_value, str):
return caption_value.strip()
@staticmethod
def image_caption(json_data: dict) -> typing.Optional[ImageCaption]:
"""
Get the caption from the json data if possible.
:param json_data: The json data, a dict.
:return: The `ImageCaption`, or None if not found.
"""
image_value = JSONResolver.get_image_value(json_data)
caption_value = JSONResolver.get_caption_value(json_data)
if image_value:
if caption_value:
return ImageCaption.resolve(caption_value)
return ImageCaption.from_file(image_value)
class DirectoryResolver(DataResolver):
def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
"""
@ -111,32 +48,7 @@ class DirectoryResolver(DataResolver):
:param data_root: The root directory to recurse through
"""
DirectoryResolver.unzip_all(data_root)
image_paths = list(DirectoryResolver.recurse_data_root(data_root))
items = []
multipliers = {}
for pathname in tqdm.tqdm(image_paths):
current_dir = os.path.dirname(pathname)
if current_dir not in multipliers:
multiply_txt_path = os.path.join(current_dir, "multiply.txt")
if os.path.exists(multiply_txt_path):
try:
with open(multiply_txt_path, 'r') as f:
val = float(f.read().strip())
multipliers[current_dir] = val
logging.info(f" - multiply.txt in '{current_dir}' set to {val}")
except Exception as e:
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}")
multipliers[current_dir] = 1.0
else:
multipliers[current_dir] = 1.0
caption = ImageCaption.resolve(pathname)
item = self.image_train_item(pathname, caption, multiplier=multipliers[current_dir])
items.append(item)
return items
return Dataset.from_path(data_root).image_train_items(self.aspects)
@staticmethod
def unzip_all(path):
@ -150,21 +62,6 @@ class DirectoryResolver(DataResolver):
except Exception as e:
logging.error(f"Error unzipping files {e}")
@staticmethod
def recurse_data_root(recurse_root):
for f in os.listdir(recurse_root):
current = os.path.join(recurse_root, f)
if os.path.isfile(current):
ext = os.path.splitext(f)[1].lower()
if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']:
yield current
for d in os.listdir(recurse_root):
current = os.path.join(recurse_root, d)
if os.path.isdir(current):
yield from DirectoryResolver.recurse_data_root(current)
def strategy(data_root: str) -> typing.Type[DataResolver]:
"""
Determine the strategy to use for resolving the data.

View File

@ -29,7 +29,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m venv ${VIRTUAL_ENV} && \
pip install -U -I torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url "https://download.pytorch.org/whl/cu117" && \
pip install -r requirements.txt && \
pip install --pre --no-deps xformers==0.0.17.dev451
pip install --pre -U --no-deps xformers
# In case of emergency, build xformers from scratch
# export FORCE_CUDA=1 && export TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6" && export CUDA_VISIBLE_DEVICES=0 && \
# pip install --no-deps git+https://github.com/facebookresearch/xformers.git@48a77cc#egg=xformers

View File

@ -10,6 +10,7 @@ ninja
omegaconf==2.2.3
piexif==1.1.3
protobuf==3.20.3
pyfakefs
pynvml==11.5.0
pyre-extensions==0.0.30
pytorch-lightning==1.9.2

View File

@ -58,13 +58,13 @@ class TestResolve(unittest.TestCase):
def test_directory_resolve_with_str(self):
items = resolver.resolve(DATA_PATH, ARGS)
image_paths = [item.pathname for item in items]
image_paths = set(item.pathname for item in items)
image_captions = [item.caption for item in items]
captions = [caption.get_caption() for caption in image_captions]
captions = set(caption.get_caption() for caption in image_captions)
self.assertEqual(len(items), 3)
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH])
self.assertEqual(captions, ['caption for test1', 'test2', 'test3'])
self.assertEqual(image_paths, {IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH})
self.assertEqual(captions, {'caption for test1', 'test2', 'test3'})
undersized_images = list(filter(lambda i: i.is_undersized, items))
self.assertEqual(len(undersized_images), 1)

289
test/test_dataset.py Normal file
View File

@ -0,0 +1,289 @@
import os
from data.dataset import Dataset, ImageConfig, Tag
from textwrap import dedent
from pyfakefs.fake_filesystem_unittest import TestCase
class TestDataset(TestCase):
def setUp(self):
self.maxDiff = None
self.setUpPyfakefs()
def test_a_caption_is_generated_from_image_given_no_other_config(self):
self.fs.create_file("image, tag1, tag2.jpg")
actual = Dataset.from_path(".").image_configs
expected = {
"./image, tag1, tag2.jpg": ImageConfig(main_prompts="image", tags=frozenset([Tag("tag1"), Tag("tag2")]))
}
self.assertEqual(expected, actual)
def test_several_image_formats_are_supported(self):
self.fs.create_file("image.JPG")
self.fs.create_file("image.jpeg")
self.fs.create_file("image.png")
self.fs.create_file("image.webp")
self.fs.create_file("image.jfif")
self.fs.create_file("image.bmp")
actual = Dataset.from_path(".").image_configs
common_cfg = ImageConfig(main_prompts="image")
expected = {
"./image.JPG": common_cfg,
"./image.jpeg": common_cfg,
"./image.png": common_cfg,
"./image.webp": common_cfg,
"./image.jfif": common_cfg,
"./image.bmp": common_cfg,
}
self.assertEqual(expected, actual)
def test_captions_can_be_read_from_txt_or_caption_sidecar(self):
self.fs.create_file("image_1.jpg")
self.fs.create_file("image_1.txt", contents="an image, test, from .txt")
self.fs.create_file("image_2.jpg")
self.fs.create_file("image_2.caption", contents="an image, test, from .caption")
actual = Dataset.from_path(".").image_configs
expected = {
"./image_1.jpg": ImageConfig(main_prompts="an image", tags=frozenset([Tag("test"), Tag("from .txt")])),
"./image_2.jpg": ImageConfig(main_prompts="an image", tags=frozenset([Tag("test"), Tag("from .caption")]))
}
self.assertEqual(expected, actual)
def test_captions_and_options_can_be_read_from_yaml_sidecar(self):
self.fs.create_file("image_1.jpg")
self.fs.create_file("image_1.yaml",
contents=dedent("""
multiply: 2
cond_dropout: 0.05
flip_p: 0.5
caption: "A simple caption, from .yaml"
"""))
self.fs.create_file("image_2.jpg")
self.fs.create_file("image_2.yml",
contents=dedent("""
flip_p: 0.0
caption:
main_prompt: A complex caption
rating: 1.1
max_caption_length: 1024
tags:
- tag: from .yml
- tag: with weight
weight: 0.5
"""))
actual = Dataset.from_path(".").image_configs
expected = {
"./image_1.jpg": ImageConfig(
multiply=2,
cond_dropout=0.05,
flip_p=0.5,
main_prompts="A simple caption",
tags= { Tag("from .yaml") }
),
"./image_2.jpg": ImageConfig(
flip_p=0.0,
rating=1.1,
max_caption_length=1024,
main_prompts="A complex caption",
tags= { Tag("from .yml"), Tag("with weight", weight=0.5) }
)
}
self.assertEqual(expected, actual)
def test_multiple_prompts_and_tags_from_multiple_sidecars_are_supported(self):
self.fs.create_file("image_1.jpg")
self.fs.create_file("image_1.yaml", contents=dedent("""
main_prompt:
- unique prompt
- dupe prompt
tags:
- from .yaml
- dupe tag
"""))
self.fs.create_file("image_1.txt", contents="also unique prompt, from .txt, dupe tag")
self.fs.create_file("image_1.caption", contents="dupe prompt, from .caption")
actual = Dataset.from_path(".").image_configs
expected = {
"./image_1.jpg": ImageConfig(
main_prompts={ "unique prompt", "also unique prompt", "dupe prompt" },
tags={ Tag("from .yaml"), Tag("from .txt"), Tag("from .caption"), Tag("dupe tag") }
)
}
self.assertEqual(expected, actual)
def test_sidecars_can_also_be_attached_to_local_and_recursive_folders(self):
self.fs.create_file("./global.yaml",
contents=dedent("""\
main_prompt: global prompt
tags:
- global tag
flip_p: 0.0
"""))
self.fs.create_file("./local.yaml",
contents=dedent("""
main_prompt: local prompt
tags:
- tag: local tag
"""))
self.fs.create_file("./arbitrary filename.png")
self.fs.create_file("./sub/sub arbitrary filename.png")
self.fs.create_file("./sub/sidecar.png")
self.fs.create_file("./sub/sidecar.txt",
contents="sidecar prompt, sidecar tag")
self.fs.create_file("./optfile/optfile.png")
self.fs.create_file("./optfile/flip_p.txt",
contents="0.1234")
self.fs.create_file("./sub/sub2/global.yaml",
contents=dedent("""
tags:
- tag: sub global tag
"""))
self.fs.create_file("./sub/sub2/local.yaml",
contents=dedent("""
tags:
- This tag wil not apply to any files
"""))
self.fs.create_file("./sub/sub2/sub3/xyz.png")
actual = Dataset.from_path(".").image_configs
expected = {
"./arbitrary filename.png": ImageConfig(
main_prompts={ 'global prompt', 'local prompt' },
tags={ Tag("global tag"), Tag("local tag") },
flip_p=0.0
),
"./sub/sub arbitrary filename.png": ImageConfig(
main_prompts={ 'global prompt' },
tags={ Tag("global tag") },
flip_p=0.0
),
"./sub/sidecar.png": ImageConfig(
main_prompts={ 'global prompt', 'sidecar prompt' },
tags={ Tag("global tag"), Tag("sidecar tag") },
flip_p=0.0
),
"./optfile/optfile.png": ImageConfig(
main_prompts={ 'global prompt' },
tags={ Tag("global tag") },
flip_p=0.1234
),
"./sub/sub2/sub3/xyz.png": ImageConfig(
main_prompts={ 'global prompt' },
tags={ Tag("global tag"), Tag("sub global tag") },
flip_p=0.0
)
}
self.assertEqual(expected, actual)
def test_can_load_dataset_from_json_manifest(self):
self.fs.create_file("./stuff/image_1.jpg")
self.fs.create_file("./stuff/default.caption", contents= "default caption")
self.fs.create_file("./other/image_1.jpg")
self.fs.create_file("./other/image_2.jpg")
self.fs.create_file("./other/image_3.jpg")
self.fs.create_file("./manifest.json", contents=dedent("""
[
{ "image": "./stuff/image_1.jpg", "caption": "./stuff/default.caption" },
{ "image": "./other/image_1.jpg", "caption": "other caption" },
{
"image": "./other/image_2.jpg",
"caption": {
"main_prompt": "complex caption",
"rating": 0.1,
"max_caption_length": 1000,
"tags": [
{"tag": "including"},
{"tag": "weighted tag", "weight": 999.9}
]
}
},
{
"image": "./other/image_3.jpg",
"multiply": 2,
"flip_p": 0.5,
"cond_dropout": 0.01,
"main_prompt": [
"first caption",
"second caption"
]
}
]
"""))
actual = Dataset.from_json("./manifest.json").image_configs
expected = {
"./stuff/image_1.jpg": ImageConfig( main_prompts={"default caption"} ),
"./other/image_1.jpg": ImageConfig( main_prompts={"other caption"} ),
"./other/image_2.jpg": ImageConfig(
main_prompts={ "complex caption" },
rating=0.1,
max_caption_length=1000,
tags={
Tag("including"),
Tag("weighted tag", 999.9)
}
),
"./other/image_3.jpg": ImageConfig(
main_prompts={ "first caption", "second caption" },
multiply=2,
flip_p=0.5,
cond_dropout=0.01
)
}
self.assertEqual(expected, actual)
def test_dataset_can_produce_train_items(self):
dataset = Dataset({
"1.jpg": ImageConfig(
multiply=2,
flip_p=0.1,
cond_dropout=0.01,
main_prompts=["first caption","second caption"],
rating = 1.1,
max_caption_length=1024,
tags=frozenset([
Tag("tag"),
Tag("tag_2", 2.0)
])),
"2.jpg": ImageConfig( main_prompts="single caption")
})
aspects = []
actual = dataset.image_train_items(aspects)
self.assertEqual(len(actual), 2)
self.assertEqual(actual[0].pathname, os.path.abspath('1.jpg'))
self.assertEqual(actual[0].multiplier, 2.0)
self.assertEqual(actual[0].flip.p, 0.1)
self.assertEqual(actual[0].cond_dropout, 0.01)
self.assertEqual(actual[0].caption.rating(), 1.1)
self.assertEqual(actual[0].caption.get_caption(), "first caption, tag, tag_2")
# Can't test this
# self.assertTrue(actual[0].caption.__use_weights)
self.assertEqual(actual[1].pathname, os.path.abspath('2.jpg'))
self.assertEqual(actual[1].multiplier, 1.0)
self.assertEqual(actual[1].flip.p, 0.0)
self.assertIsNone(actual[1].cond_dropout)
self.assertEqual(actual[1].caption.rating(), 1.0)
self.assertEqual(actual[1].caption.get_caption(), "single caption")
# Can't test this
# self.assertFalse(actual[1].caption.__use_weights)

View File

@ -32,40 +32,4 @@ class TestImageCaption(unittest.TestCase):
self.assertEqual(caption.get_caption(), "hello world, one, two, three")
caption = ImageCaption("hello world", 1.0, [], [], 2048, False)
self.assertEqual(caption.get_caption(), "hello world")
def test_parse(self):
caption = ImageCaption.parse("hello world, one, two, three")
self.assertEqual(caption.get_caption(), "hello world, one, two, three")
def test_from_file_name(self):
caption = ImageCaption.from_file_name("foo bar_1_2_3.jpg")
self.assertEqual(caption.get_caption(), "foo bar")
def test_from_text_file(self):
caption = ImageCaption.from_text_file("test/data/test1.txt")
self.assertEqual(caption.get_caption(), "caption for test1")
def test_from_file(self):
caption = ImageCaption.from_file("test/data/test1.txt")
self.assertEqual(caption.get_caption(), "caption for test1")
caption = ImageCaption.from_file("test/data/test_caption.caption")
self.assertEqual(caption.get_caption(), "caption for test2")
def test_resolve(self):
caption = ImageCaption.resolve("test/data/test1.txt")
self.assertEqual(caption.get_caption(), "caption for test1")
caption = ImageCaption.resolve("test/data/test_caption.caption")
self.assertEqual(caption.get_caption(), "caption for test2")
caption = ImageCaption.resolve("hello world")
self.assertEqual(caption.get_caption(), "hello world")
caption = ImageCaption.resolve("test/data/test1.jpg")
self.assertEqual(caption.get_caption(), "caption for test1")
caption = ImageCaption.resolve("test/data/test2.jpg")
self.assertEqual(caption.get_caption(), "test2")
self.assertEqual(caption.get_caption(), "hello world")

46
utils/fs_helpers.py Normal file
View File

@ -0,0 +1,46 @@
def barename(file):
(val, _) = os.path.splitext(os.path.basename(file))
return val
def ext(file):
(_, val) = os.path.splitext(os.path.basename(file))
return val.lower()
def same_barename(lhs, rhs):
return barename(lhs) == barename(rhs)
def is_image(file):
return ext(file) in {'.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif'}
def read_text(file):
try:
with open(file, encoding='utf-8', mode='r') as stream:
return stream.read().strip()
except Exception as e:
logging.warning(f" *** Error reading text file {file}: {e}")
def read_float(file):
try:
return float(read_text(file))
except Exception as e:
logging.warning(f" *** Could not parse '{data}' to float in file {file}: {e}")
import os
def walk_and_visit(path, visit_fn, context=None):
names = [entry.name for entry in os.scandir(path)]
dirs = []
files = []
for name in names:
fullname = os.path.join(path, name)
if os.path.isdir(fullname) and not str(name).startswith('.'):
dirs.append(fullname)
else:
files.append(fullname)
subcontext = visit_fn(files, context)
for subdir in dirs:
walk_and_visit(subdir, visit_fn, subcontext)