Merge pull request #141 from tjennings/main

image preloading and data loading optimizations.
This commit is contained in:
Victor Hall 2023-04-14 21:25:32 -04:00 committed by GitHub
commit 7ba9d43db5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 14 deletions

View File

@ -14,6 +14,8 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
import os
import torch
from torch.utils.data import Dataset
from data.data_loader import DataLoaderMultiAspect
@ -61,18 +63,18 @@ class EveryDreamBatch(Dataset):
self.image_train_items = []
self.__update_image_train_items(1.0)
self.name = name
num_images = len(self.image_train_items)
logging.info(f" ** Dataset '{name}': {num_images / self.batch_size:.0f} batches, num_images: {num_images}, batch_size: {self.batch_size}")
def shuffle(self, epoch_n: int, max_epochs: int):
self.seed += 1
if self.rated_dataset:
dropout_fraction = (max_epochs - (epoch_n * self.rated_dataset_dropout_target)) / max_epochs
else:
dropout_fraction = 1.0
self.__update_image_train_items(dropout_fraction)
def __len__(self):
@ -135,18 +137,18 @@ class EveryDreamBatch(Dataset):
if image_train_tmp.cond_dropout is not None:
example["cond_dropout"] = image_train_tmp.cond_dropout
example["runt_size"] = image_train_tmp.runt_size
return example
def __update_image_train_items(self, dropout_fraction: float):
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader:
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
batch_size= batch_size,
shuffle=False,
num_workers=4,
num_workers=min(batch_size, os.cpu_count()),
collate_fn=collate_fn
)
return dataloader

View File

@ -106,7 +106,7 @@ class ImageCaption:
weights_copy.pop(pos)
tag = tags_copy.pop(pos)
if caption:
caption += ", "
caption += tag
@ -122,7 +122,7 @@ class ImageTrainItem:
"""
image: PIL.Image
identifier: caption,
target_aspect: (width, height),
target_aspect: (width, height),
pathname: path to image file
flip_p: probability of flipping image (0.0 to 1.0)
rating: the relative rating of the images. The rating is measured in comparison to the other images.
@ -144,14 +144,14 @@ class ImageTrainItem:
self.image = image
self.image_size = image.size
self.target_size = None
self.is_undersized = False
self.error = None
self.__compute_target_width_height()
def load_image(self):
image = PIL.Image.open(self.pathname).convert('RGB')
try:
image = PIL.Image.open(self.pathname).convert('RGB')
image = ImageOps.exif_transpose(image)
except SyntaxError as e:
pass
@ -232,15 +232,15 @@ class ImageTrainItem:
# print(self.image.shape)
return self
def __compute_target_width_height(self):
self.target_wh = None
try:
with self.load_image() as image:
with PIL.Image.open(self.pathname) as image:
width, height = image.size
image_aspect = width / height
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
self.is_undersized = (width * height) < (target_wh[0] * target_wh[1])
self.target_wh = target_wh
except Exception as e: