Merge pull request #7 from JanGerritsen/yaml_caption_files

Support more control regarding caption tag shuffeling using yaml files
This commit is contained in:
Victor Hall 2023-01-09 13:58:14 -08:00 committed by GitHub
commit 99d8c6bc32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 257 additions and 59 deletions

View File

@ -34,3 +34,5 @@ Behind the scenes look at how the trainer handles multiaspect and crop jitter
[Advanced Tweaking](doc/ATWEAKING.md)
[Chaining training sessions](doc/CHAINING.md)
[Shuffling Tags](doc/SHUFFLING_TAGS.md)

View File

@ -16,9 +16,11 @@ limitations under the License.
import os
import logging
import yaml
from PIL import Image
import random
from data.image_train_item import ImageTrainItem
from data.image_train_item import ImageTrainItem, ImageCaption
import data.aspects as aspects
from colorama import Fore, Style
import zipfile
@ -27,6 +29,8 @@ import PIL
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
DEFAULT_MAX_CAPTION_LENGTH = 2048
class DataLoaderMultiAspect():
"""
Data loader for multi-aspect-ratio training and bucketing
@ -54,7 +58,7 @@ class DataLoaderMultiAspect():
random.Random(seed).shuffle(self.image_paths)
self.prepared_train_data = self.__prescan_images(self.image_paths, flip_p) # ImageTrainItem[]
self.image_caption_pairs = self.__bucketize_images(self.prepared_train_data, batch_size=batch_size, debug_level=debug_level)
def shuffle(self):
self.runts = []
self.seed = self.seed + 1
@ -76,17 +80,64 @@ class DataLoaderMultiAspect():
return self.image_caption_pairs
@staticmethod
def __read_caption_from_file(file_path, fallback_caption):
caption = fallback_caption
def __read_caption_from_file(file_path, fallback_caption: ImageCaption) -> ImageCaption:
try:
with open(file_path, encoding='utf-8', mode='r') as caption_file:
caption = caption_file.read()
caption_text = caption_file.read()
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_text)
except:
logging.error(f" *** Error reading {file_path} to get caption, falling back to filename")
caption = fallback_caption
pass
return caption
@staticmethod
def __read_caption_from_yaml(file_path: str, fallback_caption: ImageCaption) -> ImageCaption:
with open(file_path, "r") as stream:
try:
file_content = yaml.safe_load(stream)
main_prompt = file_content.get("main_prompt", "")
unparsed_tags = file_content.get("tags", [])
max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH)
tags = []
tag_weights = []
last_weight = None
weights_differ = False
for unparsed_tag in unparsed_tags:
tag = unparsed_tag.get("tag", "").strip()
if len(tag) == 0:
continue
tags.append(tag)
tag_weight = unparsed_tag.get("weight", 1.0)
tag_weights.append(tag_weight)
if last_weight is not None and weights_differ is False:
weights_differ = last_weight != tag_weight
last_weight = tag_weight
return ImageCaption(main_prompt, tags, tag_weights, max_caption_length, weights_differ)
except:
logging.error(f" *** Error reading {file_path} to get caption, falling back to filename")
return fallback_caption
@staticmethod
def __split_caption_into_tags(caption_string: str) -> ImageCaption:
"""
Splits a string by "," into the main prompt and additional tags with equal weights
"""
split_caption = caption_string.split(",")
main_prompt = split_caption.pop(0).strip()
tags = []
for tag in split_caption:
tags.append(tag.strip())
return ImageCaption(main_prompt, tags, [1.0] * len(tags), DEFAULT_MAX_CAPTION_LENGTH, False)
def __prescan_images(self, image_paths: list, flip_p=0.0):
"""
Create ImageTrainItem objects with metadata for hydration later
@ -95,16 +146,19 @@ class DataLoaderMultiAspect():
for pathname in tqdm.tqdm(image_paths):
caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0]
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_from_filename)
txt_file_path = os.path.splitext(pathname)[0] + ".txt"
caption_file_path = os.path.splitext(pathname)[0] + ".caption"
file_path_without_ext = os.path.splitext(pathname)[0]
yaml_file_path = file_path_without_ext + ".yaml"
txt_file_path = file_path_without_ext + ".txt"
caption_file_path = file_path_without_ext + ".caption"
if os.path.exists(txt_file_path):
caption = self.__read_caption_from_file(txt_file_path, caption_from_filename)
if os.path.exists(yaml_file_path):
caption = self.__read_caption_from_yaml(yaml_file_path, caption)
elif os.path.exists(txt_file_path):
caption = self.__read_caption_from_file(txt_file_path, caption)
elif os.path.exists(caption_file_path):
caption = self.__read_caption_from_file(caption_file_path, caption_from_filename)
else:
caption = caption_from_filename
caption = self.__read_caption_from_file(caption_file_path, caption)
try:
image = Image.open(pathname)
@ -165,7 +219,7 @@ class DataLoaderMultiAspect():
multiply = 1
multiply_path = os.path.join(recurse_root, "multiply.txt")
if os.path.exists(multiply_path):
try:
try:
with open(multiply_path, encoding='utf-8', mode='r') as f:
multiply = int(float(f.read().strip()))
logging.info(f" * DLMA multiply.txt in {recurse_root} set to {multiply}")

