EveryDream-trainer/ldm/data/data_loader.py

136 lines
5.2 KiB
Python
Raw Normal View History

2022-11-02 20:23:09 -06:00
import os
2022-11-03 17:47:54 -06:00
from PIL import Image
import PIL
import random
from ldm.data.image_train_item import ImageTrainItem
import ldm.data.aspects as aspects
from tqdm import tqdm
PIL.Image.MAX_IMAGE_PIXELS = 933120000
2022-11-05 09:41:48 -06:00
class DataLoaderMultiAspect():
2022-11-10 16:29:31 -07:00
"""
Data loader for multi-aspect-ratio training and bucketing
data_root: root folder of training data
batch_size: number of images per batch
flip_p: probability of flipping image horizontally (i.e. 0-0.5)
"""
def __init__(self, data_root, seed=555, debug_level=0, batch_size=1, flip_p=0.0, resolution=512):
2022-11-02 20:23:09 -06:00
self.image_paths = []
2022-11-03 17:47:54 -06:00
self.debug_level = debug_level
self.flip_p = flip_p
2022-11-02 20:23:09 -06:00
self.aspects = aspects.get_aspect_buckets(resolution)
print(f"* DLMA resolution {resolution}, buckets: {self.aspects}")
print(" Preloading images...")
2022-11-02 20:23:09 -06:00
self.__recurse_data_root(self=self, recurse_root=data_root)
random.Random(seed).shuffle(self.image_paths)
2022-11-06 17:59:37 -07:00
prepared_train_data = self.__prescan_images(debug_level, self.image_paths, flip_p) # ImageTrainItem[]
2022-11-05 09:41:48 -06:00
self.image_caption_pairs = self.__bucketize_images(prepared_train_data, batch_size=batch_size, debug_level=debug_level)
2022-11-10 16:29:31 -07:00
if debug_level > 0: print(f" * DLMA Example: {self.image_caption_pairs[0]} images")
2022-11-02 20:23:09 -06:00
def get_all_images(self):
2022-11-05 09:41:48 -06:00
return self.image_caption_pairs
2022-11-03 17:47:54 -06:00
@staticmethod
def __read_caption_from_file(file_path, fallback_caption):
2022-11-13 19:45:51 -07:00
caption = fallback_caption
try:
2022-11-18 20:52:25 -07:00
with open(file_path, encoding='utf-8', mode='r') as caption_file:
2022-11-13 19:45:51 -07:00
caption = caption_file.read()
except:
print(f" *** Error reading {file_path} to get caption, falling back to filename")
caption = fallback_caption
pass
return caption
def __prescan_images(self, debug_level: int, image_paths: list, flip_p=0.0):
2022-11-10 16:29:31 -07:00
"""
Create ImageTrainItem objects with metadata for hydration later
"""
decorated_image_train_items = []
for pathname in tqdm(image_paths):
caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0]
txt_file_path = os.path.splitext(pathname)[0] + ".txt"
2022-11-13 19:45:51 -07:00
caption_file_path = os.path.splitext(pathname)[0] + ".caption"
if os.path.exists(txt_file_path):
2022-11-13 19:45:51 -07:00
caption = self.__read_caption_from_file(txt_file_path, caption_from_filename)
elif os.path.exists(caption_file_path):
caption = self.__read_caption_from_file(caption_file_path, caption_from_filename)
else:
2022-11-13 19:45:51 -07:00
caption = caption_from_filename
#if debug_level > 1: print(f" * DLMA file: {pathname} with caption: {caption}")
2022-11-03 17:47:54 -06:00
image = Image.open(pathname)
width, height = image.size
image_aspect = width / height
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
2022-11-13 19:45:51 -07:00
image_train_item = ImageTrainItem(image=None, caption=caption, target_wh=target_wh, pathname=pathname, flip_p=flip_p)
2022-11-03 17:47:54 -06:00
decorated_image_train_items.append(image_train_item)
return decorated_image_train_items
2022-11-02 20:23:09 -06:00
@staticmethod
2022-11-06 17:59:37 -07:00
def __bucketize_images(prepared_train_data: list, batch_size=1, debug_level=0):
2022-11-10 16:29:31 -07:00
"""
Put images into buckets based on aspect ratio with batch_size*n images per bucket, discards remainder
"""
2022-11-05 09:41:48 -06:00
# TODO: this is not terribly efficient but at least linear time
buckets = {}
for image_caption_pair in prepared_train_data:
2022-11-06 17:59:37 -07:00
target_wh = image_caption_pair.target_wh
2022-11-06 17:59:37 -07:00
if (target_wh[0],target_wh[1]) not in buckets:
buckets[(target_wh[0],target_wh[1])] = []
buckets[(target_wh[0],target_wh[1])].append(image_caption_pair)
2022-11-05 09:41:48 -06:00
print(f" ** Number of buckets: {len(buckets)}")
2022-11-05 09:41:48 -06:00
2022-11-06 17:59:37 -07:00
if len(buckets) > 1:
for bucket in buckets:
truncate_count = len(buckets[bucket]) % batch_size
current_bucket_size = len(buckets[bucket])
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
if debug_level > 0:
print(f" ** Bucket {bucket} with {current_bucket_size} will drop {truncate_count} images due to batch size {batch_size}")
# flatten the buckets
2022-11-05 09:41:48 -06:00
image_caption_pairs = []
for bucket in buckets:
image_caption_pairs.extend(buckets[bucket])
return image_caption_pairs
@staticmethod
def __recurse_data_root(self, recurse_root):
2022-11-02 20:23:09 -06:00
for f in os.listdir(recurse_root):
current = os.path.join(recurse_root, f)
2022-11-05 09:41:48 -06:00
if os.path.isfile(current):
ext = os.path.splitext(f)[1]
if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp']:
self.image_paths.append(current)
2022-11-06 19:28:58 -07:00
2022-11-02 20:23:09 -06:00
sub_dirs = []
for d in os.listdir(recurse_root):
current = os.path.join(recurse_root, d)
if os.path.isdir(current):
sub_dirs.append(current)
for dir in sub_dirs:
self.__recurse_data_root(self=self, recurse_root=dir)