bunch of loader work and optimization stuff for aspect ratio support
This commit is contained in:
parent
b7779591b1
commit
aacbde8bc7
|
@ -1,6 +1,8 @@
|
|||
import os
|
||||
from PIL import Image
|
||||
import gc
|
||||
import random
|
||||
from ldm.data.image_train_item import ImageTrainItem
|
||||
|
||||
ASPECTS = [[512,512], # 1 262144\
|
||||
[576,448],[448,576], # 1.29 258048\
|
||||
|
@ -11,19 +13,20 @@ ASPECTS = [[512,512], # 1 262144\
|
|||
[960,256],[256,960], # 3.75 245760\
|
||||
[1024,256],[256,1024] # 4 245760\
|
||||
]
|
||||
|
||||
class DataLoaderMultiAspect():
|
||||
def __init__(self, data_root, seed=555, debug_level=0, batch_size=2):
|
||||
def __init__(self, data_root, seed=555, debug_level=0, batch_size=1, flip_p=0.0):
|
||||
self.image_paths = []
|
||||
self.debug_level = debug_level
|
||||
self.flip_p = flip_p
|
||||
|
||||
self.__recurse_images(self=self, recurse_root=data_root)
|
||||
print(" Preloading images...")
|
||||
|
||||
prepared_train_data = self.__crop_resize_images(debug_level, self.image_paths)
|
||||
#print(f"prepared data __init__: {prepared_train_data}")
|
||||
self.__recurse_data_root(self=self, recurse_root=data_root)
|
||||
random.Random(seed).shuffle(self.image_paths)
|
||||
prepared_train_data = self.__prescan_images(debug_level, self.image_paths, flip_p)
|
||||
self.image_caption_pairs = self.__bucketize_images(prepared_train_data, batch_size=batch_size, debug_level=debug_level)
|
||||
|
||||
print(f"**** Done loading. Loaded {len(self.image_paths)} images from data_root: {data_root} ****") if self.debug_level > 0 else None
|
||||
print(self.image_paths) if self.debug_level > 1 else None
|
||||
print(f" * DLMA Example {self.image_caption_pairs[0]} images")
|
||||
|
||||
gc.collect()
|
||||
|
||||
|
@ -31,45 +34,39 @@ class DataLoaderMultiAspect():
|
|||
return self.image_caption_pairs
|
||||
|
||||
@staticmethod
|
||||
def __crop_resize_images(debug_level, image_paths):
|
||||
decorated_image_paths = []
|
||||
print("* Loading images using multi-aspect-ratio loader *") if debug_level > 1 else None
|
||||
i = 0
|
||||
for pathname in image_paths:
|
||||
print(pathname) if debug_level > 1 else None
|
||||
parts = os.path.basename(pathname).split("_")
|
||||
parts[-1] = parts[-1].split(".")[0]
|
||||
identifier = parts[0]
|
||||
def __prescan_images(debug_level: int, image_paths: list, flip_p=0.0):
|
||||
decorated_image_train_items = []
|
||||
|
||||
for pathname in image_paths:
|
||||
parts = os.path.basename(pathname).split("_")
|
||||
|
||||
# untested
|
||||
# txt_filename = parts[0] + ".txt"
|
||||
# if os.path.exists(txt_filename):
|
||||
# try:
|
||||
# with open(txt_filename, 'r') as f:
|
||||
# identifier = f.read()
|
||||
# identifier.rstrip()
|
||||
# except:
|
||||
# print(f" *** Error reading {txt_filename} to get caption")
|
||||
# identifier = parts[0]
|
||||
# pass
|
||||
# else:
|
||||
# identifier = parts[0]
|
||||
|
||||
identifier = parts[0]
|
||||
|
||||
image = Image.open(pathname)
|
||||
width, height = image.size
|
||||
image_aspect = width / height
|
||||
|
||||
closest_aspect = min(ASPECTS, key=lambda x:abs(x[0]/x[1]-image_aspect))
|
||||
target_wh = min(ASPECTS, key=lambda x:abs(x[0]/x[1]-image_aspect))
|
||||
|
||||
target_aspect = closest_aspect[0]/closest_aspect[1]
|
||||
image_train_item = ImageTrainItem(image=None, caption=identifier, target_wh=target_wh, pathname=pathname, flip_p=flip_p)
|
||||
|
||||
if closest_aspect[0] == closest_aspect[1]:
|
||||
pass
|
||||
if target_aspect < image_aspect:
|
||||
crop_width = (width - (height * closest_aspect[0] / closest_aspect[1])) / 2
|
||||
#print(f" ** Cropping width: {crop_width}") if debug_level > 1 else None
|
||||
image = image.crop((crop_width, 0, width - crop_width, height))
|
||||
else:
|
||||
crop_height = (height - (width * closest_aspect[1] / closest_aspect[0])) / 2
|
||||
#print(f" ** Cropping height: {crop_height}") if debug_level > 1 else None
|
||||
image = image.crop((0, crop_height, width, height - crop_height))
|
||||
|
||||
image = image.resize((closest_aspect[0], closest_aspect[1]), Image.BICUBIC)
|
||||
|
||||
if debug_level > 1:
|
||||
print(f" ** Multi-aspect debug: saving resized image to outputs/{i}.png")
|
||||
image.save(f"outputs/{i}.png",format="png")
|
||||
i += 1
|
||||
|
||||
decorated_image_paths.append([image, identifier])
|
||||
|
||||
return decorated_image_paths
|
||||
# put placeholder image in the list and return meta data
|
||||
decorated_image_train_items.append(image_train_item)
|
||||
return decorated_image_train_items
|
||||
|
||||
@staticmethod
|
||||
def __bucketize_images(prepared_train_data, batch_size=1, debug_level=0):
|
||||
|
@ -77,18 +74,23 @@ class DataLoaderMultiAspect():
|
|||
buckets = {}
|
||||
|
||||
for image_caption_pair in prepared_train_data:
|
||||
image = image_caption_pair[0]
|
||||
image = image_caption_pair.image
|
||||
width, height = image.size
|
||||
|
||||
if (width, height) not in buckets:
|
||||
buckets[(width, height)] = []
|
||||
buckets[(width, height)].append(image_caption_pair)
|
||||
buckets[(width, height)].append(image_caption_pair) # [image, identifier, target_aspect, closest_aspect_wh[w,h], pathname]
|
||||
|
||||
for bucket in buckets:
|
||||
truncated_count = len(buckets[bucket]) % batch_size
|
||||
current_bucket_size = len(buckets[bucket])
|
||||
buckets[bucket] = buckets[bucket][:current_bucket_size - truncated_count]
|
||||
print(f" ** Bucket {bucket} with {current_bucket_size} will truncate {truncated_count} images due to batch size {batch_size}") if debug_level > 0 else None
|
||||
print(f" ** Number of buckets: {len(buckets)}")
|
||||
|
||||
if len(buckets) > 1: # don't bother truncating if everything is the same aspect ratio
|
||||
for bucket in buckets:
|
||||
truncate_count = len(buckets[bucket]) % batch_size
|
||||
current_bucket_size = len(buckets[bucket])
|
||||
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
|
||||
print(f" ** Bucket {bucket} with {current_bucket_size} will drop {truncate_count} images due to batch size {batch_size}") if debug_level > 0 else None
|
||||
|
||||
# flatten the buckets
|
||||
image_caption_pairs = []
|
||||
for bucket in buckets:
|
||||
image_caption_pairs.extend(buckets[bucket])
|
||||
|
@ -96,15 +98,18 @@ class DataLoaderMultiAspect():
|
|||
return image_caption_pairs
|
||||
|
||||
@staticmethod
|
||||
def __recurse_images(self, recurse_root):
|
||||
def __recurse_data_root(self, recurse_root):
|
||||
i = 0
|
||||
|
||||
for f in os.listdir(recurse_root):
|
||||
current = os.path.join(recurse_root, f)
|
||||
# get file ext
|
||||
|
||||
if os.path.isfile(current):
|
||||
i += 1
|
||||
self.image_paths.append(current)
|
||||
|
||||
print(f" ** Found {str(i).rjust(5,' ')} files in", recurse_root) if self.debug_level > 0 else None
|
||||
ext = os.path.splitext(f)[1]
|
||||
if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp']:
|
||||
i += 1
|
||||
self.image_paths.append(current)
|
||||
|
||||
sub_dirs = []
|
||||
|
||||
|
@ -114,4 +119,24 @@ class DataLoaderMultiAspect():
|
|||
sub_dirs.append(current)
|
||||
|
||||
for dir in sub_dirs:
|
||||
self.__recurse_images(self=self, recurse_root=dir)
|
||||
self.__recurse_data_root(self=self, recurse_root=dir)
|
||||
|
||||
# @staticmethod
|
||||
# def hydrate_image(self, image_path, target_aspect, closest_aspect_wh):
|
||||
# image = Image.open(example[4]) # 5 is the path
|
||||
# print(image)
|
||||
# width, height = image.size
|
||||
# image_aspect = width / height
|
||||
# target_aspect = width / height
|
||||
|
||||
# if example[3][0] == example[3][1]:
|
||||
# pass
|
||||
# if target_aspect < image_aspect:
|
||||
# crop_width = (width - (width * example[3][0] / example[3][1])) / 2
|
||||
# image = image.crop((crop_width, 0, width - crop_width, height))
|
||||
# else:
|
||||
# crop_height = (height - (width * example[3][1] / example[3][0])) / 2
|
||||
# image = image.crop((0, crop_height, width, height - crop_height))
|
||||
|
||||
# example[0] = image.resize((example[3][0], example[3][1]), Image.BICUBIC)
|
||||
# return example
|
|
@ -0,0 +1,2 @@
|
|||
# stop lightning's repeated instantiation of batch train/val/test classes causing multiple sweeps of the same data off disk
|
||||
shared_dataloader = None
|
|
@ -1,30 +1,36 @@
|
|||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from pathlib import Path
|
||||
from ldm.data.data_loader import DataLoaderMultiAspect as dlma
|
||||
import math
|
||||
|
||||
import ldm.data.dl_singleton as dls
|
||||
class EDValidateBatch(Dataset):
|
||||
def __init__(self,
|
||||
data_root,
|
||||
flip_p=0.0,
|
||||
repeats=1,
|
||||
debug_level=0,
|
||||
batch_size=1
|
||||
batch_size=1,
|
||||
set='val',
|
||||
):
|
||||
print(f"EDValidateBatch batch size: {self.batch_size}") if debug_level > 0 else None
|
||||
|
||||
self.data_root = data_root
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.image_caption_pairs = dlma(data_root=data_root, debug_level=debug_level, batch_size=self.batch_size).get_all_images()
|
||||
|
||||
# most_subscribed_aspect_ratio = self.most_subscribed_aspect_ratio()
|
||||
# self.image_caption_pairs = [image_caption_pair for image_caption_pair in self.image_caption_pairs if image_caption_pair[0].size == aspect_ratio]
|
||||
if not dls.shared_dataloader:
|
||||
print("Creating new dataloader singleton")
|
||||
dls.shared_dataloader = dlma(data_root=data_root, debug_level=debug_level, batch_size=self.batch_size)
|
||||
|
||||
self.image_caption_pairs = dls.shared_dataloader.get_all_images()
|
||||
|
||||
self.num_images = len(self.image_caption_pairs)
|
||||
|
||||
self._length = max(math.trunc(self.num_images * repeats), 1)
|
||||
self._length = max(math.trunc(self.num_images * repeats), batch_size) - self.num_images % self.batch_size
|
||||
|
||||
print()
|
||||
print(f" ** Validation Set: {set}, num_images: {self.num_images}, length: {self._length}, repeats: {repeats}, batch_size: {self.batch_size}, ")
|
||||
print(f" ** Validation steps: {self._length / batch_size:.0f}")
|
||||
print()
|
||||
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
|
||||
|
@ -34,8 +40,6 @@ class EDValidateBatch(Dataset):
|
|||
def __getitem__(self, i):
|
||||
idx = i % len(self.image_caption_pairs)
|
||||
example = self.get_image(self.image_caption_pairs[idx])
|
||||
#print caption and image size
|
||||
print(f"Caption: {example['image'].shape} {example['caption']}")
|
||||
return example
|
||||
|
||||
def get_image(self, image_caption_pair):
|
||||
|
@ -54,22 +58,3 @@ class EDValidateBatch(Dataset):
|
|||
example["caption"] = identifier
|
||||
|
||||
return example
|
||||
|
||||
def filter_aspect_ratio(self, aspect_ratio):
|
||||
# filter the images to only include the given aspect ratio
|
||||
self.image_caption_pairs = [image_caption_pair for image_caption_pair in self.image_caption_pairs if image_caption_pair[0].size == aspect_ratio]
|
||||
self.num_images = len(self.image_caption_pairs)
|
||||
self._length = max(math.trunc(self.num_images * self.repeats), 2)
|
||||
|
||||
def most_subscribed_aspect_ratio(self):
|
||||
# find the image size with the highest number of images
|
||||
aspect_ratios = {}
|
||||
for image_caption_pair in self.image_caption_pairs:
|
||||
image = image_caption_pair[0]
|
||||
aspect_ratio = image.size
|
||||
if aspect_ratio in aspect_ratios:
|
||||
aspect_ratios[aspect_ratio] += 1
|
||||
else:
|
||||
aspect_ratios[aspect_ratio] = 1
|
||||
|
||||
return max(aspect_ratios, key=aspect_ratios.get)
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from pathlib import Path
|
||||
from ldm.data.data_loader import DataLoaderMultiAspect as dlma
|
||||
import math
|
||||
import ldm.data.dl_singleton as dls
|
||||
from PIL import Image
|
||||
import gc
|
||||
|
||||
class EveryDreamBatch(Dataset):
|
||||
def __init__(self,
|
||||
|
@ -11,43 +13,59 @@ class EveryDreamBatch(Dataset):
|
|||
repeats=10,
|
||||
flip_p=0.0,
|
||||
debug_level=0,
|
||||
batch_size=1
|
||||
batch_size=1,
|
||||
set='train'
|
||||
):
|
||||
print(f"EveryDreamBatch batch size: {batch_size}")
|
||||
#print(f"EveryDreamBatch batch size: {batch_size}")
|
||||
self.data_root = data_root
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.image_caption_pairs = dlma(data_root=data_root, debug_level=debug_level, batch_size=self.batch_size).get_all_images()
|
||||
self.flip_p = flip_p
|
||||
|
||||
self.num_images = len(self.image_caption_pairs)
|
||||
if not dls.shared_dataloader:
|
||||
print(" * Creating new dataloader singleton")
|
||||
dls.shared_dataloader = dlma(data_root=data_root, debug_level=debug_level, batch_size=self.batch_size, flip_p=self.flip_p)
|
||||
|
||||
self.image_train_items = dls.shared_dataloader.get_all_images()
|
||||
#print(f" * EDB Example {self.image_train_items[0]}")
|
||||
|
||||
self.num_images = len(self.image_train_items)
|
||||
|
||||
self._length = math.trunc(self.num_images * repeats)
|
||||
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
|
||||
print(f" * Training steps: {self._length / batch_size}")
|
||||
print()
|
||||
print(f" ** Trainer Set: {set}, steps: {self._length / batch_size:.0f}, num_images: {self.num_images}, batch_size: {self.batch_size}, length w/repeats: {self._length}")
|
||||
print()
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, i):
|
||||
idx = i % self.num_images
|
||||
example = self.get_image(self.image_caption_pairs[idx])
|
||||
#example = self.get_image(self.image_caption_pairs[idx])
|
||||
image_train_item = self.image_train_items[idx]
|
||||
#print(f" *** example {example}")
|
||||
|
||||
hydrated_image_train_item = image_train_item.hydrate()
|
||||
|
||||
example = self.get_image_for_trainer(hydrated_image_train_item)
|
||||
return example
|
||||
|
||||
def get_image(self, image_caption_pair):
|
||||
def unload_images_over(self, limit):
|
||||
print(f" ********** Unloading images over limit {limit}")
|
||||
i = 0
|
||||
while i < len(self.image_train_items):
|
||||
print(self.image_train_items[i])
|
||||
if i > limit:
|
||||
self.image_train_items[i][0] = Image.new(mode='RGB', size=(1, 1))
|
||||
i += 1
|
||||
gc.collect()
|
||||
|
||||
def get_image_for_trainer(self, image_train_item):
|
||||
example = {}
|
||||
|
||||
image = image_caption_pair[0]
|
||||
image_train_tmp = image_train_item.as_formatted()
|
||||
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
identifier = image_caption_pair[1]
|
||||
|
||||
image = self.flip(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
|
||||
example["caption"] = identifier
|
||||
example["image"] = image_train_tmp.image
|
||||
example["caption"] = image_train_tmp.caption
|
||||
|
||||
return example
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
|
||||
class ImageTrainItem(): # [image, identifier, target_aspect, closest_aspect_wh[w,h], pathname]
|
||||
def __init__(self, image: Image, caption: str, target_wh: list, pathname: str, flip_p=0.0):
|
||||
self.caption = caption
|
||||
self.target_wh = target_wh
|
||||
#self.target_aspect = target_aspect
|
||||
self.pathname = pathname
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
|
||||
if image is None:
|
||||
self.image = Image.new(mode='RGB',size=(1,1))
|
||||
else:
|
||||
self.image = image
|
||||
#image_train_item.image = image.resize((image_train_item.closest_aspect_wh[0], image_train_item.closest_aspect_wh[1]), Image.BICUBIC)
|
||||
|
||||
def hydrate(self):
|
||||
self.image = self.image.resize(self.target_wh, Image.BICUBIC)
|
||||
|
||||
if not self.image.mode == "RGB":
|
||||
self.image = self.image.convert("RGB")
|
||||
|
||||
self.image = self.flip(self.image)
|
||||
self.image = np.array(self.image).astype(np.uint8)
|
||||
|
||||
self.image = (self.image / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
return self
|
|
@ -0,0 +1,53 @@
|
|||
# script to test data loader by itself
|
||||
# run from training root, edit the data_root manually
|
||||
# python ldm/data/test_dl.py
|
||||
import every_dream
|
||||
import time
|
||||
|
||||
|
||||
s = time.perf_counter()
|
||||
|
||||
data_root = "r:/everydream-trainer/training_samples/ff7r"
|
||||
|
||||
batch_size = 1
|
||||
every_dream_batch = every_dream.EveryDreamBatch(data_root=data_root, flip_p=0.0, debug_level=0, batch_size=batch_size, repeats=1)
|
||||
|
||||
print(f" *TEST* batch type: {type(every_dream_batch)}")
|
||||
i = 0
|
||||
is_next = True
|
||||
curr_batch = []
|
||||
|
||||
while is_next and i < 30 and i < len(every_dream_batch):
|
||||
try:
|
||||
example = every_dream_batch[i]
|
||||
if example is not None:
|
||||
#print(f"example type: {type(example)}") # dict
|
||||
#print(f"example keys: {example.keys()}") # dict_keys(['image', 'caption'])
|
||||
#print(f"example image type: {type(example['image'])}") # numpy.ndarray
|
||||
if i%batch_size == 0:
|
||||
curr_batch = example['image'].shape
|
||||
img_in_right_batch = curr_batch == example['image'].shape
|
||||
print(f" *TEST* example image shape: {example['image'].shape} {i%batch_size} {img_in_right_batch}")
|
||||
print(f" *TEST* example caption: {example['caption']}")
|
||||
|
||||
if not img_in_right_batch:
|
||||
raise Exception("Current image in wrong batch")
|
||||
#print(f"example caption: {example['caption']}") # str
|
||||
else:
|
||||
is_next = False
|
||||
i += 1
|
||||
except IndexError:
|
||||
is_next = False
|
||||
print(f"IndexError: {i}")
|
||||
pass
|
||||
# for idx, batches in every_dream_batch:
|
||||
# print(f"inner example type: {type(batches)}")
|
||||
# print(type(batches))
|
||||
# print(type(batches[0]))
|
||||
# print(dir(batches))
|
||||
#h, w = batches.image.size
|
||||
#print(f"{idx:05d}-{idx%6:02d}EveryDreamBatch image caption pair: w:{w} h:{h} {batches.caption[1]}")
|
||||
print(f" *TEST* test cycles: {i}")
|
||||
print(f" *TEST* EveryDreamBatch epoch image length: {len(every_dream_batch)}")
|
||||
elapsed = time.perf_counter() - s
|
||||
print(f"{__file__} executed in {elapsed:5.2f} seconds.")
|
|
@ -0,0 +1,18 @@
|
|||
# script to test data loader by itself
|
||||
# run from training root, edit the data_root manually
|
||||
# python ldm/data/test_dl.py
|
||||
import data_loader
|
||||
|
||||
data_root = "r:/everydream-trainer/training_samples/multiaspect"
|
||||
|
||||
data_loader = data_loader.DataLoaderMultiAspect(data_root=data_root, repeats=1, seed=555, debug_level=2)
|
||||
|
||||
image_caption_pairs = data_loader.get_all_images()
|
||||
|
||||
print(f"Loaded {len(image_caption_pairs)} image-caption pairs")
|
||||
|
||||
for image_caption_pair in image_caption_pairs:
|
||||
print(image_caption_pair)
|
||||
print(image_caption_pair[1])
|
||||
|
||||
print(f"**** Done loading. Loaded {len(image_caption_pairs)} images from data_root: {data_root} ****")
|
|
@ -0,0 +1,47 @@
|
|||
# script to test data loader by itself
|
||||
# run from training root, edit the data_root manually
|
||||
# python ldm/data/test_dl.py
|
||||
import ed_validate
|
||||
|
||||
data_root = "r:/everydream-trainer/training_samples/multiaspect4"
|
||||
|
||||
batch_size = 6
|
||||
ed_val_batch = ed_validate.EDValidateBatch(data_root=data_root, flip_p=0.0, debug_level=0, batch_size=batch_size, repeats=1)
|
||||
|
||||
print(f"batch type: {type(ed_val_batch)}")
|
||||
i = 0
|
||||
is_next = True
|
||||
curr_batch = []
|
||||
while is_next and i < 84:
|
||||
try:
|
||||
example = ed_val_batch[i]
|
||||
if example is not None:
|
||||
#print(f"example type: {type(example)}") # dict
|
||||
#print(f"example keys: {example.keys()}") # dict_keys(['image', 'caption'])
|
||||
#print(f"example image type: {type(example['image'])}") # numpy.ndarray
|
||||
if i%batch_size == 0:
|
||||
curr_batch = example['image'].shape
|
||||
img_in_right_batch = curr_batch == example['image'].shape
|
||||
print(f"example image shape: {example['image'].shape} {i%batch_size} {img_in_right_batch}") # (256, 256, 3)
|
||||
|
||||
if not img_in_right_batch:
|
||||
raise Exception("Current image in wrong batch")
|
||||
#print(f"example caption: {example['caption']}") # str
|
||||
else:
|
||||
is_next = False
|
||||
i += 1
|
||||
except IndexError:
|
||||
is_next = False
|
||||
print(f"IndexError: {i}")
|
||||
pass
|
||||
# for idx, batches in every_dream_batch:
|
||||
# print(f"inner example type: {type(batches)}")
|
||||
# print(type(batches))
|
||||
# print(type(batches[0]))
|
||||
# print(dir(batches))
|
||||
#h, w = batches.image.size
|
||||
#print(f"{idx:05d}-{idx%6:02d}EveryDreamBatch image caption pair: w:{w} h:{h} {batches.caption[1]}")
|
||||
|
||||
ed_val_batch.image_caption_pairs = [image_caption_pair for image_caption_pair in self.image_caption_pairs if image_caption_pair[0].size == aspect_ratio]
|
||||
|
||||
print(f"EveryDreamBatch epoch image length: {len(ed_val_batch)}")
|
Loading…
Reference in New Issue