Implemented loading captions from yaml file

This commit is contained in:
Jan Gerritsen 2023-01-07 19:57:23 +01:00
parent a3618409bc
commit 3d2709ace9
2 changed files with 36 additions and 5 deletions

View File

@ -16,6 +16,8 @@ limitations under the License.
import os
import logging
import yaml
from PIL import Image
import random
from data.image_train_item import ImageTrainItem, ImageCaption
@ -54,7 +56,7 @@ class DataLoaderMultiAspect():
random.Random(seed).shuffle(self.image_paths)
self.prepared_train_data = self.__prescan_images(self.image_paths, flip_p) # ImageTrainItem[]
self.image_caption_pairs = self.__bucketize_images(self.prepared_train_data, batch_size=batch_size, debug_level=debug_level)
def shuffle(self):
self.runts = []
self.seed = self.seed + 1
@ -87,6 +89,30 @@ class DataLoaderMultiAspect():
pass
return caption
@staticmethod
def __read_caption_from_yaml(file_path: str, fallback_caption: ImageCaption) -> ImageCaption:
with open(file_path, "r") as stream:
try:
file_content = yaml.safe_load(stream)
main_prompt = file_content.get("main_prompt", "")
unparsed_tags = file_content.get("tags", [])
tags = []
tag_weights = []
for unparsed_tag in unparsed_tags:
tag = unparsed_tag.get("tag", "").strip()
if len(tag) == 0:
continue
tags.append(tag)
tag_weights.append(unparsed_tag.get("weight", 1.0))
return ImageCaption(main_prompt, tags, tag_weights)
except:
logging.error(f" *** Error reading {file_path} to get caption, falling back to filename")
return fallback_caption
@staticmethod
def __split_caption_into_tags(caption_string: str) -> ImageCaption:
"""
@ -110,10 +136,14 @@ class DataLoaderMultiAspect():
caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0]
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_from_filename)
txt_file_path = os.path.splitext(pathname)[0] + ".txt"
caption_file_path = os.path.splitext(pathname)[0] + ".caption"
file_path_without_ext = os.path.splitext(pathname)[0]
yaml_file_path = file_path_without_ext + ".yaml"
txt_file_path = file_path_without_ext + ".txt"
caption_file_path = file_path_without_ext + ".caption"
if os.path.exists(txt_file_path):
if os.path.exists(yaml_file_path):
caption = self.__read_caption_from_yaml(yaml_file_path, caption)
elif os.path.exists(txt_file_path):
caption = self.__read_caption_from_file(txt_file_path, caption)
elif os.path.exists(caption_file_path):
caption = self.__read_caption_from_file(caption_file_path, caption)
@ -177,7 +207,7 @@ class DataLoaderMultiAspect():
multiply = 1
multiply_path = os.path.join(recurse_root, "multiply.txt")
if os.path.exists(multiply_path):
try:
try:
with open(multiply_path, encoding='utf-8', mode='r') as f:
multiply = int(float(f.read().strip()))
logging.info(f" * DLMA multiply.txt in {recurse_root} set to {multiply}")

View File

@ -103,6 +103,7 @@ class EveryDreamBatch(Dataset):
return dls.shared_dataloader.runts
def shuffle(self, epoch_n):
self.seed += 1
if dls.shared_dataloader:
dls.shared_dataloader.shuffle()
self.image_train_items = dls.shared_dataloader.get_all_images()