bunch of loader work and optimization stuff for aspect ratio support

This commit is contained in:
Victor Hall 2022-11-06 01:25:03 -04:00
parent b7779591b1
commit aacbde8bc7
8 changed files with 282 additions and 103 deletions

View File

@ -1,6 +1,8 @@
import os
from PIL import Image
import gc
import random
from ldm.data.image_train_item import ImageTrainItem
ASPECTS = [[512,512], # 1 262144\
[576,448],[448,576], # 1.29 258048\
@ -11,19 +13,20 @@ ASPECTS = [[512,512], # 1 262144\
[960,256],[256,960], # 3.75 245760\
[1024,256],[256,1024] # 4 245760\
]
class DataLoaderMultiAspect():
def __init__(self, data_root, seed=555, debug_level=0, batch_size=2):
def __init__(self, data_root, seed=555, debug_level=0, batch_size=1, flip_p=0.0):
self.image_paths = []
self.debug_level = debug_level
self.flip_p = flip_p
self.__recurse_images(self=self, recurse_root=data_root)
print(" Preloading images...")
prepared_train_data = self.__crop_resize_images(debug_level, self.image_paths)
#print(f"prepared data __init__: {prepared_train_data}")
self.__recurse_data_root(self=self, recurse_root=data_root)
random.Random(seed).shuffle(self.image_paths)
prepared_train_data = self.__prescan_images(debug_level, self.image_paths, flip_p)
self.image_caption_pairs = self.__bucketize_images(prepared_train_data, batch_size=batch_size, debug_level=debug_level)
print(f"**** Done loading. Loaded {len(self.image_paths)} images from data_root: {data_root} ****") if self.debug_level > 0 else None
print(self.image_paths) if self.debug_level > 1 else None
print(f" * DLMA Example {self.image_caption_pairs[0]} images")
gc.collect()
@ -31,45 +34,39 @@ class DataLoaderMultiAspect():
return self.image_caption_pairs
@staticmethod
def __crop_resize_images(debug_level, image_paths):
decorated_image_paths = []
print("* Loading images using multi-aspect-ratio loader *") if debug_level > 1 else None
i = 0
for pathname in image_paths:
print(pathname) if debug_level > 1 else None
parts = os.path.basename(pathname).split("_")
parts[-1] = parts[-1].split(".")[0]
identifier = parts[0]
def __prescan_images(debug_level: int, image_paths: list, flip_p=0.0):
decorated_image_train_items = []
for pathname in image_paths:
parts = os.path.basename(pathname).split("_")
# untested
# txt_filename = parts[0] + ".txt"
# if os.path.exists(txt_filename):
# try:
# with open(txt_filename, 'r') as f:
# identifier = f.read()
# identifier.rstrip()
# except:
# print(f" *** Error reading {txt_filename} to get caption")
# identifier = parts[0]
# pass
# else:
# identifier = parts[0]
identifier = parts[0]
image = Image.open(pathname)
width, height = image.size
image_aspect = width / height
closest_aspect = min(ASPECTS, key=lambda x:abs(x[0]/x[1]-image_aspect))
target_wh = min(ASPECTS, key=lambda x:abs(x[0]/x[1]-image_aspect))
target_aspect = closest_aspect[0]/closest_aspect[1]
image_train_item = ImageTrainItem(image=None, caption=identifier, target_wh=target_wh, pathname=pathname, flip_p=flip_p)
if closest_aspect[0] == closest_aspect[1]:
pass
if target_aspect < image_aspect:
crop_width = (width - (height * closest_aspect[0] / closest_aspect[1])) / 2
#print(f" ** Cropping width: {crop_width}") if debug_level > 1 else None
image = image.crop((crop_width, 0, width - crop_width, height))
else:
crop_height = (height - (width * closest_aspect[1] / closest_aspect[0])) / 2
#print(f" ** Cropping height: {crop_height}") if debug_level > 1 else None
image = image.crop((0, crop_height, width, height - crop_height))
image = image.resize((closest_aspect[0], closest_aspect[1]), Image.BICUBIC)
if debug_level > 1:
print(f" ** Multi-aspect debug: saving resized image to outputs/{i}.png")
image.save(f"outputs/{i}.png",format="png")
i += 1
decorated_image_paths.append([image, identifier])
return decorated_image_paths
# put placeholder image in the list and return meta data
decorated_image_train_items.append(image_train_item)
return decorated_image_train_items
@staticmethod
def __bucketize_images(prepared_train_data, batch_size=1, debug_level=0):
@ -77,18 +74,23 @@ class DataLoaderMultiAspect():
buckets = {}
for image_caption_pair in prepared_train_data:
image = image_caption_pair[0]
image = image_caption_pair.image
width, height = image.size
if (width, height) not in buckets:
buckets[(width, height)] = []
buckets[(width, height)].append(image_caption_pair)
buckets[(width, height)].append(image_caption_pair) # [image, identifier, target_aspect, closest_aspect_wh[w,h], pathname]
for bucket in buckets:
truncated_count = len(buckets[bucket]) % batch_size
current_bucket_size = len(buckets[bucket])
buckets[bucket] = buckets[bucket][:current_bucket_size - truncated_count]
print(f" ** Bucket {bucket} with {current_bucket_size} will truncate {truncated_count} images due to batch size {batch_size}") if debug_level > 0 else None
print(f" ** Number of buckets: {len(buckets)}")
if len(buckets) > 1: # don't bother truncating if everything is the same aspect ratio
for bucket in buckets:
truncate_count = len(buckets[bucket]) % batch_size
current_bucket_size = len(buckets[bucket])
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
print(f" ** Bucket {bucket} with {current_bucket_size} will drop {truncate_count} images due to batch size {batch_size}") if debug_level > 0 else None
# flatten the buckets
image_caption_pairs = []
for bucket in buckets:
image_caption_pairs.extend(buckets[bucket])
@ -96,15 +98,18 @@ class DataLoaderMultiAspect():
return image_caption_pairs
@staticmethod
def __recurse_images(self, recurse_root):
def __recurse_data_root(self, recurse_root):
i = 0
for f in os.listdir(recurse_root):
current = os.path.join(recurse_root, f)
# get file ext
if os.path.isfile(current):
i += 1
self.image_paths.append(current)
print(f" ** Found {str(i).rjust(5,' ')} files in", recurse_root) if self.debug_level > 0 else None
ext = os.path.splitext(f)[1]
if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp']:
i += 1
self.image_paths.append(current)
sub_dirs = []
@ -114,4 +119,24 @@ class DataLoaderMultiAspect():
sub_dirs.append(current)
for dir in sub_dirs:
self.__recurse_images(self=self, recurse_root=dir)
self.__recurse_data_root(self=self, recurse_root=dir)
# @staticmethod
# def hydrate_image(self, image_path, target_aspect, closest_aspect_wh):
# image = Image.open(example[4]) # 5 is the path
# print(image)
# width, height = image.size
# image_aspect = width / height
# target_aspect = width / height
# if example[3][0] == example[3][1]:
# pass
# if target_aspect < image_aspect:
# crop_width = (width - (width * example[3][0] / example[3][1])) / 2
# image = image.crop((crop_width, 0, width - crop_width, height))
# else:
# crop_height = (height - (width * example[3][1] / example[3][0])) / 2
# image = image.crop((0, crop_height, width, height - crop_height))
# example[0] = image.resize((example[3][0], example[3][1]), Image.BICUBIC)
# return example

2
ldm/data/dl_singleton.py Normal file
View File

@ -0,0 +1,2 @@
# stop lightning's repeated instantiation of batch train/val/test classes causing multiple sweeps of the same data off disk
shared_dataloader = None

View File

@ -1,30 +1,36 @@
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms
from pathlib import Path
from ldm.data.data_loader import DataLoaderMultiAspect as dlma
import math
import ldm.data.dl_singleton as dls
class EDValidateBatch(Dataset):
def __init__(self,
data_root,
flip_p=0.0,
repeats=1,
debug_level=0,
batch_size=1
batch_size=1,
set='val',
):
print(f"EDValidateBatch batch size: {self.batch_size}") if debug_level > 0 else None
self.data_root = data_root
self.batch_size = batch_size
self.image_caption_pairs = dlma(data_root=data_root, debug_level=debug_level, batch_size=self.batch_size).get_all_images()
# most_subscribed_aspect_ratio = self.most_subscribed_aspect_ratio()
# self.image_caption_pairs = [image_caption_pair for image_caption_pair in self.image_caption_pairs if image_caption_pair[0].size == aspect_ratio]
if not dls.shared_dataloader:
print("Creating new dataloader singleton")
dls.shared_dataloader = dlma(data_root=data_root, debug_level=debug_level, batch_size=self.batch_size)
self.image_caption_pairs = dls.shared_dataloader.get_all_images()
self.num_images = len(self.image_caption_pairs)
self._length = max(math.trunc(self.num_images * repeats), 1)
self._length = max(math.trunc(self.num_images * repeats), batch_size) - self.num_images % self.batch_size
print()
print(f" ** Validation Set: {set}, num_images: {self.num_images}, length: {self._length}, repeats: {repeats}, batch_size: {self.batch_size}, ")
print(f" ** Validation steps: {self._length / batch_size:.0f}")
print()
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
@ -34,8 +40,6 @@ class EDValidateBatch(Dataset):
def __getitem__(self, i):
idx = i % len(self.image_caption_pairs)
example = self.get_image(self.image_caption_pairs[idx])
#print caption and image size
print(f"Caption: {example['image'].shape} {example['caption']}")
return example
def get_image(self, image_caption_pair):
@ -54,22 +58,3 @@ class EDValidateBatch(Dataset):
example["caption"] = identifier
return example
def filter_aspect_ratio(self, aspect_ratio):
# filter the images to only include the given aspect ratio
self.image_caption_pairs = [image_caption_pair for image_caption_pair in self.image_caption_pairs if image_caption_pair[0].size == aspect_ratio]
self.num_images = len(self.image_caption_pairs)
self._length = max(math.trunc(self.num_images * self.repeats), 2)
def most_subscribed_aspect_ratio(self):
# find the image size with the highest number of images
aspect_ratios = {}
for image_caption_pair in self.image_caption_pairs:
image = image_caption_pair[0]
aspect_ratio = image.size
if aspect_ratio in aspect_ratios:
aspect_ratios[aspect_ratio] += 1
else:
aspect_ratios[aspect_ratio] = 1
return max(aspect_ratios, key=aspect_ratios.get)

View File

@ -1,9 +1,11 @@
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms
from pathlib import Path
from ldm.data.data_loader import DataLoaderMultiAspect as dlma
import math
import ldm.data.dl_singleton as dls
from PIL import Image
import gc
class EveryDreamBatch(Dataset):
def __init__(self,
@ -11,43 +13,59 @@ class EveryDreamBatch(Dataset):
repeats=10,
flip_p=0.0,
debug_level=0,
batch_size=1
batch_size=1,
set='train'
):
print(f"EveryDreamBatch batch size: {batch_size}")
#print(f"EveryDreamBatch batch size: {batch_size}")
self.data_root = data_root
self.batch_size = batch_size
self.image_caption_pairs = dlma(data_root=data_root, debug_level=debug_level, batch_size=self.batch_size).get_all_images()
self.flip_p = flip_p
self.num_images = len(self.image_caption_pairs)
if not dls.shared_dataloader:
print(" * Creating new dataloader singleton")
dls.shared_dataloader = dlma(data_root=data_root, debug_level=debug_level, batch_size=self.batch_size, flip_p=self.flip_p)
self.image_train_items = dls.shared_dataloader.get_all_images()
#print(f" * EDB Example {self.image_train_items[0]}")
self.num_images = len(self.image_train_items)
self._length = math.trunc(self.num_images * repeats)
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
print(f" * Training steps: {self._length / batch_size}")
print()
print(f" ** Trainer Set: {set}, steps: {self._length / batch_size:.0f}, num_images: {self.num_images}, batch_size: {self.batch_size}, length w/repeats: {self._length}")
print()
def __len__(self):
return self._length
def __getitem__(self, i):
idx = i % self.num_images
example = self.get_image(self.image_caption_pairs[idx])
#example = self.get_image(self.image_caption_pairs[idx])
image_train_item = self.image_train_items[idx]
#print(f" *** example {example}")
hydrated_image_train_item = image_train_item.hydrate()
example = self.get_image_for_trainer(hydrated_image_train_item)
return example
def get_image(self, image_caption_pair):
def unload_images_over(self, limit):
print(f" ********** Unloading images over limit {limit}")
i = 0
while i < len(self.image_train_items):
print(self.image_train_items[i])
if i > limit:
self.image_train_items[i][0] = Image.new(mode='RGB', size=(1, 1))
i += 1
gc.collect()
def get_image_for_trainer(self, image_train_item):
example = {}
image = image_caption_pair[0]
image_train_tmp = image_train_item.as_formatted()
if not image.mode == "RGB":
image = image.convert("RGB")
identifier = image_caption_pair[1]
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
example["caption"] = identifier
example["image"] = image_train_tmp.image
example["caption"] = image_train_tmp.caption
return example

View File

@ -0,0 +1,31 @@
from PIL import Image
import numpy as np
from torchvision import transforms
class ImageTrainItem(): # [image, identifier, target_aspect, closest_aspect_wh[w,h], pathname]
def __init__(self, image: Image, caption: str, target_wh: list, pathname: str, flip_p=0.0):
self.caption = caption
self.target_wh = target_wh
#self.target_aspect = target_aspect
self.pathname = pathname
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
if image is None:
self.image = Image.new(mode='RGB',size=(1,1))
else:
self.image = image
#image_train_item.image = image.resize((image_train_item.closest_aspect_wh[0], image_train_item.closest_aspect_wh[1]), Image.BICUBIC)
def hydrate(self):
self.image = self.image.resize(self.target_wh, Image.BICUBIC)
if not self.image.mode == "RGB":
self.image = self.image.convert("RGB")
self.image = self.flip(self.image)
self.image = np.array(self.image).astype(np.uint8)
self.image = (self.image / 127.5 - 1.0).astype(np.float32)
return self

53
ldm/data/test_batch.py Normal file
View File

@ -0,0 +1,53 @@
# script to test data loader by itself
# run from training root, edit the data_root manually
# python ldm/data/test_dl.py
import every_dream
import time
s = time.perf_counter()
data_root = "r:/everydream-trainer/training_samples/ff7r"
batch_size = 1
every_dream_batch = every_dream.EveryDreamBatch(data_root=data_root, flip_p=0.0, debug_level=0, batch_size=batch_size, repeats=1)
print(f" *TEST* batch type: {type(every_dream_batch)}")
i = 0
is_next = True
curr_batch = []
while is_next and i < 30 and i < len(every_dream_batch):
try:
example = every_dream_batch[i]
if example is not None:
#print(f"example type: {type(example)}") # dict
#print(f"example keys: {example.keys()}") # dict_keys(['image', 'caption'])
#print(f"example image type: {type(example['image'])}") # numpy.ndarray
if i%batch_size == 0:
curr_batch = example['image'].shape
img_in_right_batch = curr_batch == example['image'].shape
print(f" *TEST* example image shape: {example['image'].shape} {i%batch_size} {img_in_right_batch}")
print(f" *TEST* example caption: {example['caption']}")
if not img_in_right_batch:
raise Exception("Current image in wrong batch")
#print(f"example caption: {example['caption']}") # str
else:
is_next = False
i += 1
except IndexError:
is_next = False
print(f"IndexError: {i}")
pass
# for idx, batches in every_dream_batch:
# print(f"inner example type: {type(batches)}")
# print(type(batches))
# print(type(batches[0]))
# print(dir(batches))
#h, w = batches.image.size
#print(f"{idx:05d}-{idx%6:02d}EveryDreamBatch image caption pair: w:{w} h:{h} {batches.caption[1]}")
print(f" *TEST* test cycles: {i}")
print(f" *TEST* EveryDreamBatch epoch image length: {len(every_dream_batch)}")
elapsed = time.perf_counter() - s
print(f"{__file__} executed in {elapsed:5.2f} seconds.")

18
ldm/data/test_dl.py Normal file
View File

@ -0,0 +1,18 @@
# script to test data loader by itself
# run from training root, edit the data_root manually
# python ldm/data/test_dl.py
import data_loader
data_root = "r:/everydream-trainer/training_samples/multiaspect"
data_loader = data_loader.DataLoaderMultiAspect(data_root=data_root, repeats=1, seed=555, debug_level=2)
image_caption_pairs = data_loader.get_all_images()
print(f"Loaded {len(image_caption_pairs)} image-caption pairs")
for image_caption_pair in image_caption_pairs:
print(image_caption_pair)
print(image_caption_pair[1])
print(f"**** Done loading. Loaded {len(image_caption_pairs)} images from data_root: {data_root} ****")

47
ldm/data/test_validate.py Normal file
View File

@ -0,0 +1,47 @@
# script to test data loader by itself
# run from training root, edit the data_root manually
# python ldm/data/test_dl.py
import ed_validate
data_root = "r:/everydream-trainer/training_samples/multiaspect4"
batch_size = 6
ed_val_batch = ed_validate.EDValidateBatch(data_root=data_root, flip_p=0.0, debug_level=0, batch_size=batch_size, repeats=1)
print(f"batch type: {type(ed_val_batch)}")
i = 0
is_next = True
curr_batch = []
while is_next and i < 84:
try:
example = ed_val_batch[i]
if example is not None:
#print(f"example type: {type(example)}") # dict
#print(f"example keys: {example.keys()}") # dict_keys(['image', 'caption'])
#print(f"example image type: {type(example['image'])}") # numpy.ndarray
if i%batch_size == 0:
curr_batch = example['image'].shape
img_in_right_batch = curr_batch == example['image'].shape
print(f"example image shape: {example['image'].shape} {i%batch_size} {img_in_right_batch}") # (256, 256, 3)
if not img_in_right_batch:
raise Exception("Current image in wrong batch")
#print(f"example caption: {example['caption']}") # str
else:
is_next = False
i += 1
except IndexError:
is_next = False
print(f"IndexError: {i}")
pass
# for idx, batches in every_dream_batch:
# print(f"inner example type: {type(batches)}")
# print(type(batches))
# print(type(batches[0]))
# print(dir(batches))
#h, w = batches.image.size
#print(f"{idx:05d}-{idx%6:02d}EveryDreamBatch image caption pair: w:{w} h:{h} {batches.caption[1]}")
ed_val_batch.image_caption_pairs = [image_caption_pair for image_caption_pair in self.image_caption_pairs if image_caption_pair[0].size == aspect_ratio]
print(f"EveryDreamBatch epoch image length: {len(ed_val_batch)}")