View File

@ -103,6 +103,7 @@ class EveryDreamBatch(Dataset):
return dls.shared_dataloader.runts
def shuffle(self, epoch_n):
self.seed += 1
if dls.shared_dataloader:
dls.shared_dataloader.shuffle()
self.image_train_items = dls.shared_dataloader.get_all_images()
@ -134,16 +135,15 @@ class EveryDreamBatch(Dataset):
]
)
if self.shuffle_tags and "," in train_item['caption']:
tags = train_item["caption"].split(",")
random.Random(self.seed).shuffle(tags)
self.seed += 1
train_item["caption"] = ", ".join(tags)
if self.shuffle_tags:
example["caption"] = train_item["caption"].get_shuffled_caption(self.seed)
else:
example["caption"] = train_item["caption"].get_caption()
example["image"] = image_transforms(train_item["image"])
if random.random() > self.conditional_dropout:
example["tokens"] = self.tokenizer(train_item["caption"],
example["tokens"] = self.tokenizer(example["caption"],
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
@ -156,7 +156,7 @@ class EveryDreamBatch(Dataset):
).input_ids
example["tokens"] = torch.tensor(example["tokens"])
example["caption"] = train_item["caption"] # for sampling if needed
example["runt_size"] = train_item["runt_size"]
return example

View File

@ -13,25 +13,100 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import PIL
import numpy as np
from torchvision import transforms, utils
import random
import bisect
import logging
import math
import os
import logging
import random
import PIL
import numpy as np
from torchvision import transforms
_RANDOM_TRIM = 0.04
class ImageTrainItem():
class ImageCaption:
"""
Represents the various parts of an image caption
"""
def __init__(self, main_prompt: str, tags: list[str], tag_weights: list[float], max_target_length: int, use_weights: bool):
"""
:param main_prompt: The part of the caption which should always be included
:param tags: list of tags to pick from to fill the caption
:param tag_weights: weights to indicate which tags are more desired and should be picked preferably
:param max_target_length: The desired maximum length of a generated caption
:param use_weights: if ture, weights are considered when shuffling tags
"""
self.__main_prompt = main_prompt
self.__tags = tags
self.__tag_weights = tag_weights
self.__max_target_length = max_target_length
self.__use_weights = use_weights
if use_weights and len(tags) > len(tag_weights):
self.__tag_weights.extend([1.0] * (len(tags) - len(tag_weights)))
if use_weights and len(tag_weights) > len(tags):
self.__tag_weights = tag_weights[:len(tags)]
def get_shuffled_caption(self, seed: int) -> str:
"""
returns the caption a string with a random selection of the tags in random order
:param seed used to initialize the randomizer
:return: generated caption string
"""
max_target_tag_length = self.__max_target_length - len(self.__main_prompt)
if self.__use_weights:
tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length)
else:
tags_caption = self.__get_shuffled_tags(seed, self.__tags)
return self.__main_prompt + ", " + tags_caption
def get_caption(self) -> str:
return self.__main_prompt + ", " + ", ".join(self.__tags)
@staticmethod
def __get_weighted_shuffled_tags(seed: int, tags: list[str], weights: list[float], max_target_tag_length: int) -> str:
picker = random.Random(seed)
tags_copy = tags.copy()
weights_copy = weights.copy()
caption = ""
while len(tags_copy) != 0 and len(caption) < max_target_tag_length:
cum_weights = []
weight_sum = 0.0
for weight in weights_copy:
weight_sum += weight
cum_weights.append(weight_sum)
point = picker.uniform(0, weight_sum)
pos = bisect.bisect_left(cum_weights, point)
weights_copy.pop(pos)
tag = tags_copy.pop(pos)
caption += ", " + tag
return caption
@staticmethod
def __get_shuffled_tags(seed: int, tags: list[str]) -> str:
random.Random(seed).shuffle(tags)
return ", ".join(tags)
class ImageTrainItem():
"""
image: PIL.Image
identifier: caption,
target_aspect: (width, height),
pathname: path to image file
flip_p: probability of flipping image (0.0 to 1.0)
"""
def __init__(self, image: PIL.Image, caption: str, target_wh: list, pathname: str, flip_p=0.0):
"""
def __init__(self, image: PIL.Image, caption: ImageCaption, target_wh: list, pathname: str, flip_p=0.0):
self.caption = caption
self.target_wh = target_wh
self.pathname = pathname
@ -50,50 +125,50 @@ class ImageTrainItem():
save: save the cropped image to disk, for manual inspection of resize/crop
crop_jitter: randomly shift cropp by N pixels when using multiple aspect ratios to improve training quality
"""
#print(self.pathname, self.image)
# print(self.pathname, self.image)
try:
#if not hasattr(self, 'image'):
# if not hasattr(self, 'image'):
self.image = PIL.Image.open(self.pathname).convert('RGB')
width, height = self.image.size
if crop:
if crop:
cropped_img = self.__autocrop(self.image)
self.image = cropped_img.resize((512,512), resample=PIL.Image.BICUBIC)
self.image = cropped_img.resize((512, 512), resample=PIL.Image.BICUBIC)
else:
width, height = self.image.size
jitter_amount = random.randint(0,crop_jitter)
jitter_amount = random.randint(0, crop_jitter)
if self.target_wh[0] == self.target_wh[1]:
if width > height:
left = random.randint(0, width - height)
self.image = self.image.crop((left, 0, height+left, height))
self.image = self.image.crop((left, 0, height + left, height))
width = height
elif height > width:
top = random.randint(0, height - width)
self.image = self.image.crop((0, top, width, width+top))
self.image = self.image.crop((0, top, width, width + top))
height = width
elif width > self.target_wh[0]:
slice = min(int(self.target_wh[0] * _RANDOM_TRIM), width-self.target_wh[0])
slice = min(int(self.target_wh[0] * _RANDOM_TRIM), width - self.target_wh[0])
slicew_ratio = random.random()
left = int(slice*slicew_ratio)
right = width-int(slice*(1-slicew_ratio))
left = int(slice * slicew_ratio)
right = width - int(slice * (1 - slicew_ratio))
sliceh_ratio = random.random()
top = int(slice*sliceh_ratio)
bottom = height- int(slice*(1-sliceh_ratio))
top = int(slice * sliceh_ratio)
bottom = height - int(slice * (1 - sliceh_ratio))
self.image = self.image.crop((left, top, right, bottom))
else:
image_aspect = width / height
else:
image_aspect = width / height
target_aspect = self.target_wh[0] / self.target_wh[1]
if image_aspect > target_aspect:
new_width = int(height * target_aspect)
jitter_amount = max(min(jitter_amount, int(abs(width-new_width)/2)), 0)
jitter_amount = max(min(jitter_amount, int(abs(width - new_width) / 2)), 0)
left = jitter_amount
right = left + new_width
self.image = self.image.crop((left, 0, right, height))
else:
new_height = int(width / target_aspect)
jitter_amount = max(min(jitter_amount, int(abs(height-new_height)/2)), 0)
jitter_amount = max(min(jitter_amount, int(abs(height - new_height) / 2)), 0)
top = jitter_amount
bottom = top + new_height
self.image = self.image.crop((0, top, width, bottom))
@ -106,17 +181,17 @@ class ImageTrainItem():
exit()
if type(self.image) is not np.ndarray:
if save:
if save:
base_name = os.path.basename(self.pathname)
if not os.path.exists("test/output"):
os.makedirs("test/output")
self.image.save(f"test/output/{base_name}")
self.image = np.array(self.image).astype(np.uint8)
#self.image = (self.image / 127.5 - 1.0).astype(np.float32)
#print(self.image.shape)
# self.image = (self.image / 127.5 - 1.0).astype(np.float32)
# print(self.image.shape)
return self
@ -128,25 +203,25 @@ class ImageTrainItem():
x, y = image.size
if x != y:
if (x>y):
rand_x = x-y
sigma = max(rand_x*q,1)
if (x > y):
rand_x = x - y
sigma = max(rand_x * q, 1)
else:
rand_y = y-x
sigma = max(rand_y*q,1)
rand_y = y - x
sigma = max(rand_y * q, 1)
if (x>y):
if (x > y):
x_crop_gauss = abs(random.gauss(0, sigma))
x_crop = min(x_crop_gauss,(x-y)/2)
x_crop = min(x_crop_gauss, (x - y) / 2)
x_crop = math.trunc(x_crop)
y_crop = 0
else:
y_crop_gauss = abs(random.gauss(0, sigma))
x_crop = 0
y_crop = min(y_crop_gauss,(y-x)/2)
y_crop = min(y_crop_gauss, (y - x) / 2)
y_crop = math.trunc(y_crop)
min_xy = min(x, y)
image = image.crop((x_crop, y_crop, x_crop + min_xy, y_crop + min_xy))
return image
return image

