Implemented loading captions from yaml file
This commit is contained in:
parent
a3618409bc
commit
3d2709ace9
|
@ -16,6 +16,8 @@ limitations under the License.
|
|||
|
||||
import os
|
||||
import logging
|
||||
|
||||
import yaml
|
||||
from PIL import Image
|
||||
import random
|
||||
from data.image_train_item import ImageTrainItem, ImageCaption
|
||||
|
@ -54,7 +56,7 @@ class DataLoaderMultiAspect():
|
|||
random.Random(seed).shuffle(self.image_paths)
|
||||
self.prepared_train_data = self.__prescan_images(self.image_paths, flip_p) # ImageTrainItem[]
|
||||
self.image_caption_pairs = self.__bucketize_images(self.prepared_train_data, batch_size=batch_size, debug_level=debug_level)
|
||||
|
||||
|
||||
def shuffle(self):
|
||||
self.runts = []
|
||||
self.seed = self.seed + 1
|
||||
|
@ -87,6 +89,30 @@ class DataLoaderMultiAspect():
|
|||
pass
|
||||
return caption
|
||||
|
||||
@staticmethod
|
||||
def __read_caption_from_yaml(file_path: str, fallback_caption: ImageCaption) -> ImageCaption:
|
||||
with open(file_path, "r") as stream:
|
||||
try:
|
||||
file_content = yaml.safe_load(stream)
|
||||
main_prompt = file_content.get("main_prompt", "")
|
||||
unparsed_tags = file_content.get("tags", [])
|
||||
|
||||
tags = []
|
||||
tag_weights = []
|
||||
for unparsed_tag in unparsed_tags:
|
||||
tag = unparsed_tag.get("tag", "").strip()
|
||||
if len(tag) == 0:
|
||||
continue
|
||||
|
||||
tags.append(tag)
|
||||
tag_weights.append(unparsed_tag.get("weight", 1.0))
|
||||
|
||||
return ImageCaption(main_prompt, tags, tag_weights)
|
||||
|
||||
except:
|
||||
logging.error(f" *** Error reading {file_path} to get caption, falling back to filename")
|
||||
return fallback_caption
|
||||
|
||||
@staticmethod
|
||||
def __split_caption_into_tags(caption_string: str) -> ImageCaption:
|
||||
"""
|
||||
|
@ -110,10 +136,14 @@ class DataLoaderMultiAspect():
|
|||
caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0]
|
||||
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_from_filename)
|
||||
|
||||
txt_file_path = os.path.splitext(pathname)[0] + ".txt"
|
||||
caption_file_path = os.path.splitext(pathname)[0] + ".caption"
|
||||
file_path_without_ext = os.path.splitext(pathname)[0]
|
||||
yaml_file_path = file_path_without_ext + ".yaml"
|
||||
txt_file_path = file_path_without_ext + ".txt"
|
||||
caption_file_path = file_path_without_ext + ".caption"
|
||||
|
||||
if os.path.exists(txt_file_path):
|
||||
if os.path.exists(yaml_file_path):
|
||||
caption = self.__read_caption_from_yaml(yaml_file_path, caption)
|
||||
elif os.path.exists(txt_file_path):
|
||||
caption = self.__read_caption_from_file(txt_file_path, caption)
|
||||
elif os.path.exists(caption_file_path):
|
||||
caption = self.__read_caption_from_file(caption_file_path, caption)
|
||||
|
@ -177,7 +207,7 @@ class DataLoaderMultiAspect():
|
|||
multiply = 1
|
||||
multiply_path = os.path.join(recurse_root, "multiply.txt")
|
||||
if os.path.exists(multiply_path):
|
||||
try:
|
||||
try:
|
||||
with open(multiply_path, encoding='utf-8', mode='r') as f:
|
||||
multiply = int(float(f.read().strip()))
|
||||
logging.info(f" * DLMA multiply.txt in {recurse_root} set to {multiply}")
|
||||
|
|
|
@ -103,6 +103,7 @@ class EveryDreamBatch(Dataset):
|
|||
return dls.shared_dataloader.runts
|
||||
|
||||
def shuffle(self, epoch_n):
|
||||
self.seed += 1
|
||||
if dls.shared_dataloader:
|
||||
dls.shared_dataloader.shuffle()
|
||||
self.image_train_items = dls.shared_dataloader.get_all_images()
|
||||
|
|
Loading…
Reference in New Issue