From 3d2709ace9fa309966cb5c02bc42374ca295f3c7 Mon Sep 17 00:00:00 2001 From: Jan Gerritsen Date: Sat, 7 Jan 2023 19:57:23 +0100 Subject: [PATCH] Implemented loading captions from yaml file --- data/data_loader.py | 40 +++++++++++++++++++++++++++++++++++----- data/every_dream.py | 1 + 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/data/data_loader.py b/data/data_loader.py index c1ae1d7..faa906f 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -16,6 +16,8 @@ limitations under the License. import os import logging + +import yaml from PIL import Image import random from data.image_train_item import ImageTrainItem, ImageCaption @@ -54,7 +56,7 @@ class DataLoaderMultiAspect(): random.Random(seed).shuffle(self.image_paths) self.prepared_train_data = self.__prescan_images(self.image_paths, flip_p) # ImageTrainItem[] self.image_caption_pairs = self.__bucketize_images(self.prepared_train_data, batch_size=batch_size, debug_level=debug_level) - + def shuffle(self): self.runts = [] self.seed = self.seed + 1 @@ -87,6 +89,30 @@ class DataLoaderMultiAspect(): pass return caption + @staticmethod + def __read_caption_from_yaml(file_path: str, fallback_caption: ImageCaption) -> ImageCaption: + with open(file_path, "r") as stream: + try: + file_content = yaml.safe_load(stream) + main_prompt = file_content.get("main_prompt", "") + unparsed_tags = file_content.get("tags", []) + + tags = [] + tag_weights = [] + for unparsed_tag in unparsed_tags: + tag = unparsed_tag.get("tag", "").strip() + if len(tag) == 0: + continue + + tags.append(tag) + tag_weights.append(unparsed_tag.get("weight", 1.0)) + + return ImageCaption(main_prompt, tags, tag_weights) + + except: + logging.error(f" *** Error reading {file_path} to get caption, falling back to filename") + return fallback_caption + @staticmethod def __split_caption_into_tags(caption_string: str) -> ImageCaption: """ @@ -110,10 +136,14 @@ class DataLoaderMultiAspect(): caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0] caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_from_filename) - txt_file_path = os.path.splitext(pathname)[0] + ".txt" - caption_file_path = os.path.splitext(pathname)[0] + ".caption" + file_path_without_ext = os.path.splitext(pathname)[0] + yaml_file_path = file_path_without_ext + ".yaml" + txt_file_path = file_path_without_ext + ".txt" + caption_file_path = file_path_without_ext + ".caption" - if os.path.exists(txt_file_path): + if os.path.exists(yaml_file_path): + caption = self.__read_caption_from_yaml(yaml_file_path, caption) + elif os.path.exists(txt_file_path): caption = self.__read_caption_from_file(txt_file_path, caption) elif os.path.exists(caption_file_path): caption = self.__read_caption_from_file(caption_file_path, caption) @@ -177,7 +207,7 @@ class DataLoaderMultiAspect(): multiply = 1 multiply_path = os.path.join(recurse_root, "multiply.txt") if os.path.exists(multiply_path): - try: + try: with open(multiply_path, encoding='utf-8', mode='r') as f: multiply = int(float(f.read().strip())) logging.info(f" * DLMA multiply.txt in {recurse_root} set to {multiply}") diff --git a/data/every_dream.py b/data/every_dream.py index 9c049ba..1d3bfe4 100644 --- a/data/every_dream.py +++ b/data/every_dream.py @@ -103,6 +103,7 @@ class EveryDreamBatch(Dataset): return dls.shared_dataloader.runts def shuffle(self, epoch_n): + self.seed += 1 if dls.shared_dataloader: dls.shared_dataloader.shuffle() self.image_train_items = dls.shared_dataloader.get_all_images()