EveryDream-trainer/ldm/data/every_dream.py

91 lines
3.5 KiB
Python

from torch.utils.data import Dataset
from ldm.data.data_loader import DataLoaderMultiAspect as dlma
import math
import ldm.data.dl_singleton as dls
from ldm.data.image_train_item import ImageTrainItem
import random
class EveryDreamBatch(Dataset):
"""
data_root: root path of all your training images, will be recursively searched for images
repeats: how many times to repeat each image in the dataset
flip_p: probability of flipping the image horizontally
debug_level: 0=none, 1=print drops due to unfilled batches on aspect ratio buckets, 2=debug info per image, 3=save crops to disk for inspection
batch_size: how many images to return in a batch
conditional_dropout: probability of dropping the caption for a given image
resolution: max resolution (relative to square)
jitter: number of pixels to jitter the crop by, only for non-square images
"""
def __init__(self,
data_root,
repeats=10,
flip_p=0.0,
debug_level=0,
batch_size=1,
set='train',
conditional_dropout=0.02,
resolution=512,
crop_jitter=20,
seed=555,
):
self.data_root = data_root
self.batch_size = batch_size
self.debug_level = debug_level
self.conditional_dropout = conditional_dropout
self.crop_jitter = crop_jitter
self.unloaded_to_idx = 0
if seed == -1:
seed = random.randint(0, 9999)
if not dls.shared_dataloader:
print(" * Creating new dataloader singleton")
dls.shared_dataloader = dlma(data_root=data_root, seed=seed, debug_level=debug_level, batch_size=self.batch_size, flip_p=flip_p, resolution=resolution)
self.image_train_items = dls.shared_dataloader.get_all_images()
self.num_images = len(self.image_train_items)
self._length = math.trunc(self.num_images * repeats)
print()
print(f" ** Trainer Set: {set}, steps: {self._length / batch_size:.0f}, num_images: {self.num_images}, batch_size: {self.batch_size}, length w/repeats: {self._length}")
print()
def __len__(self):
return self._length
def __getitem__(self, i):
idx = i % self.num_images
image_train_item = self.image_train_items[idx]
example = self.__get_image_for_trainer(image_train_item, self.debug_level)
if self.unloaded_to_idx > idx:
self.unloaded_to_idx = 0
if idx % (self.batch_size*3) == 0 and idx > (self.batch_size * 5):
start_del = self.unloaded_to_idx
self.unloaded_to_idx = int(idx / self.batch_size)*self.batch_size - self.batch_size*4
for j in range(start_del, self.unloaded_to_idx):
if hasattr(self.image_train_items[j], 'image'):
del self.image_train_items[j].image
if self.debug_level > 1: print(f" * Unloaded images from idx {start_del} to {self.unloaded_to_idx}")
return example
def __get_image_for_trainer(self, image_train_item: ImageTrainItem, debug_level=0):
example = {}
save = debug_level > 2
image_train_tmp = image_train_item.hydrate(crop=False, save=save, crop_jitter=self.crop_jitter)
example["image"] = image_train_tmp.image
if random.random() > self.conditional_dropout:
example["caption"] = image_train_tmp.caption
else:
example["caption"] = " "
return example