diff --git a/data/data_loader.py b/data/data_loader.py index 28fcd00..d97e38e 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -15,13 +15,13 @@ limitations under the License. """ import bisect import logging -from functools import reduce +import os.path +from collections import defaultdict import math -import copy import random -from data.image_train_item import ImageTrainItem, ImageCaption -import PIL +from data.image_train_item import ImageTrainItem +import PIL.Image PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default @@ -39,9 +39,9 @@ class DataLoaderMultiAspect(): self.prepared_train_data = image_train_items random.Random(self.seed).shuffle(self.prepared_train_data) self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating()) - self.epoch_size = math.floor(sum([i.multiplier for i in self.prepared_train_data])) - if self.epoch_size != len(self.prepared_train_data): - logging.info(f" * DLMA initialized with {len(image_train_items)} source images. After applying multipliers, each epoch will train on at least {self.epoch_size} images.") + expected_epoch_size = math.floor(sum([i.multiplier for i in self.prepared_train_data])) + if expected_epoch_size != len(self.prepared_train_data): + logging.info(f" * DLMA initialized with {len(image_train_items)} source images. After applying multipliers, each epoch will train on at least {expected_epoch_size} images.") else: logging.info(f" * DLMA initialized with {len(image_train_items)} images.") @@ -50,12 +50,12 @@ class DataLoaderMultiAspect(): self.__update_rating_sums() - def __pick_multiplied_set(self, randomizer): + def __pick_multiplied_set(self, randomizer: random.Random): """ Deals with multiply.txt whole and fractional numbers """ picked_images = [] - fractional_images = [] + fractional_images_per_directory = defaultdict(list[ImageTrainItem]) for iti in self.prepared_train_data: multiplier = iti.multiplier while multiplier >= 1: @@ -63,20 +63,15 @@ class DataLoaderMultiAspect(): multiplier -= 1 # fractional remainders must be dealt with separately if multiplier > 0: - fractional_images.append((iti, multiplier)) + directory = os.path.dirname(iti.pathname) + fractional_images_per_directory[directory].append(iti) - target_epoch_size = self.epoch_size - while len(picked_images) < target_epoch_size and len(fractional_images) > 0: - # cycle through fractional_images, randomly shifting each over to picked_images based on its multiplier - iti, multiplier = fractional_images.pop(0) - if randomizer.uniform(0, 1) < multiplier: - # shift it over to picked_images - picked_images.append(iti) - else: - # put it back and move on to the next - fractional_images.append((iti, multiplier)) - - assert len(picked_images) == target_epoch_size, "Something went wrong while attempting to apply multipliers" + # resolve fractional parts per-directory + for _, fractional_items in fractional_images_per_directory.items(): + randomizer.shuffle(fractional_items) + multiplier = fractional_items[0].multiplier % 1.0 + count_to_take = math.ceil(multiplier * len(fractional_items)) + picked_images.extend(fractional_items[:count_to_take]) return picked_images diff --git a/train.py b/train.py index 32f35ba..c63c471 100644 --- a/train.py +++ b/train.py @@ -550,6 +550,7 @@ def main(args): except Exception as e: traceback.print_exc() logging.error(" * Failed to load checkpoint *") + raise if args.gradient_checkpointing: unet.enable_gradient_checkpointing()