67
doc/SHUFFLING_TAGS.md Normal file
View File

@ -0,0 +1,67 @@
# Shuffling tags randomly during training
## General shuffling
To help the model generalize better, EveryDream has an option to shuffle tags during the training.
This behavior can be activated using the parameter _--shuffle_tags_. The default is off.
The provided caption, extracted either from the file name or the provided caption file,
will be split at each "_,_" into separate chunks.
The first chunk will always be included in the caption provided during the training,
the additional chunks are shuffled into a random order.
Each epoch the order is reshuffled. _(Remember that each image is shown one per epoch to the model)_
## Weighted shuffling
EveryDream can read caption definitions from YAML files, for fine-tuned definitions.
EveryDream will check for each image if a file with the same name and the extension _.yaml_ is provided.
The expected format of the YAML file:
````yaml
main_prompt: A portrait of Cloud Strife
tags:
- tag: low angle shot
- tag: looking to the side
- tag: holding buster sword
weight: 1.5
- tag: clouds in background
weight: 0.5
- tag: smiling
weight: 0.8
max_caption_length: 1024
````
THe main prompt will always be the first part included in the caption.
The main prompt is optional, you can provide none if you do not want a fixed part at the beginning of the caption.
This is followed by a list of tags. The tags will be shuffled into a random order and added to the caption.
The tags list is optional.
The default weight of each tag is _1.0_. A different weight can be optionally specified.
Tags with a higher weight have a higher chance to appear in the front of the caption tag list.
The optional parameter _max_caption_length_ allows the definition of a maximum length of the assembled caption.
Only whole tags will be processed. If the addition of the next tag exceeds the _max_caption_length_,
it will not be added, and the caption will be provided without the other tags for this epoch.
This can be used to train the model that an image can include a certain aspect, even if it is not
explicitly mentioned in the caption.
## General notes regarding token length
For SD, the current implementation of EveryDream can only process the first 75 tokens
provided in the caption during training.
This is a base limitation of the SD Models. Workaround exists to extend this number but are currently not
implemented in EveryDream.
The effect of the limit is that the caption will always be truncated when the maximum number of tokens is
exceeded. This process does not consider if the cutoff is in the middle of a tag or even in the middle of a
word if it is translated into several tokens.