From a3618409bca58f6a891fd1de4ffc80c1ad803c96 Mon Sep 17 00:00:00 2001 From: Jan Gerritsen Date: Sat, 7 Jan 2023 17:29:09 +0100 Subject: [PATCH 1/3] Support more control regarding caption tag shuffeling using yaml files --- data/data_loader.py | 28 +++++--- data/every_dream.py | 13 ++-- data/image_train_item.py | 139 ++++++++++++++++++++++++++++----------- 3 files changed, 126 insertions(+), 54 deletions(-) diff --git a/data/data_loader.py b/data/data_loader.py index a0924c6..c1ae1d7 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -18,7 +18,7 @@ import os import logging from PIL import Image import random -from data.image_train_item import ImageTrainItem +from data.image_train_item import ImageTrainItem, ImageCaption import data.aspects as aspects from colorama import Fore, Style import zipfile @@ -76,17 +76,30 @@ class DataLoaderMultiAspect(): return self.image_caption_pairs @staticmethod - def __read_caption_from_file(file_path, fallback_caption): - caption = fallback_caption + def __read_caption_from_file(file_path, fallback_caption: ImageCaption) -> ImageCaption: try: with open(file_path, encoding='utf-8', mode='r') as caption_file: - caption = caption_file.read() + caption_text = caption_file.read() + caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_text) except: logging.error(f" *** Error reading {file_path} to get caption, falling back to filename") caption = fallback_caption pass return caption + @staticmethod + def __split_caption_into_tags(caption_string: str) -> ImageCaption: + """ + Splits a string by "," into the main prompt and additional tags with equal weights + """ + split_caption = caption_string.split(",") + main_prompt = split_caption.pop(0).strip() + tags = [] + for tag in split_caption: + tags.append(tag.strip()) + + return ImageCaption(main_prompt, tags, [1.0] * len(tags)) + def __prescan_images(self, image_paths: list, flip_p=0.0): """ Create ImageTrainItem objects with metadata for hydration later @@ -95,16 +108,15 @@ class DataLoaderMultiAspect(): for pathname in tqdm.tqdm(image_paths): caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0] + caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_from_filename) txt_file_path = os.path.splitext(pathname)[0] + ".txt" caption_file_path = os.path.splitext(pathname)[0] + ".caption" if os.path.exists(txt_file_path): - caption = self.__read_caption_from_file(txt_file_path, caption_from_filename) + caption = self.__read_caption_from_file(txt_file_path, caption) elif os.path.exists(caption_file_path): - caption = self.__read_caption_from_file(caption_file_path, caption_from_filename) - else: - caption = caption_from_filename + caption = self.__read_caption_from_file(caption_file_path, caption) try: image = Image.open(pathname) diff --git a/data/every_dream.py b/data/every_dream.py index 56bcd83..9c049ba 100644 --- a/data/every_dream.py +++ b/data/every_dream.py @@ -134,16 +134,15 @@ class EveryDreamBatch(Dataset): ] ) - if self.shuffle_tags and "," in train_item['caption']: - tags = train_item["caption"].split(",") - random.Random(self.seed).shuffle(tags) - self.seed += 1 - train_item["caption"] = ", ".join(tags) + if self.shuffle_tags: + example["caption"] = train_item["caption"].get_shuffled_caption(self.seed) + else: + example["caption"] = train_item["caption"].get_caption() example["image"] = image_transforms(train_item["image"]) if random.random() > self.conditional_dropout: - example["tokens"] = self.tokenizer(train_item["caption"], + example["tokens"] = self.tokenizer(example["caption"], truncation=True, padding="max_length", max_length=self.tokenizer.model_max_length, @@ -156,7 +155,7 @@ class EveryDreamBatch(Dataset): ).input_ids example["tokens"] = torch.tensor(example["tokens"]) - example["caption"] = train_item["caption"] # for sampling if needed + example["runt_size"] = train_item["runt_size"] return example diff --git a/data/image_train_item.py b/data/image_train_item.py index 1ecc800..8967031 100644 --- a/data/image_train_item.py +++ b/data/image_train_item.py @@ -13,25 +13,86 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ -import PIL -import numpy as np -from torchvision import transforms, utils -import random +import bisect +import logging import math import os -import logging +import random + +import PIL +import numpy as np +from torchvision import transforms _RANDOM_TRIM = 0.04 -class ImageTrainItem(): + +class ImageCaption: + """ + Represents the various parts of an image caption + """ + + def __init__(self, main_prompt: str, tags: list[str], tag_weights: list[float]): + """ + :param main_prompt: The part of the caption which should always be included + :param tags: list of tags to pick from to fill the caption + :param tag_weights: weights to indicate which tags are more desired and should be picked preferably + """ + self.__main_prompt = main_prompt + self.__tags = tags + self.__tag_weights = tag_weights + if len(tags) > len(tag_weights): + self.__tag_weights.extend([1.0] * (len(tags) - len(tag_weights))) + + def get_shuffled_caption(self, seed: int, target_length=150) -> str: + """ + returns the caption a string with a random selection of the tags in random order + :param seed used to initialize the randomizer + :param target_length: maximum desired length of the caption + :return: generated caption string + """ + target_tag_length = target_length - len(self.__main_prompt) + tags_caption = self.__get_tags_caption(seed, self.__tags, self.__tag_weights, target_tag_length) + + return self.__main_prompt + tags_caption + + def get_caption(self) -> str: + return self.__main_prompt + ", ".join(self.__tags) + + @staticmethod + def __get_tags_caption(seed: int, tags: list[str], weights: list[float], target_length: int) -> str: + caption = "" + + picker = random.Random(seed) + tags_copy = tags.copy() + weights_copy = weights.copy() + + while len(tags_copy) != 0 and len(caption) < target_length: + cum_weights = [] + weight_sum = 0.0 + for weight in weights_copy: + weight_sum += weight + cum_weights.append(weight_sum) + + point = picker.uniform(0, weight_sum) + pos = bisect.bisect_left(cum_weights, point) + + weights_copy.pop(pos) + tag = tags_copy.pop(pos) + caption += ", " + tag + + return caption + + +class ImageTrainItem(): """ image: PIL.Image identifier: caption, target_aspect: (width, height), pathname: path to image file flip_p: probability of flipping image (0.0 to 1.0) - """ - def __init__(self, image: PIL.Image, caption: str, target_wh: list, pathname: str, flip_p=0.0): + """ + + def __init__(self, image: PIL.Image, caption: ImageCaption, target_wh: list, pathname: str, flip_p=0.0): self.caption = caption self.target_wh = target_wh self.pathname = pathname @@ -50,50 +111,50 @@ class ImageTrainItem(): save: save the cropped image to disk, for manual inspection of resize/crop crop_jitter: randomly shift cropp by N pixels when using multiple aspect ratios to improve training quality """ - #print(self.pathname, self.image) + # print(self.pathname, self.image) try: - #if not hasattr(self, 'image'): + # if not hasattr(self, 'image'): self.image = PIL.Image.open(self.pathname).convert('RGB') width, height = self.image.size - if crop: + if crop: cropped_img = self.__autocrop(self.image) - self.image = cropped_img.resize((512,512), resample=PIL.Image.BICUBIC) + self.image = cropped_img.resize((512, 512), resample=PIL.Image.BICUBIC) else: width, height = self.image.size - jitter_amount = random.randint(0,crop_jitter) + jitter_amount = random.randint(0, crop_jitter) if self.target_wh[0] == self.target_wh[1]: if width > height: left = random.randint(0, width - height) - self.image = self.image.crop((left, 0, height+left, height)) + self.image = self.image.crop((left, 0, height + left, height)) width = height elif height > width: top = random.randint(0, height - width) - self.image = self.image.crop((0, top, width, width+top)) + self.image = self.image.crop((0, top, width, width + top)) height = width elif width > self.target_wh[0]: - slice = min(int(self.target_wh[0] * _RANDOM_TRIM), width-self.target_wh[0]) + slice = min(int(self.target_wh[0] * _RANDOM_TRIM), width - self.target_wh[0]) slicew_ratio = random.random() - left = int(slice*slicew_ratio) - right = width-int(slice*(1-slicew_ratio)) + left = int(slice * slicew_ratio) + right = width - int(slice * (1 - slicew_ratio)) sliceh_ratio = random.random() - top = int(slice*sliceh_ratio) - bottom = height- int(slice*(1-sliceh_ratio)) + top = int(slice * sliceh_ratio) + bottom = height - int(slice * (1 - sliceh_ratio)) self.image = self.image.crop((left, top, right, bottom)) - else: - image_aspect = width / height + else: + image_aspect = width / height target_aspect = self.target_wh[0] / self.target_wh[1] if image_aspect > target_aspect: new_width = int(height * target_aspect) - jitter_amount = max(min(jitter_amount, int(abs(width-new_width)/2)), 0) + jitter_amount = max(min(jitter_amount, int(abs(width - new_width) / 2)), 0) left = jitter_amount right = left + new_width self.image = self.image.crop((left, 0, right, height)) else: new_height = int(width / target_aspect) - jitter_amount = max(min(jitter_amount, int(abs(height-new_height)/2)), 0) + jitter_amount = max(min(jitter_amount, int(abs(height - new_height) / 2)), 0) top = jitter_amount bottom = top + new_height self.image = self.image.crop((0, top, width, bottom)) @@ -106,17 +167,17 @@ class ImageTrainItem(): exit() if type(self.image) is not np.ndarray: - if save: + if save: base_name = os.path.basename(self.pathname) if not os.path.exists("test/output"): os.makedirs("test/output") self.image.save(f"test/output/{base_name}") - + self.image = np.array(self.image).astype(np.uint8) - #self.image = (self.image / 127.5 - 1.0).astype(np.float32) - - #print(self.image.shape) + # self.image = (self.image / 127.5 - 1.0).astype(np.float32) + + # print(self.image.shape) return self @@ -128,25 +189,25 @@ class ImageTrainItem(): x, y = image.size if x != y: - if (x>y): - rand_x = x-y - sigma = max(rand_x*q,1) + if (x > y): + rand_x = x - y + sigma = max(rand_x * q, 1) else: - rand_y = y-x - sigma = max(rand_y*q,1) + rand_y = y - x + sigma = max(rand_y * q, 1) - if (x>y): + if (x > y): x_crop_gauss = abs(random.gauss(0, sigma)) - x_crop = min(x_crop_gauss,(x-y)/2) + x_crop = min(x_crop_gauss, (x - y) / 2) x_crop = math.trunc(x_crop) y_crop = 0 else: y_crop_gauss = abs(random.gauss(0, sigma)) x_crop = 0 - y_crop = min(y_crop_gauss,(y-x)/2) + y_crop = min(y_crop_gauss, (y - x) / 2) y_crop = math.trunc(y_crop) - + min_xy = min(x, y) image = image.crop((x_crop, y_crop, x_crop + min_xy, y_crop + min_xy)) - return image \ No newline at end of file + return image From 3d2709ace9fa309966cb5c02bc42374ca295f3c7 Mon Sep 17 00:00:00 2001 From: Jan Gerritsen Date: Sat, 7 Jan 2023 19:57:23 +0100 Subject: [PATCH 2/3] Implemented loading captions from yaml file --- data/data_loader.py | 40 +++++++++++++++++++++++++++++++++++----- data/every_dream.py | 1 + 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/data/data_loader.py b/data/data_loader.py index c1ae1d7..faa906f 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -16,6 +16,8 @@ limitations under the License. import os import logging + +import yaml from PIL import Image import random from data.image_train_item import ImageTrainItem, ImageCaption @@ -54,7 +56,7 @@ class DataLoaderMultiAspect(): random.Random(seed).shuffle(self.image_paths) self.prepared_train_data = self.__prescan_images(self.image_paths, flip_p) # ImageTrainItem[] self.image_caption_pairs = self.__bucketize_images(self.prepared_train_data, batch_size=batch_size, debug_level=debug_level) - + def shuffle(self): self.runts = [] self.seed = self.seed + 1 @@ -87,6 +89,30 @@ class DataLoaderMultiAspect(): pass return caption + @staticmethod + def __read_caption_from_yaml(file_path: str, fallback_caption: ImageCaption) -> ImageCaption: + with open(file_path, "r") as stream: + try: + file_content = yaml.safe_load(stream) + main_prompt = file_content.get("main_prompt", "") + unparsed_tags = file_content.get("tags", []) + + tags = [] + tag_weights = [] + for unparsed_tag in unparsed_tags: + tag = unparsed_tag.get("tag", "").strip() + if len(tag) == 0: + continue + + tags.append(tag) + tag_weights.append(unparsed_tag.get("weight", 1.0)) + + return ImageCaption(main_prompt, tags, tag_weights) + + except: + logging.error(f" *** Error reading {file_path} to get caption, falling back to filename") + return fallback_caption + @staticmethod def __split_caption_into_tags(caption_string: str) -> ImageCaption: """ @@ -110,10 +136,14 @@ class DataLoaderMultiAspect(): caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0] caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_from_filename) - txt_file_path = os.path.splitext(pathname)[0] + ".txt" - caption_file_path = os.path.splitext(pathname)[0] + ".caption" + file_path_without_ext = os.path.splitext(pathname)[0] + yaml_file_path = file_path_without_ext + ".yaml" + txt_file_path = file_path_without_ext + ".txt" + caption_file_path = file_path_without_ext + ".caption" - if os.path.exists(txt_file_path): + if os.path.exists(yaml_file_path): + caption = self.__read_caption_from_yaml(yaml_file_path, caption) + elif os.path.exists(txt_file_path): caption = self.__read_caption_from_file(txt_file_path, caption) elif os.path.exists(caption_file_path): caption = self.__read_caption_from_file(caption_file_path, caption) @@ -177,7 +207,7 @@ class DataLoaderMultiAspect(): multiply = 1 multiply_path = os.path.join(recurse_root, "multiply.txt") if os.path.exists(multiply_path): - try: + try: with open(multiply_path, encoding='utf-8', mode='r') as f: multiply = int(float(f.read().strip())) logging.info(f" * DLMA multiply.txt in {recurse_root} set to {multiply}") diff --git a/data/every_dream.py b/data/every_dream.py index 9c049ba..1d3bfe4 100644 --- a/data/every_dream.py +++ b/data/every_dream.py @@ -103,6 +103,7 @@ class EveryDreamBatch(Dataset): return dls.shared_dataloader.runts def shuffle(self, epoch_n): + self.seed += 1 if dls.shared_dataloader: dls.shared_dataloader.shuffle() self.image_train_items = dls.shared_dataloader.get_all_images() From f47ceadcc7144ac93a352e395e8d0a9870e958f8 Mon Sep 17 00:00:00 2001 From: Jan Gerritsen Date: Sat, 7 Jan 2023 22:59:51 +0100 Subject: [PATCH 3/3] Implemented an optimization for the shuffling if all tags have the same weight and added documentation. --- README.md | 2 ++ data/data_loader.py | 18 +++++++++-- data/image_train_item.py | 38 ++++++++++++++++------- doc/SHUFFLING_TAGS.md | 67 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 110 insertions(+), 15 deletions(-) create mode 100644 doc/SHUFFLING_TAGS.md diff --git a/README.md b/README.md index b2d6018..784020f 100644 --- a/README.md +++ b/README.md @@ -34,3 +34,5 @@ Behind the scenes look at how the trainer handles multiaspect and crop jitter [Advanced Tweaking](doc/ATWEAKING.md) [Chaining training sessions](doc/CHAINING.md) + +[Shuffling Tags](doc/SHUFFLING_TAGS.md) diff --git a/data/data_loader.py b/data/data_loader.py index faa906f..269f7af 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -29,6 +29,8 @@ import PIL PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default +DEFAULT_MAX_CAPTION_LENGTH = 2048 + class DataLoaderMultiAspect(): """ Data loader for multi-aspect-ratio training and bucketing @@ -97,17 +99,27 @@ class DataLoaderMultiAspect(): main_prompt = file_content.get("main_prompt", "") unparsed_tags = file_content.get("tags", []) + max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH) + tags = [] tag_weights = [] + last_weight = None + weights_differ = False for unparsed_tag in unparsed_tags: tag = unparsed_tag.get("tag", "").strip() if len(tag) == 0: continue tags.append(tag) - tag_weights.append(unparsed_tag.get("weight", 1.0)) + tag_weight = unparsed_tag.get("weight", 1.0) + tag_weights.append(tag_weight) - return ImageCaption(main_prompt, tags, tag_weights) + if last_weight is not None and weights_differ is False: + weights_differ = last_weight != tag_weight + + last_weight = tag_weight + + return ImageCaption(main_prompt, tags, tag_weights, max_caption_length, weights_differ) except: logging.error(f" *** Error reading {file_path} to get caption, falling back to filename") @@ -124,7 +136,7 @@ class DataLoaderMultiAspect(): for tag in split_caption: tags.append(tag.strip()) - return ImageCaption(main_prompt, tags, [1.0] * len(tags)) + return ImageCaption(main_prompt, tags, [1.0] * len(tags), DEFAULT_MAX_CAPTION_LENGTH, False) def __prescan_images(self, image_paths: list, flip_p=0.0): """ diff --git a/data/image_train_item.py b/data/image_train_item.py index 8967031..c86ac70 100644 --- a/data/image_train_item.py +++ b/data/image_train_item.py @@ -31,42 +31,51 @@ class ImageCaption: Represents the various parts of an image caption """ - def __init__(self, main_prompt: str, tags: list[str], tag_weights: list[float]): + def __init__(self, main_prompt: str, tags: list[str], tag_weights: list[float], max_target_length: int, use_weights: bool): """ :param main_prompt: The part of the caption which should always be included :param tags: list of tags to pick from to fill the caption :param tag_weights: weights to indicate which tags are more desired and should be picked preferably + :param max_target_length: The desired maximum length of a generated caption + :param use_weights: if ture, weights are considered when shuffling tags """ self.__main_prompt = main_prompt self.__tags = tags self.__tag_weights = tag_weights - if len(tags) > len(tag_weights): + self.__max_target_length = max_target_length + self.__use_weights = use_weights + if use_weights and len(tags) > len(tag_weights): self.__tag_weights.extend([1.0] * (len(tags) - len(tag_weights))) - def get_shuffled_caption(self, seed: int, target_length=150) -> str: + if use_weights and len(tag_weights) > len(tags): + self.__tag_weights = tag_weights[:len(tags)] + + def get_shuffled_caption(self, seed: int) -> str: """ returns the caption a string with a random selection of the tags in random order :param seed used to initialize the randomizer - :param target_length: maximum desired length of the caption :return: generated caption string """ - target_tag_length = target_length - len(self.__main_prompt) - tags_caption = self.__get_tags_caption(seed, self.__tags, self.__tag_weights, target_tag_length) + max_target_tag_length = self.__max_target_length - len(self.__main_prompt) - return self.__main_prompt + tags_caption + if self.__use_weights: + tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length) + else: + tags_caption = self.__get_shuffled_tags(seed, self.__tags) + + return self.__main_prompt + ", " + tags_caption def get_caption(self) -> str: - return self.__main_prompt + ", ".join(self.__tags) + return self.__main_prompt + ", " + ", ".join(self.__tags) @staticmethod - def __get_tags_caption(seed: int, tags: list[str], weights: list[float], target_length: int) -> str: - caption = "" - + def __get_weighted_shuffled_tags(seed: int, tags: list[str], weights: list[float], max_target_tag_length: int) -> str: picker = random.Random(seed) tags_copy = tags.copy() weights_copy = weights.copy() - while len(tags_copy) != 0 and len(caption) < target_length: + caption = "" + while len(tags_copy) != 0 and len(caption) < max_target_tag_length: cum_weights = [] weight_sum = 0.0 for weight in weights_copy: @@ -82,6 +91,11 @@ class ImageCaption: return caption + @staticmethod + def __get_shuffled_tags(seed: int, tags: list[str]) -> str: + random.Random(seed).shuffle(tags) + return ", ".join(tags) + class ImageTrainItem(): """ diff --git a/doc/SHUFFLING_TAGS.md b/doc/SHUFFLING_TAGS.md new file mode 100644 index 0000000..5bf0f44 --- /dev/null +++ b/doc/SHUFFLING_TAGS.md @@ -0,0 +1,67 @@ +# Shuffling tags randomly during training + +## General shuffling + +To help the model generalize better, EveryDream has an option to shuffle tags during the training. + +This behavior can be activated using the parameter _--shuffle_tags_. The default is off. + +The provided caption, extracted either from the file name or the provided caption file, +will be split at each "_,_" into separate chunks. + +The first chunk will always be included in the caption provided during the training, +the additional chunks are shuffled into a random order. + +Each epoch the order is reshuffled. _(Remember that each image is shown one per epoch to the model)_ + + +## Weighted shuffling + +EveryDream can read caption definitions from YAML files, for fine-tuned definitions. + +EveryDream will check for each image if a file with the same name and the extension _.yaml_ is provided. + +The expected format of the YAML file: +````yaml +main_prompt: A portrait of Cloud Strife +tags: + - tag: low angle shot + - tag: looking to the side + - tag: holding buster sword + weight: 1.5 + - tag: clouds in background + weight: 0.5 + - tag: smiling + weight: 0.8 +max_caption_length: 1024 +```` + +THe main prompt will always be the first part included in the caption. +The main prompt is optional, you can provide none if you do not want a fixed part at the beginning of the caption. + +This is followed by a list of tags. The tags will be shuffled into a random order and added to the caption. +The tags list is optional. + +The default weight of each tag is _1.0_. A different weight can be optionally specified. +Tags with a higher weight have a higher chance to appear in the front of the caption tag list. + +The optional parameter _max_caption_length_ allows the definition of a maximum length of the assembled caption. +Only whole tags will be processed. If the addition of the next tag exceeds the _max_caption_length_, +it will not be added, and the caption will be provided without the other tags for this epoch. + +This can be used to train the model that an image can include a certain aspect, even if it is not +explicitly mentioned in the caption. + + +## General notes regarding token length + +For SD, the current implementation of EveryDream can only process the first 75 tokens +provided in the caption during training. + +This is a base limitation of the SD Models. Workaround exists to extend this number but are currently not +implemented in EveryDream. + +The effect of the limit is that the caption will always be truncated when the maximum number of tokens is +exceeded. This process does not consider if the cutoff is in the middle of a tag or even in the middle of a +word if it is translated into several tokens. +