EveryDream-trainer/ldm/data/every_dream.py

91 lines
3.5 KiB
Python
Raw Normal View History

2022-10-22 12:53:01 -06:00
from torch.utils.data import Dataset
2022-11-03 17:47:54 -06:00
from ldm.data.data_loader import DataLoaderMultiAspect as dlma
2022-11-05 09:41:48 -06:00
import math
import ldm.data.dl_singleton as dls
2022-11-10 16:29:31 -07:00
from ldm.data.image_train_item import ImageTrainItem
2022-11-13 19:45:51 -07:00
import random
2022-10-22 12:53:01 -06:00
class EveryDreamBatch(Dataset):
2022-11-13 19:45:51 -07:00
"""
data_root: root path of all your training images, will be recursively searched for images
repeats: how many times to repeat each image in the dataset
flip_p: probability of flipping the image horizontally
debug_level: 0=none, 1=print drops due to unfilled batches on aspect ratio buckets, 2=debug info per image, 3=save crops to disk for inspection
2022-11-13 19:45:51 -07:00
batch_size: how many images to return in a batch
conditional_dropout: probability of dropping the caption for a given image
resolution: max resolution (relative to square)
2022-11-13 19:45:51 -07:00
jitter: number of pixels to jitter the crop by, only for non-square images
"""
2022-10-22 12:53:01 -06:00
def __init__(self,
data_root,
repeats=10,
flip_p=0.0,
2022-11-05 09:41:48 -06:00
debug_level=0,
batch_size=1,
2022-11-13 19:45:51 -07:00
set='train',
2022-11-18 20:52:25 -07:00
conditional_dropout=0.02,
resolution=512,
2022-11-18 20:52:25 -07:00
crop_jitter=20,
seed=555,
2022-10-22 12:53:01 -06:00
):
self.data_root = data_root
2022-11-05 09:41:48 -06:00
self.batch_size = batch_size
2022-11-10 16:29:31 -07:00
self.debug_level = debug_level
2022-11-13 19:45:51 -07:00
self.conditional_dropout = conditional_dropout
self.crop_jitter = crop_jitter
self.unloaded_to_idx = 0
2022-11-18 20:52:25 -07:00
if seed == -1:
seed = random.randint(0, 9999)
if not dls.shared_dataloader:
print(" * Creating new dataloader singleton")
dls.shared_dataloader = dlma(data_root=data_root, seed=seed, debug_level=debug_level, batch_size=self.batch_size, flip_p=flip_p, resolution=resolution)
self.image_train_items = dls.shared_dataloader.get_all_images()
self.num_images = len(self.image_train_items)
2022-10-22 12:53:01 -06:00
2022-11-05 09:41:48 -06:00
self._length = math.trunc(self.num_images * repeats)
2022-10-22 12:53:01 -06:00
print()
print(f" ** Trainer Set: {set}, steps: {self._length / batch_size:.0f}, num_images: {self.num_images}, batch_size: {self.batch_size}, length w/repeats: {self._length}")
print()
2022-10-22 12:53:01 -06:00
def __len__(self):
return self._length
def __getitem__(self, i):
2022-11-05 09:41:48 -06:00
idx = i % self.num_images
image_train_item = self.image_train_items[idx]
2022-11-10 16:29:31 -07:00
example = self.__get_image_for_trainer(image_train_item, self.debug_level)
if self.unloaded_to_idx > idx:
self.unloaded_to_idx = 0
2022-11-18 20:52:25 -07:00
if idx % (self.batch_size*3) == 0 and idx > (self.batch_size * 5):
start_del = self.unloaded_to_idx
self.unloaded_to_idx = int(idx / self.batch_size)*self.batch_size - self.batch_size*4
2022-11-18 20:52:25 -07:00
for j in range(start_del, self.unloaded_to_idx):
if hasattr(self.image_train_items[j], 'image'):
del self.image_train_items[j].image
2022-11-18 20:52:25 -07:00
if self.debug_level > 1: print(f" * Unloaded images from idx {start_del} to {self.unloaded_to_idx}")
return example
2022-10-22 12:53:01 -06:00
2022-11-13 19:45:51 -07:00
def __get_image_for_trainer(self, image_train_item: ImageTrainItem, debug_level=0):
example = {}
2022-10-22 12:53:01 -06:00
save = debug_level > 2
2022-11-10 17:10:11 -07:00
2022-11-13 19:45:51 -07:00
image_train_tmp = image_train_item.hydrate(crop=False, save=save, crop_jitter=self.crop_jitter)
2022-10-22 12:53:01 -06:00
example["image"] = image_train_tmp.image
2022-11-13 19:45:51 -07:00
if random.random() > self.conditional_dropout:
example["caption"] = image_train_tmp.caption
else:
example["caption"] = " "
2022-10-22 12:53:01 -06:00
return example