Merge pull request #7 from JanGerritsen/yaml_caption_files
Support more control regarding caption tag shuffeling using yaml files
This commit is contained in:
commit
99d8c6bc32
|
@ -34,3 +34,5 @@ Behind the scenes look at how the trainer handles multiaspect and crop jitter
|
||||||
[Advanced Tweaking](doc/ATWEAKING.md)
|
[Advanced Tweaking](doc/ATWEAKING.md)
|
||||||
|
|
||||||
[Chaining training sessions](doc/CHAINING.md)
|
[Chaining training sessions](doc/CHAINING.md)
|
||||||
|
|
||||||
|
[Shuffling Tags](doc/SHUFFLING_TAGS.md)
|
||||||
|
|
|
@ -16,9 +16,11 @@ limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import yaml
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import random
|
import random
|
||||||
from data.image_train_item import ImageTrainItem
|
from data.image_train_item import ImageTrainItem, ImageCaption
|
||||||
import data.aspects as aspects
|
import data.aspects as aspects
|
||||||
from colorama import Fore, Style
|
from colorama import Fore, Style
|
||||||
import zipfile
|
import zipfile
|
||||||
|
@ -27,6 +29,8 @@ import PIL
|
||||||
|
|
||||||
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
|
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
|
||||||
|
|
||||||
|
DEFAULT_MAX_CAPTION_LENGTH = 2048
|
||||||
|
|
||||||
class DataLoaderMultiAspect():
|
class DataLoaderMultiAspect():
|
||||||
"""
|
"""
|
||||||
Data loader for multi-aspect-ratio training and bucketing
|
Data loader for multi-aspect-ratio training and bucketing
|
||||||
|
@ -76,17 +80,64 @@ class DataLoaderMultiAspect():
|
||||||
return self.image_caption_pairs
|
return self.image_caption_pairs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __read_caption_from_file(file_path, fallback_caption):
|
def __read_caption_from_file(file_path, fallback_caption: ImageCaption) -> ImageCaption:
|
||||||
caption = fallback_caption
|
|
||||||
try:
|
try:
|
||||||
with open(file_path, encoding='utf-8', mode='r') as caption_file:
|
with open(file_path, encoding='utf-8', mode='r') as caption_file:
|
||||||
caption = caption_file.read()
|
caption_text = caption_file.read()
|
||||||
|
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_text)
|
||||||
except:
|
except:
|
||||||
logging.error(f" *** Error reading {file_path} to get caption, falling back to filename")
|
logging.error(f" *** Error reading {file_path} to get caption, falling back to filename")
|
||||||
caption = fallback_caption
|
caption = fallback_caption
|
||||||
pass
|
pass
|
||||||
return caption
|
return caption
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __read_caption_from_yaml(file_path: str, fallback_caption: ImageCaption) -> ImageCaption:
|
||||||
|
with open(file_path, "r") as stream:
|
||||||
|
try:
|
||||||
|
file_content = yaml.safe_load(stream)
|
||||||
|
main_prompt = file_content.get("main_prompt", "")
|
||||||
|
unparsed_tags = file_content.get("tags", [])
|
||||||
|
|
||||||
|
max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH)
|
||||||
|
|
||||||
|
tags = []
|
||||||
|
tag_weights = []
|
||||||
|
last_weight = None
|
||||||
|
weights_differ = False
|
||||||
|
for unparsed_tag in unparsed_tags:
|
||||||
|
tag = unparsed_tag.get("tag", "").strip()
|
||||||
|
if len(tag) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tags.append(tag)
|
||||||
|
tag_weight = unparsed_tag.get("weight", 1.0)
|
||||||
|
tag_weights.append(tag_weight)
|
||||||
|
|
||||||
|
if last_weight is not None and weights_differ is False:
|
||||||
|
weights_differ = last_weight != tag_weight
|
||||||
|
|
||||||
|
last_weight = tag_weight
|
||||||
|
|
||||||
|
return ImageCaption(main_prompt, tags, tag_weights, max_caption_length, weights_differ)
|
||||||
|
|
||||||
|
except:
|
||||||
|
logging.error(f" *** Error reading {file_path} to get caption, falling back to filename")
|
||||||
|
return fallback_caption
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __split_caption_into_tags(caption_string: str) -> ImageCaption:
|
||||||
|
"""
|
||||||
|
Splits a string by "," into the main prompt and additional tags with equal weights
|
||||||
|
"""
|
||||||
|
split_caption = caption_string.split(",")
|
||||||
|
main_prompt = split_caption.pop(0).strip()
|
||||||
|
tags = []
|
||||||
|
for tag in split_caption:
|
||||||
|
tags.append(tag.strip())
|
||||||
|
|
||||||
|
return ImageCaption(main_prompt, tags, [1.0] * len(tags), DEFAULT_MAX_CAPTION_LENGTH, False)
|
||||||
|
|
||||||
def __prescan_images(self, image_paths: list, flip_p=0.0):
|
def __prescan_images(self, image_paths: list, flip_p=0.0):
|
||||||
"""
|
"""
|
||||||
Create ImageTrainItem objects with metadata for hydration later
|
Create ImageTrainItem objects with metadata for hydration later
|
||||||
|
@ -95,16 +146,19 @@ class DataLoaderMultiAspect():
|
||||||
|
|
||||||
for pathname in tqdm.tqdm(image_paths):
|
for pathname in tqdm.tqdm(image_paths):
|
||||||
caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0]
|
caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0]
|
||||||
|
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_from_filename)
|
||||||
|
|
||||||
txt_file_path = os.path.splitext(pathname)[0] + ".txt"
|
file_path_without_ext = os.path.splitext(pathname)[0]
|
||||||
caption_file_path = os.path.splitext(pathname)[0] + ".caption"
|
yaml_file_path = file_path_without_ext + ".yaml"
|
||||||
|
txt_file_path = file_path_without_ext + ".txt"
|
||||||
|
caption_file_path = file_path_without_ext + ".caption"
|
||||||
|
|
||||||
if os.path.exists(txt_file_path):
|
if os.path.exists(yaml_file_path):
|
||||||
caption = self.__read_caption_from_file(txt_file_path, caption_from_filename)
|
caption = self.__read_caption_from_yaml(yaml_file_path, caption)
|
||||||
|
elif os.path.exists(txt_file_path):
|
||||||
|
caption = self.__read_caption_from_file(txt_file_path, caption)
|
||||||
elif os.path.exists(caption_file_path):
|
elif os.path.exists(caption_file_path):
|
||||||
caption = self.__read_caption_from_file(caption_file_path, caption_from_filename)
|
caption = self.__read_caption_from_file(caption_file_path, caption)
|
||||||
else:
|
|
||||||
caption = caption_from_filename
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
image = Image.open(pathname)
|
image = Image.open(pathname)
|
||||||
|
|
|
@ -103,6 +103,7 @@ class EveryDreamBatch(Dataset):
|
||||||
return dls.shared_dataloader.runts
|
return dls.shared_dataloader.runts
|
||||||
|
|
||||||
def shuffle(self, epoch_n):
|
def shuffle(self, epoch_n):
|
||||||
|
self.seed += 1
|
||||||
if dls.shared_dataloader:
|
if dls.shared_dataloader:
|
||||||
dls.shared_dataloader.shuffle()
|
dls.shared_dataloader.shuffle()
|
||||||
self.image_train_items = dls.shared_dataloader.get_all_images()
|
self.image_train_items = dls.shared_dataloader.get_all_images()
|
||||||
|
@ -134,16 +135,15 @@ class EveryDreamBatch(Dataset):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.shuffle_tags and "," in train_item['caption']:
|
if self.shuffle_tags:
|
||||||
tags = train_item["caption"].split(",")
|
example["caption"] = train_item["caption"].get_shuffled_caption(self.seed)
|
||||||
random.Random(self.seed).shuffle(tags)
|
else:
|
||||||
self.seed += 1
|
example["caption"] = train_item["caption"].get_caption()
|
||||||
train_item["caption"] = ", ".join(tags)
|
|
||||||
|
|
||||||
example["image"] = image_transforms(train_item["image"])
|
example["image"] = image_transforms(train_item["image"])
|
||||||
|
|
||||||
if random.random() > self.conditional_dropout:
|
if random.random() > self.conditional_dropout:
|
||||||
example["tokens"] = self.tokenizer(train_item["caption"],
|
example["tokens"] = self.tokenizer(example["caption"],
|
||||||
truncation=True,
|
truncation=True,
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
max_length=self.tokenizer.model_max_length,
|
max_length=self.tokenizer.model_max_length,
|
||||||
|
@ -156,7 +156,7 @@ class EveryDreamBatch(Dataset):
|
||||||
).input_ids
|
).input_ids
|
||||||
|
|
||||||
example["tokens"] = torch.tensor(example["tokens"])
|
example["tokens"] = torch.tensor(example["tokens"])
|
||||||
example["caption"] = train_item["caption"] # for sampling if needed
|
|
||||||
example["runt_size"] = train_item["runt_size"]
|
example["runt_size"] = train_item["runt_size"]
|
||||||
|
|
||||||
return example
|
return example
|
||||||
|
|
|
@ -13,16 +13,90 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
import PIL
|
import bisect
|
||||||
import numpy as np
|
import logging
|
||||||
from torchvision import transforms, utils
|
|
||||||
import random
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import logging
|
import random
|
||||||
|
|
||||||
|
import PIL
|
||||||
|
import numpy as np
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
_RANDOM_TRIM = 0.04
|
_RANDOM_TRIM = 0.04
|
||||||
|
|
||||||
|
|
||||||
|
class ImageCaption:
|
||||||
|
"""
|
||||||
|
Represents the various parts of an image caption
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, main_prompt: str, tags: list[str], tag_weights: list[float], max_target_length: int, use_weights: bool):
|
||||||
|
"""
|
||||||
|
:param main_prompt: The part of the caption which should always be included
|
||||||
|
:param tags: list of tags to pick from to fill the caption
|
||||||
|
:param tag_weights: weights to indicate which tags are more desired and should be picked preferably
|
||||||
|
:param max_target_length: The desired maximum length of a generated caption
|
||||||
|
:param use_weights: if ture, weights are considered when shuffling tags
|
||||||
|
"""
|
||||||
|
self.__main_prompt = main_prompt
|
||||||
|
self.__tags = tags
|
||||||
|
self.__tag_weights = tag_weights
|
||||||
|
self.__max_target_length = max_target_length
|
||||||
|
self.__use_weights = use_weights
|
||||||
|
if use_weights and len(tags) > len(tag_weights):
|
||||||
|
self.__tag_weights.extend([1.0] * (len(tags) - len(tag_weights)))
|
||||||
|
|
||||||
|
if use_weights and len(tag_weights) > len(tags):
|
||||||
|
self.__tag_weights = tag_weights[:len(tags)]
|
||||||
|
|
||||||
|
def get_shuffled_caption(self, seed: int) -> str:
|
||||||
|
"""
|
||||||
|
returns the caption a string with a random selection of the tags in random order
|
||||||
|
:param seed used to initialize the randomizer
|
||||||
|
:return: generated caption string
|
||||||
|
"""
|
||||||
|
max_target_tag_length = self.__max_target_length - len(self.__main_prompt)
|
||||||
|
|
||||||
|
if self.__use_weights:
|
||||||
|
tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length)
|
||||||
|
else:
|
||||||
|
tags_caption = self.__get_shuffled_tags(seed, self.__tags)
|
||||||
|
|
||||||
|
return self.__main_prompt + ", " + tags_caption
|
||||||
|
|
||||||
|
def get_caption(self) -> str:
|
||||||
|
return self.__main_prompt + ", " + ", ".join(self.__tags)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __get_weighted_shuffled_tags(seed: int, tags: list[str], weights: list[float], max_target_tag_length: int) -> str:
|
||||||
|
picker = random.Random(seed)
|
||||||
|
tags_copy = tags.copy()
|
||||||
|
weights_copy = weights.copy()
|
||||||
|
|
||||||
|
caption = ""
|
||||||
|
while len(tags_copy) != 0 and len(caption) < max_target_tag_length:
|
||||||
|
cum_weights = []
|
||||||
|
weight_sum = 0.0
|
||||||
|
for weight in weights_copy:
|
||||||
|
weight_sum += weight
|
||||||
|
cum_weights.append(weight_sum)
|
||||||
|
|
||||||
|
point = picker.uniform(0, weight_sum)
|
||||||
|
pos = bisect.bisect_left(cum_weights, point)
|
||||||
|
|
||||||
|
weights_copy.pop(pos)
|
||||||
|
tag = tags_copy.pop(pos)
|
||||||
|
caption += ", " + tag
|
||||||
|
|
||||||
|
return caption
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __get_shuffled_tags(seed: int, tags: list[str]) -> str:
|
||||||
|
random.Random(seed).shuffle(tags)
|
||||||
|
return ", ".join(tags)
|
||||||
|
|
||||||
|
|
||||||
class ImageTrainItem():
|
class ImageTrainItem():
|
||||||
"""
|
"""
|
||||||
image: PIL.Image
|
image: PIL.Image
|
||||||
|
@ -31,7 +105,8 @@ class ImageTrainItem():
|
||||||
pathname: path to image file
|
pathname: path to image file
|
||||||
flip_p: probability of flipping image (0.0 to 1.0)
|
flip_p: probability of flipping image (0.0 to 1.0)
|
||||||
"""
|
"""
|
||||||
def __init__(self, image: PIL.Image, caption: str, target_wh: list, pathname: str, flip_p=0.0):
|
|
||||||
|
def __init__(self, image: PIL.Image, caption: ImageCaption, target_wh: list, pathname: str, flip_p=0.0):
|
||||||
self.caption = caption
|
self.caption = caption
|
||||||
self.target_wh = target_wh
|
self.target_wh = target_wh
|
||||||
self.pathname = pathname
|
self.pathname = pathname
|
||||||
|
|
|
@ -0,0 +1,67 @@
|
||||||
|
# Shuffling tags randomly during training
|
||||||
|
|
||||||
|
## General shuffling
|
||||||
|
|
||||||
|
To help the model generalize better, EveryDream has an option to shuffle tags during the training.
|
||||||
|
|
||||||
|
This behavior can be activated using the parameter _--shuffle_tags_. The default is off.
|
||||||
|
|
||||||
|
The provided caption, extracted either from the file name or the provided caption file,
|
||||||
|
will be split at each "_,_" into separate chunks.
|
||||||
|
|
||||||
|
The first chunk will always be included in the caption provided during the training,
|
||||||
|
the additional chunks are shuffled into a random order.
|
||||||
|
|
||||||
|
Each epoch the order is reshuffled. _(Remember that each image is shown one per epoch to the model)_
|
||||||
|
|
||||||
|
|
||||||
|
## Weighted shuffling
|
||||||
|
|
||||||
|
EveryDream can read caption definitions from YAML files, for fine-tuned definitions.
|
||||||
|
|
||||||
|
EveryDream will check for each image if a file with the same name and the extension _.yaml_ is provided.
|
||||||
|
|
||||||
|
The expected format of the YAML file:
|
||||||
|
````yaml
|
||||||
|
main_prompt: A portrait of Cloud Strife
|
||||||
|
tags:
|
||||||
|
- tag: low angle shot
|
||||||
|
- tag: looking to the side
|
||||||
|
- tag: holding buster sword
|
||||||
|
weight: 1.5
|
||||||
|
- tag: clouds in background
|
||||||
|
weight: 0.5
|
||||||
|
- tag: smiling
|
||||||
|
weight: 0.8
|
||||||
|
max_caption_length: 1024
|
||||||
|
````
|
||||||
|
|
||||||
|
THe main prompt will always be the first part included in the caption.
|
||||||
|
The main prompt is optional, you can provide none if you do not want a fixed part at the beginning of the caption.
|
||||||
|
|
||||||
|
This is followed by a list of tags. The tags will be shuffled into a random order and added to the caption.
|
||||||
|
The tags list is optional.
|
||||||
|
|
||||||
|
The default weight of each tag is _1.0_. A different weight can be optionally specified.
|
||||||
|
Tags with a higher weight have a higher chance to appear in the front of the caption tag list.
|
||||||
|
|
||||||
|
The optional parameter _max_caption_length_ allows the definition of a maximum length of the assembled caption.
|
||||||
|
Only whole tags will be processed. If the addition of the next tag exceeds the _max_caption_length_,
|
||||||
|
it will not be added, and the caption will be provided without the other tags for this epoch.
|
||||||
|
|
||||||
|
This can be used to train the model that an image can include a certain aspect, even if it is not
|
||||||
|
explicitly mentioned in the caption.
|
||||||
|
|
||||||
|
|
||||||
|
## General notes regarding token length
|
||||||
|
|
||||||
|
For SD, the current implementation of EveryDream can only process the first 75 tokens
|
||||||
|
provided in the caption during training.
|
||||||
|
|
||||||
|
This is a base limitation of the SD Models. Workaround exists to extend this number but are currently not
|
||||||
|
implemented in EveryDream.
|
||||||
|
|
||||||
|
The effect of the limit is that the caption will always be truncated when the maximum number of tokens is
|
||||||
|
exceeded. This process does not consider if the cutoff is in the middle of a tag or even in the middle of a
|
||||||
|
word if it is translated into several tokens.
|
||||||
|
|
Loading…
Reference in New Issue