conf
This commit is contained in:
commit
605716a646
|
@ -0,0 +1,36 @@
|
|||
---
|
||||
name: Bug report
|
||||
about: For bugs that are NOT ERRORS
|
||||
title: "[BUG]"
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**If you are getting an error, use the ERROR template, not this BUG template. **
|
||||
|
||||
Have you joined discord? You're far more likely to get a response there: https://discord.gg/uheqxU6sXN
|
||||
|
||||
**To Reproduce**
|
||||
Steps to reproduce the undesired behavior:
|
||||
1. Configure training in this way (please attach _cfg.json)
|
||||
2. Execute training with "this" command
|
||||
|
||||
**Describe expected behavior and actual behavior**
|
||||
ex. It does XYZ but should do ABC instead
|
||||
ex. It does not do ABC when it should do ABC
|
||||
ex. It should not do XYZ at all
|
||||
|
||||
**Describe why you think this is a bug**
|
||||
It should do ABC because... XYZ is wrong because...
|
||||
|
||||
**Attach log and cfg**
|
||||
*PLEASE* attach the ".log" and "_cfg.json" from your logs folder for the run. These are in the "logs" folder under "project_name_timestamp" subfolder. This will assist greatly in identifying problems with configurations or system problems that may be causing your problems.
|
||||
|
||||
**Runtime environment (please complete the following information):**
|
||||
- OS: [e.g. Windows 10, Ubuntu Linux 22.04, etc]
|
||||
- Is this your local computer or a cloud host? Please list the cloud host (Vast, Google Colab, etc)
|
||||
- GPU [e.g. 3090 24GB, A100 40GB, 2080 Ti 11GB, etc]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
|
@ -0,0 +1,22 @@
|
|||
---
|
||||
name: Error Report
|
||||
about: For ERRORS that halt training
|
||||
title: "[ERROR]"
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Attach log and cfg**
|
||||
*PLEASE* attach the ".log" and "_cfg.json" from your logs folder for the run that failed. This is absolutely critical to providing assistance. These files are always generated and saved in the "logs" folder under project_name_timestamp folder every run. For cloud hosts, you can download the files. For Google Colab these are likely being saved to your Gdrive.
|
||||
|
||||
**Runtime environment (please complete the following information):**
|
||||
- OS: [e.g. Windows 10, Ubuntu Linux 22.04, etc]
|
||||
- Is this your local computer or a cloud host? Please list the cloud host (Vast, Google Colab, etc)
|
||||
- GPU [e.g. 3090 24GB, A100 40GB]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
||||
|
||||
Also consider posting your cfg and log to the Discord #help channel here instead, there are far more people there than will read your issue here on Github:
|
||||
Have you joined discord? You're far more likely to get a response there: https://discord.gg/uheqxU6sXN
|
|
@ -446,7 +446,7 @@
|
|||
"\n",
|
||||
"\n",
|
||||
"if Disconnect_after_training :\n",
|
||||
" time.sleep(3)\n",
|
||||
" time.sleep(40)\n",
|
||||
" runtime.unassign()"
|
||||
]
|
||||
},
|
||||
|
|
|
@ -51,10 +51,72 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3d9b0db8-c2b1-4f0a-b835-b6b2ef527019",
|
||||
"id": "f15fcd56-0418-4be1-a5c3-38aa679b1aaf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### HuggingFace Login\n",
|
||||
"# Start Training\n",
|
||||
"Naming your project will help you track what the heck you're doing when you're floating in checkpoint files later.\n",
|
||||
"\n",
|
||||
"You may wish to consider adding \"sd1\" or \"sd2v\" or similar to remember what the base was, as you'll also have to tell your inference app what you were using, as its difficult for programs to know what inference YAML to use automatically. For instance, Automatic1111 webui requires you to copy the v2 inference YAML and rename it to match your checkpoint name so it knows how to load the file, tough it assumes SD 1.x compatible. Something to keep in mind if you start training on SD2.1.\n",
|
||||
"\n",
|
||||
"`max_epochs`, `sample_steps`, and `save_every_n_epochs` should be tuned to your dataset. I like to generate one or two sets of samples per save, and aim for 5 (give or take 2) saved checkpoints.\n",
|
||||
"\n",
|
||||
"Next cell runs training. This will take a while depending on your number of images, repeats, and max_epochs.\n",
|
||||
"\n",
|
||||
"You can watch for test images in the logs folder.\n",
|
||||
"\n",
|
||||
"#### Weights and Balanaces\n",
|
||||
"I you pass the `--wandb` flag you will be prompted for your W&B `API Key`. W&B is a free online logging utility. If you don't have a W&B account, you can create one for free at https://wandb.ai/site. Your key is on this page: https://wandb.ai/settings under \"Danger Zone\" \"API Keys\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6f73fb86-ebef-41e2-9382-4aa11be84be6",
|
||||
"metadata": {
|
||||
"scrolled": true,
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%run train.py --config train.json \\\n",
|
||||
"--resume_ckpt \"panopstor/EveryDream\" \\\n",
|
||||
"--project_name \"sd1_mymodel\" \\\n",
|
||||
"--data_root \"input\" \\\n",
|
||||
"--max_epochs 200 \\\n",
|
||||
"--sample_steps 150 \\\n",
|
||||
"--save_every_n_epochs 35 \\\n",
|
||||
"--lr 1.2e-6 \\\n",
|
||||
"--lr_scheduler constant \\\n",
|
||||
"--save_full_precision\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ed464c6b-1a8d-48e4-9787-265e8acaac43",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Optionally you can chain trainings together using multiple configurations combined with `resume_ckpt: findlast`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "492350d4-9b2f-4d2a-9641-1f723125b296",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%run train.py --config chain0.json --project_name \"sd1_chain_a\" --data_root \"input\" --resume_ckpt \"panopstor/EveryDream\"\n",
|
||||
"%run train.py --config chain1.json --project_name \"sd1_chain_b\" --data_root \"input\" --resume_ckpt findlast\n",
|
||||
"%run train.py --config chain2.json --project_name \"sd1_chain_c\" --data_root \"input\" --resume_ckpt findlast"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3c506e79-bf03-4e34-bf06-9371963d4d7d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# HuggingFace download (Optional)\n",
|
||||
"Run the cell below and paste your token into the prompt. You can get your token from your [huggingface account page](https://huggingface.co/settings/tokens).\n",
|
||||
"\n",
|
||||
"The token will not show on the screen, just press enter after you paste it."
|
||||
|
@ -74,10 +136,10 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7a96f2af-8c93-4460-aa9e-2ff795fb06ea",
|
||||
"id": "b252a308-49cf-443f-abbb-d08b471411fb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Then run the following cell to download the base checkpoint (may take a minute)."
|
||||
"Then run the following cell to download the base checkpoint (may take a minute)."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -111,68 +173,6 @@
|
|||
"print(\"DONE\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f15fcd56-0418-4be1-a5c3-38aa679b1aaf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Start Training\n",
|
||||
"Naming your project will help you track what the heck you're doing when you're floating in checkpoint files later.\n",
|
||||
"\n",
|
||||
"You may wish to consider adding \"sd1\" or \"sd2v\" or similar to remember what the base was, as you'll also have to tell your inference app what you were using, as its difficult for programs to know what inference YAML to use automatically. For instance, Automatic1111 webui requires you to copy the v2 inference YAML and rename it to match your checkpoint name so it knows how to load the file, tough it assumes SD 1.x compatible. Something to keep in mind if you start training on SD2.1.\n",
|
||||
"\n",
|
||||
"`max_epochs`, `sample_steps`, and `save_every_n_epochs` should be tuned to your dataset. I like to generate one or two sets of samples per save, and aim for 5 (give or take 2) saved checkpoints.\n",
|
||||
"\n",
|
||||
"Next cell runs training. This will take a while depending on your number of images, repeats, and max_epochs.\n",
|
||||
"\n",
|
||||
"You can watch for test images in the logs folder.\n",
|
||||
"\n",
|
||||
"## Weights and Balanaces\n",
|
||||
"I you pass the `--wandb` flag you will be prompted for your W&B `API Key`. W&B is a free online logging utility. If you don't have a W&B account, you can create one for free at https://wandb.ai/site. Your key is on this page: https://wandb.ai/settings under \"Danger Zone\" \"API Keys\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6f73fb86-ebef-41e2-9382-4aa11be84be6",
|
||||
"metadata": {
|
||||
"scrolled": true,
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%run train.py --config train.json \\\n",
|
||||
"--resume_ckpt \"sd_v1-5_vae\" \\\n",
|
||||
"--project_name \"sd1_mymodel\" \\\n",
|
||||
"--data_root \"input\" \\\n",
|
||||
"--max_epochs 200 \\\n",
|
||||
"--sample_steps 150 \\\n",
|
||||
"--save_every_n_epochs 35 \\\n",
|
||||
"--lr 1.2e-6 \\\n",
|
||||
"--lr_scheduler constant \\\n",
|
||||
"--save_full_precision\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ed464c6b-1a8d-48e4-9787-265e8acaac43",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Optionally you can chain trainings together using multiple configurations combined with `resume_ckpt: findlast`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "492350d4-9b2f-4d2a-9641-1f723125b296",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%run train.py --config chain0.json --project_name \"sd1_chain_a\" --data_root \"input\" --resume_ckpt \"{ckpt_name}\"\n",
|
||||
"%run train.py --config chain1.json --project_name \"sd1_chain_b\" --data_root \"input\" --resume_ckpt findlast\n",
|
||||
"%run train.py --config chain2.json --project_name \"sd1_chain_c\" --data_root \"input\" --resume_ckpt findlast"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f24eee3d-f5df-45f3-9acc-ee0206cfe6b1",
|
||||
|
@ -351,7 +351,7 @@
|
|||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
@ -365,7 +365,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.10"
|
||||
"version": "3.10.6"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
|
|
|
@ -0,0 +1,249 @@
|
|||
import os
|
||||
import logging
|
||||
import yaml
|
||||
import json
|
||||
|
||||
from functools import total_ordering
|
||||
from attrs import define, field, Factory
|
||||
from data.image_train_item import ImageCaption, ImageTrainItem
|
||||
from utils.fs_helpers import *
|
||||
from typing import TypeVar, Iterable
|
||||
|
||||
|
||||
def overlay(overlay, base):
|
||||
return overlay if overlay is not None else base
|
||||
|
||||
def safe_set(val):
|
||||
if isinstance(val, str):
|
||||
return {val} if val else {}
|
||||
|
||||
if isinstance(val, Iterable):
|
||||
return {i for i in val if i is not None}
|
||||
|
||||
return val or {}
|
||||
|
||||
@define(frozen=True)
|
||||
@total_ordering
|
||||
class Tag:
|
||||
value: str
|
||||
weight: float = field(default=1.0, converter=lambda x: x if x is not None else 1.0)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, data):
|
||||
if isinstance(data, str):
|
||||
return Tag(data)
|
||||
|
||||
if isinstance(data, dict):
|
||||
value = data.get("tag")
|
||||
weight = data.get("weight")
|
||||
if value:
|
||||
return Tag(value, weight)
|
||||
|
||||
return None
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.weight < other.weight and self.value < other.value
|
||||
|
||||
@define
|
||||
class ImageConfig:
|
||||
# Captions
|
||||
main_prompts: set[str] = field(factory=set, converter=safe_set)
|
||||
rating: float = None
|
||||
max_caption_length: int = None
|
||||
tags: set[Tag] = field(factory=set, converter=safe_set)
|
||||
|
||||
# Options
|
||||
multiply: float = None
|
||||
cond_dropout: float = None
|
||||
flip_p: float = None
|
||||
|
||||
def merge(self, other):
|
||||
if other is None:
|
||||
return self
|
||||
|
||||
return ImageConfig(
|
||||
main_prompts=self.main_prompts.union(other.main_prompts),
|
||||
rating=overlay(other.rating, self.rating),
|
||||
max_caption_length=overlay(other.max_caption_length, self.max_caption_length),
|
||||
tags=self.tags.union(other.tags),
|
||||
multiply=overlay(other.multiply, self.multiply),
|
||||
cond_dropout=overlay(other.cond_dropout, self.cond_dropout),
|
||||
flip_p=overlay(other.flip_p, self.flip_p),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
# Parse standard yaml tag file (with options)
|
||||
parsed_cfg = ImageConfig(
|
||||
main_prompts=safe_set(data.get("main_prompt")),
|
||||
rating=data.get("rating"),
|
||||
max_caption_length=data.get("max_caption_length"),
|
||||
tags=safe_set(map(Tag.parse, data.get("tags", []))),
|
||||
multiply=data.get("multiply"),
|
||||
cond_dropout=data.get("cond_dropout"),
|
||||
flip_p=data.get("flip_p"),
|
||||
)
|
||||
|
||||
# Alternatively parse from dedicated `caption` attribute
|
||||
if cap_attr := data.get('caption'):
|
||||
parsed_cfg = parsed_cfg.merge(ImageConfig.parse(cap_attr))
|
||||
|
||||
return parsed_cfg
|
||||
|
||||
@classmethod
|
||||
def fold(cls, configs):
|
||||
acc = ImageConfig()
|
||||
for cfg in configs:
|
||||
acc = acc.merge(cfg)
|
||||
return acc
|
||||
|
||||
def ensure_caption(self):
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_caption_text(cls, text: str):
|
||||
if not text:
|
||||
return ImageConfig()
|
||||
if os.path.isfile(text):
|
||||
return ImageConfig.from_file(text)
|
||||
|
||||
split_caption = list(map(str.strip, text.split(",")))
|
||||
return ImageConfig(
|
||||
main_prompts=split_caption[0],
|
||||
tags=map(Tag.parse, split_caption[1:])
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, file: str):
|
||||
match ext(file):
|
||||
case '.jpg' | '.jpeg' | '.png' | '.bmp' | '.webp' | '.jfif':
|
||||
return ImageConfig(image=file)
|
||||
case ".json":
|
||||
return ImageConfig.from_dict(json.load(read_text(file)))
|
||||
case ".yaml" | ".yml":
|
||||
return ImageConfig.from_dict(yaml.safe_load(read_text(file)))
|
||||
case ".txt" | ".caption":
|
||||
return ImageConfig.from_caption_text(read_text(file))
|
||||
case _:
|
||||
return logging.warning(" *** Unrecognized config extension {ext}")
|
||||
|
||||
@classmethod
|
||||
def parse(cls, input):
|
||||
if isinstance(input, str):
|
||||
if os.path.isfile(input):
|
||||
return ImageConfig.from_file(input)
|
||||
else:
|
||||
return ImageConfig.from_caption_text(input)
|
||||
elif isinstance(input, dict):
|
||||
return ImageConfig.from_dict(input)
|
||||
|
||||
|
||||
@define()
|
||||
class Dataset:
|
||||
image_configs: dict[str, ImageConfig]
|
||||
|
||||
def __global_cfg(files):
|
||||
cfgs = []
|
||||
for file in files:
|
||||
match os.path.basename(file):
|
||||
case 'global.yaml' | 'global.yml':
|
||||
cfgs.append(ImageConfig.from_file(file))
|
||||
return ImageConfig.fold(cfgs)
|
||||
|
||||
def __local_cfg(files):
|
||||
cfgs = []
|
||||
for file in files:
|
||||
match os.path.basename(file):
|
||||
case 'multiply.txt':
|
||||
cfgs.append(ImageConfig(multiply=read_float(file)))
|
||||
case 'cond_dropout.txt':
|
||||
cfgs.append(ImageConfig(cond_dropout=read_float(file)))
|
||||
case 'flip_p.txt':
|
||||
cfgs.append(ImageConfig(flip_p=read_float(file)))
|
||||
case 'local.yaml' | 'local.yml':
|
||||
cfgs.append(ImageConfig.from_file(file))
|
||||
return ImageConfig.fold(cfgs)
|
||||
|
||||
def __sidecar_cfg(imagepath, files):
|
||||
cfgs = []
|
||||
for file in files:
|
||||
if same_barename(imagepath, file):
|
||||
match ext(file):
|
||||
case '.txt' | '.caption' | '.yml' | '.yaml':
|
||||
cfgs.append(ImageConfig.from_file(file))
|
||||
return ImageConfig.fold(cfgs)
|
||||
|
||||
# Use file name for caption only as a last resort
|
||||
@classmethod
|
||||
def __ensure_caption(cls, cfg: ImageConfig, file: str):
|
||||
if cfg.main_prompts or cfg.tags:
|
||||
return cfg
|
||||
cap_cfg = ImageConfig.from_caption_text(barename(file).split("_")[0])
|
||||
return cfg.merge(cap_cfg)
|
||||
|
||||
@classmethod
|
||||
def from_path(cls, data_root):
|
||||
# Create a visitor that maintains global config stack
|
||||
# and accumulates image configs as it traverses dataset
|
||||
image_configs = {}
|
||||
def process_dir(files, parent_globals):
|
||||
global_cfg = parent_globals.merge(Dataset.__global_cfg(files))
|
||||
local_cfg = Dataset.__local_cfg(files)
|
||||
for img in filter(is_image, files):
|
||||
img_cfg = Dataset.__sidecar_cfg(img, files)
|
||||
resolved_cfg = ImageConfig.fold([global_cfg, local_cfg, img_cfg])
|
||||
image_configs[img] = Dataset.__ensure_caption(resolved_cfg, img)
|
||||
return global_cfg
|
||||
|
||||
walk_and_visit(data_root, process_dir, ImageConfig())
|
||||
return Dataset(image_configs)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_path):
|
||||
"""
|
||||
Import a dataset definition from a JSON file
|
||||
"""
|
||||
image_configs = {}
|
||||
with open(json_path, encoding='utf-8', mode='r') as stream:
|
||||
for data in json.load(stream):
|
||||
img = data.get("image")
|
||||
cfg = Dataset.__ensure_caption(ImageConfig.parse(data), img)
|
||||
if not img:
|
||||
logging.warning(f" *** Error parsing json image entry in {json_path}: {data}")
|
||||
continue
|
||||
image_configs[img] = cfg
|
||||
return Dataset(image_configs)
|
||||
|
||||
def image_train_items(self, aspects):
|
||||
items = []
|
||||
for image in self.image_configs:
|
||||
config = self.image_configs[image]
|
||||
if len(config.main_prompts) > 1:
|
||||
logging.warning(f" *** Found multiple multiple main_prompts for image {image}, but only one will be applied: {config.main_prompts}")
|
||||
|
||||
tags = []
|
||||
tag_weights = []
|
||||
for tag in sorted(config.tags):
|
||||
tags.append(tag.value)
|
||||
tag_weights.append(tag.weight)
|
||||
use_weights = len(set(tag_weights)) > 1
|
||||
|
||||
caption = ImageCaption(
|
||||
main_prompt=next(iter(sorted(config.main_prompts))),
|
||||
rating=config.rating or 1.0,
|
||||
tags=tags,
|
||||
tag_weights=tag_weights,
|
||||
max_target_length=config.max_caption_length,
|
||||
use_weights=use_weights)
|
||||
|
||||
item = ImageTrainItem(
|
||||
image=None,
|
||||
caption=caption,
|
||||
aspects=aspects,
|
||||
pathname=os.path.abspath(image),
|
||||
flip_p=config.flip_p or 0.0,
|
||||
multiplier=config.multiply or 1.0,
|
||||
cond_dropout=config.cond_dropout
|
||||
)
|
||||
items.append(item)
|
||||
return list(sorted(items, key=lambda ti: ti.pathname))
|
|
@ -104,7 +104,7 @@ class EveryDreamBatch(Dataset):
|
|||
|
||||
example["image"] = image_transforms(train_item["image"])
|
||||
|
||||
if random.random() > self.conditional_dropout:
|
||||
if random.random() > (train_item.get("cond_dropout", self.conditional_dropout)):
|
||||
example["tokens"] = self.tokenizer(example["caption"],
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
|
@ -132,6 +132,8 @@ class EveryDreamBatch(Dataset):
|
|||
example["image"] = image_train_tmp.image.copy() # hack for now to avoid memory leak
|
||||
image_train_tmp.image = None # hack for now to avoid memory leak
|
||||
example["caption"] = image_train_tmp.caption
|
||||
if image_train_tmp.cond_dropout is not None:
|
||||
example["cond_dropout"] = image_train_tmp.cond_dropout
|
||||
example["runt_size"] = image_train_tmp.runt_size
|
||||
|
||||
return example
|
||||
|
|
|
@ -37,7 +37,6 @@ class ImageCaption:
|
|||
"""
|
||||
Represents the various parts of an image caption
|
||||
"""
|
||||
|
||||
def __init__(self, main_prompt: str, rating: float, tags: list[str], tag_weights: list[float], max_target_length: int, use_weights: bool):
|
||||
"""
|
||||
:param main_prompt: The part of the caption which should always be included
|
||||
|
@ -114,137 +113,6 @@ class ImageCaption:
|
|||
random.Random(seed).shuffle(tags)
|
||||
return ", ".join(tags)
|
||||
|
||||
@staticmethod
|
||||
def parse(string: str) -> 'ImageCaption':
|
||||
"""
|
||||
Parses a string to get the caption.
|
||||
|
||||
:param string: String to parse.
|
||||
:return: `ImageCaption` object.
|
||||
"""
|
||||
split_caption = list(map(str.strip, string.split(",")))
|
||||
main_prompt = split_caption[0]
|
||||
tags = split_caption[1:]
|
||||
tag_weights = [1.0] * len(tags)
|
||||
|
||||
return ImageCaption(main_prompt, 1.0, tags, tag_weights, DEFAULT_MAX_CAPTION_LENGTH, False)
|
||||
|
||||
@staticmethod
|
||||
def from_file_name(file_path: str) -> 'ImageCaption':
|
||||
"""
|
||||
Parses the file name to get the caption.
|
||||
|
||||
:param file_path: Path to the image file.
|
||||
:return: `ImageCaption` object.
|
||||
"""
|
||||
(file_name, _) = os.path.splitext(os.path.basename(file_path))
|
||||
caption = file_name.split("_")[0]
|
||||
return ImageCaption.parse(caption)
|
||||
|
||||
@staticmethod
|
||||
def from_text_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
|
||||
"""
|
||||
Parses a text file to get the caption. Returns the default caption if
|
||||
the file does not exist or is invalid.
|
||||
|
||||
:param file_path: Path to the text file.
|
||||
:param default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid.
|
||||
:return: `ImageCaption` object or `None`.
|
||||
"""
|
||||
try:
|
||||
with open(file_path, encoding='utf-8', mode='r') as caption_file:
|
||||
caption_text = caption_file.read()
|
||||
return ImageCaption.parse(caption_text)
|
||||
except Exception as e:
|
||||
logging.error(f" *** Error reading {file_path} to get caption {e}")
|
||||
return default_caption
|
||||
|
||||
@staticmethod
|
||||
def from_yaml_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
|
||||
"""
|
||||
Parses a yaml file to get the caption. Returns the default caption if
|
||||
the file does not exist or is invalid.
|
||||
|
||||
:param file_path: path to the yaml file
|
||||
:param default_caption: caption to return if the file does not exist or is invalid
|
||||
:return: `ImageCaption` object or `None`.
|
||||
"""
|
||||
try:
|
||||
with open(file_path, "r") as stream:
|
||||
file_content = yaml.safe_load(stream)
|
||||
main_prompt = file_content.get("main_prompt", "")
|
||||
rating = file_content.get("rating", 1.0)
|
||||
unparsed_tags = file_content.get("tags", [])
|
||||
|
||||
max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH)
|
||||
|
||||
tags = []
|
||||
tag_weights = []
|
||||
last_weight = None
|
||||
weights_differ = False
|
||||
for unparsed_tag in unparsed_tags:
|
||||
tag = unparsed_tag.get("tag", "").strip()
|
||||
if len(tag) == 0:
|
||||
continue
|
||||
|
||||
tags.append(tag)
|
||||
tag_weight = unparsed_tag.get("weight", 1.0)
|
||||
tag_weights.append(tag_weight)
|
||||
|
||||
if last_weight is not None and weights_differ is False:
|
||||
weights_differ = last_weight != tag_weight
|
||||
|
||||
last_weight = tag_weight
|
||||
|
||||
return ImageCaption(main_prompt, rating, tags, tag_weights, max_caption_length, weights_differ)
|
||||
except:
|
||||
logging.error(f" *** Error reading {file_path} to get caption")
|
||||
return default_caption
|
||||
|
||||
@staticmethod
|
||||
def from_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
|
||||
"""
|
||||
Try to resolve a caption from a file path or return `default_caption`.
|
||||
|
||||
:string: The path to the file to parse.
|
||||
:default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid.
|
||||
:return: `ImageCaption` object or `None`.
|
||||
"""
|
||||
if os.path.exists(file_path):
|
||||
(file_path_without_ext, ext) = os.path.splitext(file_path)
|
||||
match ext:
|
||||
case ".yaml" | ".yml":
|
||||
return ImageCaption.from_yaml_file(file_path, default_caption)
|
||||
|
||||
case ".txt" | ".caption":
|
||||
return ImageCaption.from_text_file(file_path, default_caption)
|
||||
|
||||
case '.jpg'| '.jpeg'| '.png'| '.bmp'| '.webp'| '.jfif':
|
||||
for ext in [".yaml", ".yml", ".txt", ".caption"]:
|
||||
file_path = file_path_without_ext + ext
|
||||
image_caption = ImageCaption.from_file(file_path)
|
||||
if image_caption is not None:
|
||||
return image_caption
|
||||
return ImageCaption.from_file_name(file_path)
|
||||
|
||||
case _:
|
||||
return default_caption
|
||||
else:
|
||||
return default_caption
|
||||
|
||||
@staticmethod
|
||||
def resolve(string: str) -> 'ImageCaption':
|
||||
"""
|
||||
Try to resolve a caption from a string. If the string is a file path,
|
||||
the caption will be read from the file, otherwise the string will be
|
||||
parsed as a caption.
|
||||
|
||||
:string: The string to resolve.
|
||||
:return: `ImageCaption` object.
|
||||
"""
|
||||
return ImageCaption.from_file(string, None) or ImageCaption.parse(string)
|
||||
|
||||
|
||||
class ImageTrainItem:
|
||||
"""
|
||||
image: PIL.Image
|
||||
|
@ -254,7 +122,7 @@ class ImageTrainItem:
|
|||
flip_p: probability of flipping image (0.0 to 1.0)
|
||||
rating: the relative rating of the images. The rating is measured in comparison to the other images.
|
||||
"""
|
||||
def __init__(self, image: PIL.Image, caption: ImageCaption, aspects: list[float], pathname: str, flip_p=0.0, multiplier: float=1.0):
|
||||
def __init__(self, image: PIL.Image, caption: ImageCaption, aspects: list[float], pathname: str, flip_p=0.0, multiplier: float=1.0, cond_dropout=None):
|
||||
self.caption = caption
|
||||
self.aspects = aspects
|
||||
self.pathname = pathname
|
||||
|
@ -262,6 +130,7 @@ class ImageTrainItem:
|
|||
self.cropped_img = None
|
||||
self.runt_size = 0
|
||||
self.multiplier = multiplier
|
||||
self.cond_dropout = cond_dropout
|
||||
|
||||
self.image_size = None
|
||||
if image is None or len(image) == 0:
|
||||
|
|
109
data/resolver.py
109
data/resolver.py
|
@ -4,6 +4,7 @@ import os
|
|||
import typing
|
||||
import zipfile
|
||||
import argparse
|
||||
from data.dataset import Dataset
|
||||
|
||||
import tqdm
|
||||
from colorama import Fore, Style
|
||||
|
@ -27,16 +28,6 @@ class DataResolver:
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def image_train_item(self, image_path: str, caption: ImageCaption, multiplier: float=1) -> ImageTrainItem:
|
||||
return ImageTrainItem(
|
||||
image=None,
|
||||
caption=caption,
|
||||
aspects=self.aspects,
|
||||
pathname=image_path,
|
||||
flip_p=self.flip_p,
|
||||
multiplier=multiplier
|
||||
)
|
||||
|
||||
class JSONResolver(DataResolver):
|
||||
def image_train_items(self, json_path: str) -> list[ImageTrainItem]:
|
||||
"""
|
||||
|
@ -45,62 +36,8 @@ class JSONResolver(DataResolver):
|
|||
|
||||
:param json_path: The path to the JSON file.
|
||||
"""
|
||||
items = []
|
||||
with open(json_path, encoding='utf-8', mode='r') as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
for data in tqdm.tqdm(json_data):
|
||||
caption = JSONResolver.image_caption(data)
|
||||
if caption:
|
||||
image_value = JSONResolver.get_image_value(data)
|
||||
item = self.image_train_item(image_value, caption)
|
||||
if item:
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
return Dataset.from_json(json_path).image_train_items(self.aspects)
|
||||
|
||||
@staticmethod
|
||||
def get_image_value(json_data: dict) -> typing.Optional[str]:
|
||||
"""
|
||||
Get the image from the json data if possible.
|
||||
|
||||
:param json_data: The json data, a dict.
|
||||
:return: The image, or None if not found.
|
||||
"""
|
||||
image_value = json_data.get("image", None)
|
||||
if isinstance(image_value, str):
|
||||
image_value = image_value.strip()
|
||||
if os.path.exists(image_value):
|
||||
return image_value
|
||||
|
||||
@staticmethod
|
||||
def get_caption_value(json_data: dict) -> typing.Optional[str]:
|
||||
"""
|
||||
Get the caption from the json data if possible.
|
||||
|
||||
:param json_data: The json data, a dict.
|
||||
:return: The caption, or None if not found.
|
||||
"""
|
||||
caption_value = json_data.get("caption", None)
|
||||
if isinstance(caption_value, str):
|
||||
return caption_value.strip()
|
||||
|
||||
@staticmethod
|
||||
def image_caption(json_data: dict) -> typing.Optional[ImageCaption]:
|
||||
"""
|
||||
Get the caption from the json data if possible.
|
||||
|
||||
:param json_data: The json data, a dict.
|
||||
:return: The `ImageCaption`, or None if not found.
|
||||
"""
|
||||
image_value = JSONResolver.get_image_value(json_data)
|
||||
caption_value = JSONResolver.get_caption_value(json_data)
|
||||
if image_value:
|
||||
if caption_value:
|
||||
return ImageCaption.resolve(caption_value)
|
||||
return ImageCaption.from_file(image_value)
|
||||
|
||||
|
||||
class DirectoryResolver(DataResolver):
|
||||
def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
|
||||
"""
|
||||
|
@ -111,32 +48,7 @@ class DirectoryResolver(DataResolver):
|
|||
:param data_root: The root directory to recurse through
|
||||
"""
|
||||
DirectoryResolver.unzip_all(data_root)
|
||||
image_paths = list(DirectoryResolver.recurse_data_root(data_root))
|
||||
items = []
|
||||
multipliers = {}
|
||||
|
||||
for pathname in tqdm.tqdm(image_paths):
|
||||
current_dir = os.path.dirname(pathname)
|
||||
|
||||
if current_dir not in multipliers:
|
||||
multiply_txt_path = os.path.join(current_dir, "multiply.txt")
|
||||
if os.path.exists(multiply_txt_path):
|
||||
try:
|
||||
with open(multiply_txt_path, 'r') as f:
|
||||
val = float(f.read().strip())
|
||||
multipliers[current_dir] = val
|
||||
logging.info(f" - multiply.txt in '{current_dir}' set to {val}")
|
||||
except Exception as e:
|
||||
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}")
|
||||
multipliers[current_dir] = 1.0
|
||||
else:
|
||||
multipliers[current_dir] = 1.0
|
||||
|
||||
caption = ImageCaption.resolve(pathname)
|
||||
item = self.image_train_item(pathname, caption, multiplier=multipliers[current_dir])
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
return Dataset.from_path(data_root).image_train_items(self.aspects)
|
||||
|
||||
@staticmethod
|
||||
def unzip_all(path):
|
||||
|
@ -150,21 +62,6 @@ class DirectoryResolver(DataResolver):
|
|||
except Exception as e:
|
||||
logging.error(f"Error unzipping files {e}")
|
||||
|
||||
@staticmethod
|
||||
def recurse_data_root(recurse_root):
|
||||
for f in os.listdir(recurse_root):
|
||||
current = os.path.join(recurse_root, f)
|
||||
|
||||
if os.path.isfile(current):
|
||||
ext = os.path.splitext(f)[1].lower()
|
||||
if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']:
|
||||
yield current
|
||||
|
||||
for d in os.listdir(recurse_root):
|
||||
current = os.path.join(recurse_root, d)
|
||||
if os.path.isdir(current):
|
||||
yield from DirectoryResolver.recurse_data_root(current)
|
||||
|
||||
def strategy(data_root: str) -> typing.Type[DataResolver]:
|
||||
"""
|
||||
Determine the strategy to use for resolving the data.
|
||||
|
|
|
@ -29,7 +29,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
|||
python3 -m venv ${VIRTUAL_ENV} && \
|
||||
pip install -U -I torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url "https://download.pytorch.org/whl/cu117" && \
|
||||
pip install -r requirements.txt && \
|
||||
pip install --pre --no-deps xformers==0.0.17.dev451
|
||||
pip install --pre -U --no-deps xformers
|
||||
# In case of emergency, build xformers from scratch
|
||||
# export FORCE_CUDA=1 && export TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6" && export CUDA_VISIBLE_DEVICES=0 && \
|
||||
# pip install --no-deps git+https://github.com/facebookresearch/xformers.git@48a77cc#egg=xformers
|
||||
|
|
|
@ -10,6 +10,7 @@ ninja
|
|||
omegaconf==2.2.3
|
||||
piexif==1.1.3
|
||||
protobuf==3.20.3
|
||||
pyfakefs
|
||||
pynvml==11.5.0
|
||||
pyre-extensions==0.0.30
|
||||
pytorch-lightning==1.9.2
|
||||
|
|
|
@ -58,13 +58,13 @@ class TestResolve(unittest.TestCase):
|
|||
|
||||
def test_directory_resolve_with_str(self):
|
||||
items = resolver.resolve(DATA_PATH, ARGS)
|
||||
image_paths = [item.pathname for item in items]
|
||||
image_paths = set(item.pathname for item in items)
|
||||
image_captions = [item.caption for item in items]
|
||||
captions = [caption.get_caption() for caption in image_captions]
|
||||
captions = set(caption.get_caption() for caption in image_captions)
|
||||
|
||||
self.assertEqual(len(items), 3)
|
||||
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH])
|
||||
self.assertEqual(captions, ['caption for test1', 'test2', 'test3'])
|
||||
self.assertEqual(image_paths, {IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH})
|
||||
self.assertEqual(captions, {'caption for test1', 'test2', 'test3'})
|
||||
|
||||
undersized_images = list(filter(lambda i: i.is_undersized, items))
|
||||
self.assertEqual(len(undersized_images), 1)
|
||||
|
|
|
@ -0,0 +1,289 @@
|
|||
import os
|
||||
from data.dataset import Dataset, ImageConfig, Tag
|
||||
|
||||
from textwrap import dedent
|
||||
from pyfakefs.fake_filesystem_unittest import TestCase
|
||||
|
||||
class TestDataset(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.maxDiff = None
|
||||
self.setUpPyfakefs()
|
||||
|
||||
def test_a_caption_is_generated_from_image_given_no_other_config(self):
|
||||
self.fs.create_file("image, tag1, tag2.jpg")
|
||||
|
||||
actual = Dataset.from_path(".").image_configs
|
||||
|
||||
expected = {
|
||||
"./image, tag1, tag2.jpg": ImageConfig(main_prompts="image", tags=frozenset([Tag("tag1"), Tag("tag2")]))
|
||||
}
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_several_image_formats_are_supported(self):
|
||||
self.fs.create_file("image.JPG")
|
||||
self.fs.create_file("image.jpeg")
|
||||
self.fs.create_file("image.png")
|
||||
self.fs.create_file("image.webp")
|
||||
self.fs.create_file("image.jfif")
|
||||
self.fs.create_file("image.bmp")
|
||||
|
||||
actual = Dataset.from_path(".").image_configs
|
||||
|
||||
common_cfg = ImageConfig(main_prompts="image")
|
||||
expected = {
|
||||
"./image.JPG": common_cfg,
|
||||
"./image.jpeg": common_cfg,
|
||||
"./image.png": common_cfg,
|
||||
"./image.webp": common_cfg,
|
||||
"./image.jfif": common_cfg,
|
||||
"./image.bmp": common_cfg,
|
||||
}
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_captions_can_be_read_from_txt_or_caption_sidecar(self):
|
||||
self.fs.create_file("image_1.jpg")
|
||||
self.fs.create_file("image_1.txt", contents="an image, test, from .txt")
|
||||
self.fs.create_file("image_2.jpg")
|
||||
self.fs.create_file("image_2.caption", contents="an image, test, from .caption")
|
||||
|
||||
actual = Dataset.from_path(".").image_configs
|
||||
|
||||
expected = {
|
||||
"./image_1.jpg": ImageConfig(main_prompts="an image", tags=frozenset([Tag("test"), Tag("from .txt")])),
|
||||
"./image_2.jpg": ImageConfig(main_prompts="an image", tags=frozenset([Tag("test"), Tag("from .caption")]))
|
||||
}
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
|
||||
def test_captions_and_options_can_be_read_from_yaml_sidecar(self):
|
||||
self.fs.create_file("image_1.jpg")
|
||||
self.fs.create_file("image_1.yaml",
|
||||
contents=dedent("""
|
||||
multiply: 2
|
||||
cond_dropout: 0.05
|
||||
flip_p: 0.5
|
||||
caption: "A simple caption, from .yaml"
|
||||
"""))
|
||||
self.fs.create_file("image_2.jpg")
|
||||
self.fs.create_file("image_2.yml",
|
||||
contents=dedent("""
|
||||
flip_p: 0.0
|
||||
caption:
|
||||
main_prompt: A complex caption
|
||||
rating: 1.1
|
||||
max_caption_length: 1024
|
||||
tags:
|
||||
- tag: from .yml
|
||||
- tag: with weight
|
||||
weight: 0.5
|
||||
"""))
|
||||
|
||||
actual = Dataset.from_path(".").image_configs
|
||||
|
||||
expected = {
|
||||
"./image_1.jpg": ImageConfig(
|
||||
multiply=2,
|
||||
cond_dropout=0.05,
|
||||
flip_p=0.5,
|
||||
main_prompts="A simple caption",
|
||||
tags= { Tag("from .yaml") }
|
||||
),
|
||||
"./image_2.jpg": ImageConfig(
|
||||
flip_p=0.0,
|
||||
rating=1.1,
|
||||
max_caption_length=1024,
|
||||
main_prompts="A complex caption",
|
||||
tags= { Tag("from .yml"), Tag("with weight", weight=0.5) }
|
||||
)
|
||||
}
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
|
||||
def test_multiple_prompts_and_tags_from_multiple_sidecars_are_supported(self):
|
||||
self.fs.create_file("image_1.jpg")
|
||||
self.fs.create_file("image_1.yaml", contents=dedent("""
|
||||
main_prompt:
|
||||
- unique prompt
|
||||
- dupe prompt
|
||||
tags:
|
||||
- from .yaml
|
||||
- dupe tag
|
||||
"""))
|
||||
self.fs.create_file("image_1.txt", contents="also unique prompt, from .txt, dupe tag")
|
||||
self.fs.create_file("image_1.caption", contents="dupe prompt, from .caption")
|
||||
|
||||
actual = Dataset.from_path(".").image_configs
|
||||
|
||||
expected = {
|
||||
"./image_1.jpg": ImageConfig(
|
||||
main_prompts={ "unique prompt", "also unique prompt", "dupe prompt" },
|
||||
tags={ Tag("from .yaml"), Tag("from .txt"), Tag("from .caption"), Tag("dupe tag") }
|
||||
)
|
||||
}
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_sidecars_can_also_be_attached_to_local_and_recursive_folders(self):
|
||||
self.fs.create_file("./global.yaml",
|
||||
contents=dedent("""\
|
||||
main_prompt: global prompt
|
||||
tags:
|
||||
- global tag
|
||||
flip_p: 0.0
|
||||
"""))
|
||||
|
||||
self.fs.create_file("./local.yaml",
|
||||
contents=dedent("""
|
||||
main_prompt: local prompt
|
||||
tags:
|
||||
- tag: local tag
|
||||
"""))
|
||||
|
||||
self.fs.create_file("./arbitrary filename.png")
|
||||
self.fs.create_file("./sub/sub arbitrary filename.png")
|
||||
self.fs.create_file("./sub/sidecar.png")
|
||||
self.fs.create_file("./sub/sidecar.txt",
|
||||
contents="sidecar prompt, sidecar tag")
|
||||
|
||||
self.fs.create_file("./optfile/optfile.png")
|
||||
self.fs.create_file("./optfile/flip_p.txt",
|
||||
contents="0.1234")
|
||||
|
||||
self.fs.create_file("./sub/sub2/global.yaml",
|
||||
contents=dedent("""
|
||||
tags:
|
||||
- tag: sub global tag
|
||||
"""))
|
||||
self.fs.create_file("./sub/sub2/local.yaml",
|
||||
contents=dedent("""
|
||||
tags:
|
||||
- This tag wil not apply to any files
|
||||
"""))
|
||||
self.fs.create_file("./sub/sub2/sub3/xyz.png")
|
||||
|
||||
actual = Dataset.from_path(".").image_configs
|
||||
|
||||
expected = {
|
||||
"./arbitrary filename.png": ImageConfig(
|
||||
main_prompts={ 'global prompt', 'local prompt' },
|
||||
tags={ Tag("global tag"), Tag("local tag") },
|
||||
flip_p=0.0
|
||||
),
|
||||
"./sub/sub arbitrary filename.png": ImageConfig(
|
||||
main_prompts={ 'global prompt' },
|
||||
tags={ Tag("global tag") },
|
||||
flip_p=0.0
|
||||
),
|
||||
"./sub/sidecar.png": ImageConfig(
|
||||
main_prompts={ 'global prompt', 'sidecar prompt' },
|
||||
tags={ Tag("global tag"), Tag("sidecar tag") },
|
||||
flip_p=0.0
|
||||
),
|
||||
"./optfile/optfile.png": ImageConfig(
|
||||
main_prompts={ 'global prompt' },
|
||||
tags={ Tag("global tag") },
|
||||
flip_p=0.1234
|
||||
),
|
||||
"./sub/sub2/sub3/xyz.png": ImageConfig(
|
||||
main_prompts={ 'global prompt' },
|
||||
tags={ Tag("global tag"), Tag("sub global tag") },
|
||||
flip_p=0.0
|
||||
)
|
||||
}
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_can_load_dataset_from_json_manifest(self):
|
||||
self.fs.create_file("./stuff/image_1.jpg")
|
||||
self.fs.create_file("./stuff/default.caption", contents= "default caption")
|
||||
self.fs.create_file("./other/image_1.jpg")
|
||||
self.fs.create_file("./other/image_2.jpg")
|
||||
self.fs.create_file("./other/image_3.jpg")
|
||||
self.fs.create_file("./manifest.json", contents=dedent("""
|
||||
[
|
||||
{ "image": "./stuff/image_1.jpg", "caption": "./stuff/default.caption" },
|
||||
{ "image": "./other/image_1.jpg", "caption": "other caption" },
|
||||
{
|
||||
"image": "./other/image_2.jpg",
|
||||
"caption": {
|
||||
"main_prompt": "complex caption",
|
||||
"rating": 0.1,
|
||||
"max_caption_length": 1000,
|
||||
"tags": [
|
||||
{"tag": "including"},
|
||||
{"tag": "weighted tag", "weight": 999.9}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"image": "./other/image_3.jpg",
|
||||
"multiply": 2,
|
||||
"flip_p": 0.5,
|
||||
"cond_dropout": 0.01,
|
||||
"main_prompt": [
|
||||
"first caption",
|
||||
"second caption"
|
||||
]
|
||||
}
|
||||
]
|
||||
"""))
|
||||
|
||||
actual = Dataset.from_json("./manifest.json").image_configs
|
||||
expected = {
|
||||
"./stuff/image_1.jpg": ImageConfig( main_prompts={"default caption"} ),
|
||||
"./other/image_1.jpg": ImageConfig( main_prompts={"other caption"} ),
|
||||
"./other/image_2.jpg": ImageConfig(
|
||||
main_prompts={ "complex caption" },
|
||||
rating=0.1,
|
||||
max_caption_length=1000,
|
||||
tags={
|
||||
Tag("including"),
|
||||
Tag("weighted tag", 999.9)
|
||||
}
|
||||
),
|
||||
"./other/image_3.jpg": ImageConfig(
|
||||
main_prompts={ "first caption", "second caption" },
|
||||
multiply=2,
|
||||
flip_p=0.5,
|
||||
cond_dropout=0.01
|
||||
)
|
||||
}
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_dataset_can_produce_train_items(self):
|
||||
dataset = Dataset({
|
||||
"1.jpg": ImageConfig(
|
||||
multiply=2,
|
||||
flip_p=0.1,
|
||||
cond_dropout=0.01,
|
||||
main_prompts=["first caption","second caption"],
|
||||
rating = 1.1,
|
||||
max_caption_length=1024,
|
||||
tags=frozenset([
|
||||
Tag("tag"),
|
||||
Tag("tag_2", 2.0)
|
||||
])),
|
||||
"2.jpg": ImageConfig( main_prompts="single caption")
|
||||
})
|
||||
|
||||
aspects = []
|
||||
actual = dataset.image_train_items(aspects)
|
||||
|
||||
self.assertEqual(len(actual), 2)
|
||||
|
||||
self.assertEqual(actual[0].pathname, os.path.abspath('1.jpg'))
|
||||
self.assertEqual(actual[0].multiplier, 2.0)
|
||||
self.assertEqual(actual[0].flip.p, 0.1)
|
||||
self.assertEqual(actual[0].cond_dropout, 0.01)
|
||||
self.assertEqual(actual[0].caption.rating(), 1.1)
|
||||
self.assertEqual(actual[0].caption.get_caption(), "first caption, tag, tag_2")
|
||||
# Can't test this
|
||||
# self.assertTrue(actual[0].caption.__use_weights)
|
||||
|
||||
self.assertEqual(actual[1].pathname, os.path.abspath('2.jpg'))
|
||||
self.assertEqual(actual[1].multiplier, 1.0)
|
||||
self.assertEqual(actual[1].flip.p, 0.0)
|
||||
self.assertIsNone(actual[1].cond_dropout)
|
||||
self.assertEqual(actual[1].caption.rating(), 1.0)
|
||||
self.assertEqual(actual[1].caption.get_caption(), "single caption")
|
||||
# Can't test this
|
||||
# self.assertFalse(actual[1].caption.__use_weights)
|
|
@ -32,40 +32,4 @@ class TestImageCaption(unittest.TestCase):
|
|||
self.assertEqual(caption.get_caption(), "hello world, one, two, three")
|
||||
|
||||
caption = ImageCaption("hello world", 1.0, [], [], 2048, False)
|
||||
self.assertEqual(caption.get_caption(), "hello world")
|
||||
|
||||
def test_parse(self):
|
||||
caption = ImageCaption.parse("hello world, one, two, three")
|
||||
|
||||
self.assertEqual(caption.get_caption(), "hello world, one, two, three")
|
||||
|
||||
def test_from_file_name(self):
|
||||
caption = ImageCaption.from_file_name("foo bar_1_2_3.jpg")
|
||||
self.assertEqual(caption.get_caption(), "foo bar")
|
||||
|
||||
def test_from_text_file(self):
|
||||
caption = ImageCaption.from_text_file("test/data/test1.txt")
|
||||
self.assertEqual(caption.get_caption(), "caption for test1")
|
||||
|
||||
def test_from_file(self):
|
||||
caption = ImageCaption.from_file("test/data/test1.txt")
|
||||
self.assertEqual(caption.get_caption(), "caption for test1")
|
||||
|
||||
caption = ImageCaption.from_file("test/data/test_caption.caption")
|
||||
self.assertEqual(caption.get_caption(), "caption for test2")
|
||||
|
||||
def test_resolve(self):
|
||||
caption = ImageCaption.resolve("test/data/test1.txt")
|
||||
self.assertEqual(caption.get_caption(), "caption for test1")
|
||||
|
||||
caption = ImageCaption.resolve("test/data/test_caption.caption")
|
||||
self.assertEqual(caption.get_caption(), "caption for test2")
|
||||
|
||||
caption = ImageCaption.resolve("hello world")
|
||||
self.assertEqual(caption.get_caption(), "hello world")
|
||||
|
||||
caption = ImageCaption.resolve("test/data/test1.jpg")
|
||||
self.assertEqual(caption.get_caption(), "caption for test1")
|
||||
|
||||
caption = ImageCaption.resolve("test/data/test2.jpg")
|
||||
self.assertEqual(caption.get_caption(), "test2")
|
||||
self.assertEqual(caption.get_caption(), "hello world")
|
|
@ -0,0 +1,46 @@
|
|||
|
||||
def barename(file):
|
||||
(val, _) = os.path.splitext(os.path.basename(file))
|
||||
return val
|
||||
|
||||
def ext(file):
|
||||
(_, val) = os.path.splitext(os.path.basename(file))
|
||||
return val.lower()
|
||||
|
||||
def same_barename(lhs, rhs):
|
||||
return barename(lhs) == barename(rhs)
|
||||
|
||||
def is_image(file):
|
||||
return ext(file) in {'.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif'}
|
||||
|
||||
def read_text(file):
|
||||
try:
|
||||
with open(file, encoding='utf-8', mode='r') as stream:
|
||||
return stream.read().strip()
|
||||
except Exception as e:
|
||||
logging.warning(f" *** Error reading text file {file}: {e}")
|
||||
|
||||
def read_float(file):
|
||||
try:
|
||||
return float(read_text(file))
|
||||
except Exception as e:
|
||||
logging.warning(f" *** Could not parse '{data}' to float in file {file}: {e}")
|
||||
|
||||
import os
|
||||
|
||||
def walk_and_visit(path, visit_fn, context=None):
|
||||
names = [entry.name for entry in os.scandir(path)]
|
||||
|
||||
dirs = []
|
||||
files = []
|
||||
for name in names:
|
||||
fullname = os.path.join(path, name)
|
||||
if os.path.isdir(fullname) and not str(name).startswith('.'):
|
||||
dirs.append(fullname)
|
||||
else:
|
||||
files.append(fullname)
|
||||
|
||||
subcontext = visit_fn(files, context)
|
||||
|
||||
for subdir in dirs:
|
||||
walk_and_visit(subdir, visit_fn, subcontext)
|
Loading…
Reference in New Issue