91 lines
3.5 KiB
Python
91 lines
3.5 KiB
Python
from torch.utils.data import Dataset
|
|
from ldm.data.data_loader import DataLoaderMultiAspect as dlma
|
|
import math
|
|
import ldm.data.dl_singleton as dls
|
|
from ldm.data.image_train_item import ImageTrainItem
|
|
import random
|
|
|
|
class EveryDreamBatch(Dataset):
|
|
"""
|
|
data_root: root path of all your training images, will be recursively searched for images
|
|
repeats: how many times to repeat each image in the dataset
|
|
flip_p: probability of flipping the image horizontally
|
|
debug_level: 0=none, 1=print drops due to unfilled batches on aspect ratio buckets, 2=debug info per image, 3=save crops to disk for inspection
|
|
batch_size: how many images to return in a batch
|
|
conditional_dropout: probability of dropping the caption for a given image
|
|
resolution: max resolution (relative to square)
|
|
jitter: number of pixels to jitter the crop by, only for non-square images
|
|
"""
|
|
def __init__(self,
|
|
data_root,
|
|
repeats=10,
|
|
flip_p=0.0,
|
|
debug_level=0,
|
|
batch_size=1,
|
|
set='train',
|
|
conditional_dropout=0.02,
|
|
resolution=512,
|
|
crop_jitter=20,
|
|
seed=555,
|
|
):
|
|
self.data_root = data_root
|
|
self.batch_size = batch_size
|
|
self.debug_level = debug_level
|
|
self.conditional_dropout = conditional_dropout
|
|
self.crop_jitter = crop_jitter
|
|
self.unloaded_to_idx = 0
|
|
if seed == -1:
|
|
seed = random.randint(0, 9999)
|
|
|
|
if not dls.shared_dataloader:
|
|
print(" * Creating new dataloader singleton")
|
|
dls.shared_dataloader = dlma(data_root=data_root, seed=seed, debug_level=debug_level, batch_size=self.batch_size, flip_p=flip_p, resolution=resolution)
|
|
|
|
self.image_train_items = dls.shared_dataloader.get_all_images()
|
|
|
|
self.num_images = len(self.image_train_items)
|
|
|
|
self._length = math.trunc(self.num_images * repeats)
|
|
|
|
print()
|
|
print(f" ** Trainer Set: {set}, steps: {self._length / batch_size:.0f}, num_images: {self.num_images}, batch_size: {self.batch_size}, length w/repeats: {self._length}")
|
|
print()
|
|
|
|
def __len__(self):
|
|
return self._length
|
|
|
|
def __getitem__(self, i):
|
|
idx = i % self.num_images
|
|
image_train_item = self.image_train_items[idx]
|
|
example = self.__get_image_for_trainer(image_train_item, self.debug_level)
|
|
|
|
if self.unloaded_to_idx > idx:
|
|
self.unloaded_to_idx = 0
|
|
|
|
if idx % (self.batch_size*3) == 0 and idx > (self.batch_size * 5):
|
|
start_del = self.unloaded_to_idx
|
|
self.unloaded_to_idx = int(idx / self.batch_size)*self.batch_size - self.batch_size*4
|
|
|
|
for j in range(start_del, self.unloaded_to_idx):
|
|
if hasattr(self.image_train_items[j], 'image'):
|
|
del self.image_train_items[j].image
|
|
if self.debug_level > 1: print(f" * Unloaded images from idx {start_del} to {self.unloaded_to_idx}")
|
|
|
|
return example
|
|
|
|
def __get_image_for_trainer(self, image_train_item: ImageTrainItem, debug_level=0):
|
|
example = {}
|
|
|
|
save = debug_level > 2
|
|
|
|
image_train_tmp = image_train_item.hydrate(crop=False, save=save, crop_jitter=self.crop_jitter)
|
|
|
|
example["image"] = image_train_tmp.image
|
|
|
|
if random.random() > self.conditional_dropout:
|
|
example["caption"] = image_train_tmp.caption
|
|
else:
|
|
example["caption"] = " "
|
|
|
|
return example
|